vmbus_client/
lib.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4#![expect(missing_docs)]
5#![forbid(unsafe_code)]
6
7mod hvsock;
8pub mod saved_state;
9
10pub use self::saved_state::SavedState;
11use anyhow::Context as _;
12use anyhow::Result;
13use futures::FutureExt;
14use futures::StreamExt;
15use futures::future::OptionFuture;
16use futures::stream::SelectAll;
17use futures_concurrency::future::Race;
18use guid::Guid;
19use inspect::Inspect;
20use mesh::rpc::FailableRpc;
21use mesh::rpc::Rpc;
22use mesh::rpc::RpcSend;
23use pal_async::task::Spawn;
24use pal_async::task::Task;
25use pal_event::Event;
26use std::collections::HashMap;
27use std::collections::VecDeque;
28use std::collections::hash_map;
29use std::convert::TryInto;
30use std::future::Future;
31use std::future::poll_fn;
32use std::ops::Deref;
33use std::ops::DerefMut;
34use std::pin::pin;
35use std::sync::Arc;
36use std::task::Context;
37use std::task::Poll;
38use thiserror::Error;
39use vmbus_async::async_dgram::AsyncRecv;
40use vmbus_async::async_dgram::AsyncRecvExt;
41use vmbus_channel::bus::GpadlRequest;
42use vmbus_channel::bus::ModifyRequest;
43use vmbus_channel::bus::OpenData;
44use vmbus_channel::gpadl::GpadlId;
45use vmbus_core::HvsockConnectRequest;
46use vmbus_core::OutgoingMessage;
47use vmbus_core::TaggedStream;
48use vmbus_core::VersionInfo;
49use vmbus_core::protocol;
50use vmbus_core::protocol::ChannelId;
51use vmbus_core::protocol::ConnectionState;
52use vmbus_core::protocol::FeatureFlags;
53use vmbus_core::protocol::Message;
54use vmbus_core::protocol::OpenChannelFlags;
55use vmbus_core::protocol::Version;
56use vmcore::interrupt::Interrupt;
57use vmcore::synic::MonitorPageGpas;
58use zerocopy::Immutable;
59use zerocopy::IntoBytes;
60use zerocopy::KnownLayout;
61
62const SINT: u8 = 2;
63const VTL: u8 = 0;
64const SUPPORTED_VERSIONS: &[Version] = &[Version::Iron, Version::Copper];
65const SUPPORTED_FEATURE_FLAGS: FeatureFlags = FeatureFlags::new()
66    .with_guest_specified_signal_parameters(true)
67    .with_channel_interrupt_redirection(true)
68    .with_modify_connection(true)
69    .with_client_id(true)
70    .with_pause_resume(true);
71
72/// The client interface synic events.
73pub trait SynicEventClient: Send + Sync {
74    /// Maps an incoming event signal on SINT7 to `event`.
75    fn map_event(&self, event_flag: u16, event: &Event) -> std::io::Result<()>;
76
77    /// Unmaps an event previously mapped with `map_event`.
78    fn unmap_event(&self, event_flag: u16);
79
80    /// Signals an event on the synic.
81    fn signal_event(&self, connection_id: u32, event_flag: u16) -> std::io::Result<()>;
82}
83
84/// A stream of vmbus messages that can be paused and resumed.
85pub trait VmbusMessageSource: AsyncRecv + Send {
86    /// Stop accepting new messages from the synic. After this is called, the message source must
87    /// return any pending messages already in the queue, and then return EOF.
88    fn pause_message_stream(&mut self) {}
89
90    /// Resume accepting new messages from the synic.
91    fn resume_message_stream(&mut self) {}
92}
93
94pub trait PollPostMessage: Send {
95    fn poll_post_message(
96        &mut self,
97        cx: &mut Context<'_>,
98        connection_id: u32,
99        typ: u32,
100        msg: &[u8],
101    ) -> Poll<()>;
102}
103
104pub struct VmbusClient {
105    task_send: mesh::Sender<TaskRequest>,
106    access: VmbusClientAccess,
107    task: Task<ClientTask>,
108}
109
110#[derive(Debug, thiserror::Error)]
111pub enum ConnectError {
112    #[error("invalid state to connect to the server")]
113    InvalidState,
114    #[error("no supported protocol versions")]
115    NoSupportedVersions,
116    #[error("failed to connect to the server: {0:?}")]
117    FailedToConnect(ConnectionState),
118}
119
120#[derive(Clone)]
121pub struct VmbusClientAccess {
122    client_request_send: mesh::Sender<ClientRequest>,
123}
124
125/// A builder for creating a [`VmbusClient`].
126pub struct VmbusClientBuilder {
127    event_client: Arc<dyn SynicEventClient>,
128    msg_source: Box<dyn VmbusMessageSource>,
129    msg_client: Box<dyn PollPostMessage>,
130}
131
132impl VmbusClientBuilder {
133    /// Creates a new instance of the builder with the given synic input.
134    pub fn new(
135        event_client: impl SynicEventClient + 'static,
136        msg_source: impl VmbusMessageSource + 'static,
137        msg_client: impl PollPostMessage + 'static,
138    ) -> Self {
139        Self {
140            event_client: Arc::new(event_client),
141            msg_source: Box::new(msg_source),
142            msg_client: Box::new(msg_client),
143        }
144    }
145
146    /// Creates a new instance with a receiver for incoming synic messages.
147    pub fn build(self, spawner: &impl Spawn) -> VmbusClient {
148        let (task_send, task_recv) = mesh::channel();
149        let (client_request_send, client_request_recv) = mesh::channel();
150
151        let inner = ClientTaskInner {
152            messages: OutgoingMessages {
153                poster: self.msg_client,
154                queued: VecDeque::new(),
155                state: OutgoingMessageState::Paused,
156            },
157            teardown_gpadls: HashMap::new(),
158            channel_requests: SelectAll::new(),
159            synic: SynicState {
160                event_flag_state: Vec::new(),
161                event_client: self.event_client,
162            },
163        };
164
165        let mut task = ClientTask {
166            inner,
167            channels: ChannelList::default(),
168            task_recv,
169            running: false,
170            msg_source: self.msg_source,
171            client_request_recv,
172            state: ClientState::Disconnected,
173            modify_request: None,
174            hvsock_tracker: hvsock::HvsockRequestTracker::new(),
175        };
176
177        let task = spawner.spawn("vmbus client", async move {
178            task.run().await;
179            task
180        });
181
182        VmbusClient {
183            access: VmbusClientAccess {
184                client_request_send,
185            },
186            task_send,
187            task,
188        }
189    }
190}
191
192impl VmbusClient {
193    /// Connects to the server, negotiating the protocol version and retrieving
194    /// the initial list of channel offers.
195    pub async fn connect(
196        &mut self,
197        target_message_vp: u32,
198        monitor_page: Option<MonitorPageGpas>,
199        client_id: Guid,
200    ) -> Result<ConnectResult, ConnectError> {
201        let request = ConnectRequest {
202            target_message_vp,
203            monitor_page,
204            client_id,
205        };
206
207        self.access
208            .client_request_send
209            .call(ClientRequest::Connect, request)
210            .await
211            .unwrap()
212    }
213
214    pub async fn unload(self) {
215        self.access
216            .client_request_send
217            .call(ClientRequest::Unload, ())
218            .await
219            .unwrap();
220
221        self.sever().await;
222    }
223
224    pub fn access(&self) -> &VmbusClientAccess {
225        &self.access
226    }
227
228    pub fn start(&mut self) {
229        self.task_send.send(TaskRequest::Start);
230    }
231
232    pub async fn stop(&mut self) {
233        self.task_send
234            .call(TaskRequest::Stop, ())
235            .await
236            .expect("Failed to send stop request");
237    }
238
239    pub async fn save(&self) -> SavedState {
240        self.task_send
241            .call(TaskRequest::Save, ())
242            .await
243            .expect("Failed to send save request")
244    }
245
246    pub async fn restore(
247        &mut self,
248        state: SavedState,
249    ) -> Result<Option<ConnectResult>, RestoreError> {
250        self.task_send
251            .call(TaskRequest::Restore, state)
252            .await
253            .expect("Failed to send restore request")
254    }
255
256    pub async fn post_restore(&mut self) {
257        self.task_send
258            .call(TaskRequest::PostRestore, ())
259            .await
260            .expect("Failed to send post-restore request");
261    }
262
263    async fn sever(self) -> VmbusClientBuilder {
264        drop(self.task_send);
265        let task = self.task.await;
266        VmbusClientBuilder {
267            event_client: task.inner.synic.event_client,
268            msg_source: task.msg_source,
269            msg_client: task.inner.messages.poster,
270        }
271    }
272}
273
274impl Inspect for VmbusClient {
275    fn inspect(&self, req: inspect::Request<'_>) {
276        self.task_send.send(TaskRequest::Inspect(req.defer()));
277    }
278}
279
280#[derive(Debug)]
281pub struct ConnectResult {
282    pub version: VersionInfo,
283    pub offers: Vec<OfferInfo>,
284    pub offer_recv: mesh::Receiver<OfferInfo>,
285}
286
287impl VmbusClientAccess {
288    pub async fn modify(&self, request: ModifyConnectionRequest) -> ConnectionState {
289        self.client_request_send
290            .call(ClientRequest::Modify, request)
291            .await
292            .expect("Failed to send modify request")
293    }
294
295    pub fn connect_hvsock(
296        &self,
297        request: HvsockConnectRequest,
298    ) -> impl Future<Output = Option<OfferInfo>> + use<> {
299        self.client_request_send
300            .call(ClientRequest::HvsockConnect, request)
301            .map(|r| r.ok().flatten())
302    }
303}
304
305#[derive(Debug)]
306pub struct OpenRequest {
307    pub open_data: OpenData,
308    pub incoming_event: Option<Event>,
309    pub use_vtl2_connection_id: bool,
310}
311
312#[derive(Debug)]
313pub struct RestoreRequest {
314    pub incoming_event: Option<Event>,
315    // FUTURE: move to saved state, don't rely on the caller.
316    pub redirected_event_flag: Option<u16>,
317    // FUTURE: ditto
318    pub connection_id: u32,
319}
320
321/// Expresses an operation requested of the client.
322pub enum ChannelRequest {
323    Open(FailableRpc<OpenRequest, OpenOutput>),
324    Restore(FailableRpc<RestoreRequest, OpenOutput>),
325    Close(Rpc<(), ()>),
326    Gpadl(FailableRpc<GpadlRequest, ()>),
327    TeardownGpadl(Rpc<GpadlId, ()>),
328    Modify(Rpc<ModifyRequest, i32>),
329}
330
331#[derive(Debug)]
332pub struct OpenOutput {
333    // FUTURE: remove this once it's part of the saved state.
334    pub redirected_event_flag: Option<u16>,
335    pub guest_to_host_signal: Interrupt,
336}
337
338impl std::fmt::Display for ChannelRequest {
339    fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
340        let s = match self {
341            ChannelRequest::Open(_) => "Open",
342            ChannelRequest::Close(_) => "Close",
343            ChannelRequest::Restore(_) => "Restore",
344            ChannelRequest::Gpadl(_) => "Gpadl",
345            ChannelRequest::TeardownGpadl(_) => "TeardownGpadl",
346            ChannelRequest::Modify(_) => "Modify",
347        };
348        fmt.pad(s)
349    }
350}
351
352#[derive(Debug, Error)]
353pub enum RestoreError {
354    #[error("unsupported protocol version {0:#x}")]
355    UnsupportedVersion(u32),
356
357    #[error("unsupported feature flags {0:#x}")]
358    UnsupportedFeatureFlags(u32),
359
360    #[error("duplicate channel id {0}")]
361    DuplicateChannelId(u32),
362
363    #[error("duplicate gpadl id {0}")]
364    DuplicateGpadlId(u32),
365
366    #[error("gpadl for unknown channel id {0}")]
367    GpadlForUnknownChannelId(u32),
368
369    #[error("invalid pending message")]
370    InvalidPendingMessage(#[source] vmbus_core::MessageTooLarge),
371
372    #[error("failed to offer restored channel")]
373    OfferFailed(#[source] anyhow::Error),
374}
375
376/// Provides the offer details from the server in addition to both a channel
377/// to request client actions and a channel to receive server responses.
378#[derive(Debug, Inspect)]
379pub struct OfferInfo {
380    pub offer: protocol::OfferChannel,
381    #[inspect(skip)]
382    pub request_send: mesh::Sender<ChannelRequest>,
383    #[inspect(skip)]
384    pub revoke_recv: mesh::OneshotReceiver<()>,
385}
386
387#[derive(Debug)]
388enum ClientRequest {
389    Connect(Rpc<ConnectRequest, Result<ConnectResult, ConnectError>>),
390    Unload(Rpc<(), ()>),
391    Modify(Rpc<ModifyConnectionRequest, ConnectionState>),
392    HvsockConnect(Rpc<HvsockConnectRequest, Option<OfferInfo>>),
393}
394
395impl std::fmt::Display for ClientRequest {
396    fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
397        let s = match self {
398            ClientRequest::Connect(..) => "Connect",
399            ClientRequest::Unload { .. } => "Unload",
400            ClientRequest::Modify(..) => "Modify",
401            ClientRequest::HvsockConnect(..) => "HvsockConnect",
402        };
403        fmt.pad(s)
404    }
405}
406
407enum TaskRequest {
408    Inspect(inspect::Deferred),
409    Save(Rpc<(), SavedState>),
410    Restore(Rpc<SavedState, Result<Option<ConnectResult>, RestoreError>>),
411    PostRestore(Rpc<(), ()>),
412    Start,
413    Stop(Rpc<(), ()>),
414}
415
416/// The overall state machine used to drive which actions the client can legally
417/// take. This primarily pertains to overall client activity but has a
418/// side-effect of limiting whether or not channels can perform actions.
419#[derive(Inspect)]
420#[inspect(external_tag)]
421enum ClientState {
422    /// The client has yet to connect to the server.
423    Disconnected,
424    /// The client has initiated contact with the server.
425    Connecting {
426        version: Version,
427        #[inspect(skip)]
428        rpc: Rpc<ConnectRequest, Result<ConnectResult, ConnectError>>,
429    },
430    /// The client has negotiated the protocol version with the server.
431    Connected {
432        version: VersionInfo,
433        #[inspect(skip)]
434        offer_send: mesh::Sender<OfferInfo>,
435    },
436    /// The client has requested offers from the server.
437    RequestingOffers {
438        version: VersionInfo,
439        #[inspect(skip)]
440        rpc: Rpc<(), Result<ConnectResult, ConnectError>>,
441        #[inspect(skip)]
442        offers: Vec<OfferInfo>,
443    },
444    /// The client has initiated an unload from the server.
445    Disconnecting {
446        version: VersionInfo,
447        #[inspect(skip)]
448        rpc: Rpc<(), ()>,
449    },
450}
451
452impl ClientState {
453    fn get_version(&self) -> Option<VersionInfo> {
454        match self {
455            ClientState::Connected { version, .. } => Some(*version),
456            ClientState::RequestingOffers { version, .. } => Some(*version),
457            ClientState::Disconnecting { version, .. } => Some(*version),
458            ClientState::Disconnected | ClientState::Connecting { .. } => None,
459        }
460    }
461}
462
463impl std::fmt::Display for ClientState {
464    fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
465        let s = match self {
466            ClientState::Disconnected => "Disconnected",
467            ClientState::Connecting { .. } => "Connecting",
468            ClientState::Connected { .. } => "Connected",
469            ClientState::RequestingOffers { .. } => "RequestingOffers",
470            ClientState::Disconnecting { .. } => "Disconnecting",
471        };
472        fmt.pad(s)
473    }
474}
475
476#[derive(Copy, Clone, Debug, Default)]
477struct ConnectRequest {
478    target_message_vp: u32,
479    monitor_page: Option<MonitorPageGpas>,
480    client_id: Guid,
481}
482
483#[derive(Copy, Clone, Debug, Default)]
484pub struct ModifyConnectionRequest {
485    pub monitor_page: Option<MonitorPageGpas>,
486}
487
488impl From<ModifyConnectionRequest> for protocol::ModifyConnection {
489    fn from(value: ModifyConnectionRequest) -> Self {
490        let monitor_page = value.monitor_page.unwrap_or_default();
491
492        Self {
493            parent_to_child_monitor_page_gpa: monitor_page.parent_to_child,
494            child_to_parent_monitor_page_gpa: monitor_page.child_to_parent,
495        }
496    }
497}
498
499/// The per-channel state which dictates which whether or not a channel can
500/// request an Open/Close. As GPADLs can happen outside this loop there is no
501/// state tied to GPADL actions.
502#[derive(Debug, Inspect)]
503#[inspect(external_tag)]
504enum ChannelState {
505    /// The channel has been offered to the client.
506    Offered,
507    /// The channel has requested the server to be opened.
508    Opening {
509        connection_id: u32,
510        redirected_event_flag: Option<u16>,
511        #[inspect(skip)]
512        redirected_event: Option<Event>,
513        #[inspect(skip)]
514        rpc: FailableRpc<(), OpenOutput>,
515    },
516    /// The channel has been restored but not claimed.
517    Restored,
518    /// The channel has been successfully opened.
519    Opened {
520        connection_id: u32,
521        redirected_event_flag: Option<u16>,
522        #[inspect(skip)]
523        redirected_event: Option<Event>,
524    },
525    /// The channel has been revoked by the server.
526    Revoked,
527}
528
529impl std::fmt::Display for ChannelState {
530    fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
531        let s = match self {
532            ChannelState::Opening { .. } => "Opening",
533            ChannelState::Offered => "Offered",
534            ChannelState::Opened { .. } => "Opened",
535            ChannelState::Restored => "Restored",
536            ChannelState::Revoked => "Revoked",
537        };
538        fmt.pad(s)
539    }
540}
541
542#[derive(Debug, Inspect)]
543struct Channel {
544    offer: protocol::OfferChannel,
545    // When dropped, notifies the caller the channel has been revoked.
546    #[inspect(skip)]
547    revoke_send: Option<mesh::OneshotSender<()>>,
548    state: ChannelState,
549    #[inspect(with = "|x| x.is_some()")]
550    modify_response_send: Option<Rpc<(), i32>>,
551    #[inspect(with = "|x| inspect::iter_by_key(x).map_key(|x| x.0)")]
552    gpadls: HashMap<GpadlId, GpadlState>,
553    is_client_released: bool,
554}
555
556impl Channel {
557    fn pending_request(&self) -> Option<&'static str> {
558        if self.modify_response_send.is_some() {
559            return Some("modify");
560        }
561        self.gpadls.iter().find_map(|(_, gpadl)| match gpadl {
562            GpadlState::Offered(_) => Some("creating gpadl"),
563            GpadlState::Created => None,
564            GpadlState::TearingDown { .. } => Some("tearing down gpadl"),
565        })
566    }
567}
568
569#[derive(Inspect)]
570struct ClientTask {
571    #[inspect(flatten)]
572    inner: ClientTaskInner,
573    channels: ChannelList,
574    state: ClientState,
575    hvsock_tracker: hvsock::HvsockRequestTracker,
576    running: bool,
577    #[inspect(with = "|x| x.is_some()")]
578    modify_request: Option<Rpc<ModifyConnectionRequest, ConnectionState>>,
579    #[inspect(skip)]
580    msg_source: Box<dyn VmbusMessageSource>,
581    #[inspect(skip)]
582    task_recv: mesh::Receiver<TaskRequest>,
583    #[inspect(skip)]
584    client_request_recv: mesh::Receiver<ClientRequest>,
585}
586
587impl ClientTask {
588    fn handle_initiate_contact(
589        &mut self,
590        rpc: Rpc<ConnectRequest, Result<ConnectResult, ConnectError>>,
591        version: Version,
592    ) {
593        let ClientState::Disconnected = self.state else {
594            tracing::warn!(client_state = %self.state, "invalid client state for InitiateContact");
595            rpc.complete(Err(ConnectError::InvalidState));
596            return;
597        };
598        let feature_flags = if version >= Version::Copper {
599            SUPPORTED_FEATURE_FLAGS
600        } else {
601            FeatureFlags::new()
602        };
603
604        let request = rpc.input();
605
606        tracing::debug!(version = ?version, ?feature_flags, "VmBus client connecting");
607        let target_info = protocol::TargetInfo::new()
608            .with_sint(SINT)
609            .with_vtl(VTL)
610            .with_feature_flags(feature_flags.into());
611        let monitor_page = request.monitor_page.unwrap_or_default();
612        let msg = protocol::InitiateContact2 {
613            initiate_contact: protocol::InitiateContact {
614                version_requested: version as u32,
615                target_message_vp: request.target_message_vp,
616                interrupt_page_or_target_info: target_info.into(),
617                parent_to_child_monitor_page_gpa: monitor_page.parent_to_child,
618                child_to_parent_monitor_page_gpa: monitor_page.child_to_parent,
619            },
620            client_id: request.client_id,
621        };
622
623        self.state = ClientState::Connecting { version, rpc };
624        if version < Version::Copper {
625            self.inner.messages.send(&msg.initiate_contact)
626        } else {
627            self.inner.messages.send(&msg);
628        }
629    }
630
631    fn handle_unload(&mut self, rpc: Rpc<(), ()>) {
632        tracing::debug!(%self.state, "VmBus client disconnecting");
633        self.state = ClientState::Disconnecting {
634            version: self.state.get_version().expect("invalid state for unload"),
635            rpc,
636        };
637
638        self.inner.messages.send(&protocol::Unload {});
639    }
640
641    fn handle_modify(&mut self, request: Rpc<ModifyConnectionRequest, ConnectionState>) {
642        if !matches!(self.state, ClientState::Connected { version, .. }
643            if version.feature_flags.modify_connection())
644        {
645            tracing::warn!("ModifyConnection not supported");
646            request.complete(ConnectionState::FAILED_UNKNOWN_FAILURE);
647            return;
648        }
649
650        if self.modify_request.is_some() {
651            tracing::warn!("Duplicate ModifyConnection request");
652            request.complete(ConnectionState::FAILED_UNKNOWN_FAILURE);
653            return;
654        }
655
656        let message = protocol::ModifyConnection::from(*request.input());
657        self.modify_request = Some(request);
658        self.inner.messages.send(&message);
659    }
660
661    fn handle_tl_connect(&mut self, rpc: Rpc<HvsockConnectRequest, Option<OfferInfo>>) {
662        // The client only supports protocol versions which use the newer message format.
663        // The host will not send a TlConnectRequestResult message on success, so a response to this
664        // message is not guaranteed.
665        let message = protocol::TlConnectRequest2::from(*rpc.input());
666        self.hvsock_tracker.add_request(rpc);
667        self.inner.messages.send(&message);
668    }
669
670    fn handle_client_request(&mut self, request: ClientRequest) {
671        match request {
672            ClientRequest::Connect(rpc) => {
673                self.handle_initiate_contact(rpc, *SUPPORTED_VERSIONS.last().unwrap());
674            }
675            ClientRequest::Unload(rpc) => {
676                self.handle_unload(rpc);
677            }
678            ClientRequest::Modify(request) => self.handle_modify(request),
679            ClientRequest::HvsockConnect(request) => self.handle_tl_connect(request),
680        }
681    }
682
683    fn handle_version_response(&mut self, msg: protocol::VersionResponse2) {
684        let old_state = std::mem::replace(&mut self.state, ClientState::Disconnected);
685        let ClientState::Connecting { version, rpc } = old_state else {
686            self.state = old_state;
687            tracing::warn!(
688                client_state = %self.state,
689                "invalid client state to handle VersionResponse"
690            );
691            return;
692        };
693        if msg.version_response.version_supported > 0 {
694            if msg.version_response.connection_state != ConnectionState::SUCCESSFUL {
695                rpc.complete(Err(ConnectError::FailedToConnect(
696                    msg.version_response.connection_state,
697                )));
698                return;
699            }
700
701            let feature_flags = if version >= Version::Copper {
702                FeatureFlags::from(msg.supported_features)
703            } else {
704                FeatureFlags::new()
705            };
706
707            let version = VersionInfo {
708                version,
709                feature_flags,
710            };
711
712            self.inner.messages.send(&protocol::RequestOffers {});
713            self.state = ClientState::RequestingOffers {
714                version,
715                rpc: rpc.split().1,
716                offers: Vec::new(),
717            };
718            tracing::info!(?version, "VmBus client connected, requesting offers");
719        } else {
720            let index = SUPPORTED_VERSIONS
721                .iter()
722                .position(|v| *v == version)
723                .unwrap();
724
725            if index == 0 {
726                rpc.complete(Err(ConnectError::NoSupportedVersions));
727                return;
728            }
729            let next_version = SUPPORTED_VERSIONS[index - 1];
730            tracing::debug!(
731                version = version as u32,
732                next_version = next_version as u32,
733                "Unsupported version, retrying"
734            );
735            self.handle_initiate_contact(rpc, next_version);
736        }
737    }
738
739    fn create_channel(&mut self, offer: protocol::OfferChannel) -> Result<OfferInfo> {
740        self.create_channel_core(offer, ChannelState::Offered)
741    }
742
743    fn create_channel_core(
744        &mut self,
745        offer: protocol::OfferChannel,
746        state: ChannelState,
747    ) -> Result<OfferInfo> {
748        if self.channels.0.contains_key(&offer.channel_id) {
749            anyhow::bail!("channel {:?} exists", offer.channel_id);
750        }
751        let (request_send, request_recv) = mesh::channel();
752        let (revoke_send, revoke_recv) = mesh::oneshot();
753
754        self.channels.0.insert(
755            offer.channel_id,
756            Channel {
757                revoke_send: Some(revoke_send),
758                offer,
759                state,
760                modify_response_send: None,
761                gpadls: HashMap::new(),
762                is_client_released: false,
763            },
764        );
765
766        self.inner
767            .channel_requests
768            .push(TaggedStream::new(offer.channel_id, request_recv));
769
770        Ok(OfferInfo {
771            offer,
772            revoke_recv,
773            request_send,
774        })
775    }
776
777    fn handle_offer(&mut self, offer: protocol::OfferChannel) {
778        let offer_info = self
779            .create_channel(offer)
780            .expect("channel should not exist");
781
782        tracing::info!(
783                state = %self.state,
784                channel_id = offer.channel_id.0,
785                interface_id = %offer.interface_id,
786                instance_id = %offer.instance_id,
787                subchannel_index = offer.subchannel_index,
788                "received offer");
789
790        if let Some(offer) = self.hvsock_tracker.check_offer(&offer_info.offer) {
791            offer.complete(Some(offer_info));
792        } else {
793            match &mut self.state {
794                ClientState::Connected { offer_send, .. } => {
795                    offer_send.send(offer_info);
796                }
797                ClientState::RequestingOffers { offers, .. } => {
798                    offers.push(offer_info);
799                }
800                state => unreachable!("invalid client state for OfferChannel: {state}"),
801            }
802        }
803    }
804
805    fn handle_rescind(&mut self, rescind: protocol::RescindChannelOffer) -> TriedRelease {
806        tracing::info!(state = %self.state, channel_id = rescind.channel_id.0, "received rescind");
807
808        let mut channel = self.channels.get_mut(rescind.channel_id);
809        let event_flag = match std::mem::replace(&mut channel.state, ChannelState::Revoked) {
810            ChannelState::Offered => None,
811            ChannelState::Opening {
812                connection_id: _,
813                redirected_event_flag,
814                redirected_event: _,
815                rpc,
816            } => {
817                rpc.fail(anyhow::anyhow!("channel revoked"));
818                redirected_event_flag
819            }
820            ChannelState::Restored => None,
821            ChannelState::Opened {
822                connection_id: _,
823                redirected_event_flag,
824                redirected_event: _,
825            } => redirected_event_flag,
826            ChannelState::Revoked => {
827                panic!("channel id {:?} already revoked", rescind.channel_id);
828            }
829        };
830        if let Some(event_flag) = event_flag {
831            self.inner.synic.free_event_flag(event_flag);
832        }
833
834        // Drop the channel and send the revoked message to the client.
835        channel.revoke_send.take().unwrap().send(());
836
837        channel.try_release(&mut self.inner.messages)
838    }
839
840    fn handle_offers_delivered(&mut self) {
841        match std::mem::replace(&mut self.state, ClientState::Disconnected) {
842            ClientState::RequestingOffers {
843                version,
844                rpc,
845                offers,
846            } => {
847                tracing::info!(version = ?version, "VmBus client connected, offers delivered");
848                let (offer_send, offer_recv) = mesh::channel();
849                self.state = ClientState::Connected {
850                    version,
851                    offer_send,
852                };
853                rpc.complete(Ok(ConnectResult {
854                    version,
855                    offers,
856                    offer_recv,
857                }));
858            }
859            state => {
860                tracing::warn!(client_state = %state, "invalid client state for OffersDelivered");
861                self.state = state;
862            }
863        }
864    }
865
866    fn handle_gpadl_created(&mut self, request: protocol::GpadlCreated) -> TriedRelease {
867        let mut channel = self.channels.get_mut(request.channel_id);
868        let Some(gpadl_state) = channel.gpadls.get_mut(&request.gpadl_id) else {
869            panic!("GpadlCreated for unknown gpadl {:#x}", request.gpadl_id.0);
870        };
871
872        let rpc = match std::mem::replace(gpadl_state, GpadlState::Created) {
873            GpadlState::Offered(rpc) => rpc,
874            old_state => {
875                panic!(
876                    "invalid state {old_state:?} for gpadl {:#x}:{:#x}",
877                    request.channel_id.0, request.gpadl_id.0
878                );
879            }
880        };
881
882        let gpadl_created = request.status == protocol::STATUS_SUCCESS;
883        if gpadl_created {
884            rpc.complete(Ok(()));
885        } else {
886            channel.gpadls.remove(&request.gpadl_id).unwrap();
887            rpc.fail(anyhow::anyhow!(
888                "gpadl creation failed: {:#x}",
889                request.status
890            ));
891        };
892        channel.try_release(&mut self.inner.messages)
893    }
894
895    fn handle_open_result(&mut self, result: protocol::OpenResult) {
896        tracing::debug!(
897            channel_id = result.channel_id.0,
898            result = result.status,
899            "received open result"
900        );
901
902        let mut channel = self.channels.get_mut(result.channel_id);
903
904        let channel_opened = result.status == protocol::STATUS_SUCCESS as u32;
905        let old_state = std::mem::replace(&mut channel.state, ChannelState::Offered);
906        let ChannelState::Opening {
907            connection_id,
908            redirected_event_flag,
909            redirected_event,
910            rpc,
911        } = old_state
912        else {
913            tracing::warn!(
914                old_state = ?channel.state,
915                channel_opened,
916                "invalid state for open result"
917            );
918            channel.state = old_state;
919            return;
920        };
921
922        if !channel_opened {
923            if let Some(event_flag) = redirected_event_flag {
924                self.inner.synic.free_event_flag(event_flag);
925            }
926            rpc.fail(anyhow::anyhow!("open failed: {:#x}", result.status));
927            return;
928        }
929
930        channel.state = ChannelState::Opened {
931            connection_id,
932            redirected_event_flag,
933            redirected_event,
934        };
935
936        rpc.complete(Ok(OpenOutput {
937            redirected_event_flag,
938            guest_to_host_signal: self.inner.synic.guest_to_host_interrupt(connection_id),
939        }));
940    }
941
942    fn handle_gpadl_torndown(&mut self, request: protocol::GpadlTorndown) -> TriedRelease {
943        let Some(channel_id) = self.inner.teardown_gpadls.remove(&request.gpadl_id) else {
944            panic!("gpadl {:#x} not in teardown list", request.gpadl_id.0);
945        };
946
947        tracing::debug!(
948            gpadl_id = request.gpadl_id.0,
949            channel_id = channel_id.0,
950            "Received GpadlTorndown"
951        );
952
953        let mut channel = self.channels.get_mut(channel_id);
954        let gpadl_state = channel
955            .gpadls
956            .remove(&request.gpadl_id)
957            .expect("gpadl validated above");
958
959        let GpadlState::TearingDown { rpcs } = gpadl_state else {
960            panic!("gpadl should be tearing down if in teardown list, state = {gpadl_state:?}");
961        };
962
963        for rpc in rpcs {
964            rpc.complete(());
965        }
966        channel.try_release(&mut self.inner.messages)
967    }
968
969    fn handle_unload_complete(&mut self) {
970        match std::mem::replace(&mut self.state, ClientState::Disconnected) {
971            ClientState::Disconnecting { version: _, rpc } => {
972                tracing::info!("VmBus client disconnected");
973                rpc.complete(());
974            }
975            state => {
976                tracing::warn!(client_state = %state, "invalid client state for UnloadComplete");
977            }
978        }
979    }
980
981    fn handle_modify_complete(&mut self, response: protocol::ModifyConnectionResponse) {
982        if let Some(request) = self.modify_request.take() {
983            request.complete(response.connection_state)
984        } else {
985            tracing::warn!("Unexpected modify complete request");
986        }
987    }
988
989    fn handle_modify_channel_response(
990        &mut self,
991        response: protocol::ModifyChannelResponse,
992    ) -> TriedRelease {
993        let mut channel = self.channels.get_mut(response.channel_id);
994        let Some(sender) = channel.modify_response_send.take() else {
995            panic!(
996                "unexpected modify channel response for channel {:#x}",
997                response.channel_id.0
998            );
999        };
1000
1001        sender.complete(response.status);
1002        channel.try_release(&mut self.inner.messages)
1003    }
1004
1005    fn handle_tl_connect_result(&mut self, response: protocol::TlConnectResult) {
1006        if let Some(rpc) = self.hvsock_tracker.check_result(&response) {
1007            rpc.complete(None);
1008        }
1009    }
1010
1011    /// Returns false if the message was a pause complete message.
1012    fn handle_synic_message(&mut self, data: &[u8]) -> bool {
1013        let msg = Message::parse(data, self.state.get_version()).unwrap();
1014        tracing::trace!(?msg, "received client message from synic");
1015
1016        match msg {
1017            Message::VersionResponse2(version_response, ..) => {
1018                self.handle_version_response(version_response);
1019            }
1020            Message::VersionResponse(version_response, ..) => {
1021                self.handle_version_response(version_response.into());
1022            }
1023            Message::OfferChannel(offer, ..) => {
1024                self.handle_offer(offer);
1025            }
1026            Message::AllOffersDelivered(..) => {
1027                self.handle_offers_delivered();
1028            }
1029            Message::UnloadComplete(..) => {
1030                self.handle_unload_complete();
1031            }
1032            Message::ModifyConnectionResponse(response, ..) => {
1033                self.handle_modify_complete(response);
1034            }
1035            Message::GpadlCreated(gpadl, ..) => {
1036                self.handle_gpadl_created(gpadl);
1037            }
1038            Message::OpenResult(result, ..) => {
1039                self.handle_open_result(result);
1040            }
1041            Message::GpadlTorndown(gpadl, ..) => {
1042                self.handle_gpadl_torndown(gpadl);
1043            }
1044            Message::RescindChannelOffer(rescind, ..) => {
1045                self.handle_rescind(rescind);
1046            }
1047            Message::ModifyChannelResponse(response, ..) => {
1048                self.handle_modify_channel_response(response);
1049            }
1050            Message::TlConnectResult(response, ..) => self.handle_tl_connect_result(response),
1051            // Unsupported messages.
1052            Message::CloseReservedChannelResponse(..) => {
1053                todo!("Unsupported message {msg:?}")
1054            }
1055            Message::PauseResponse(..) => {
1056                return false;
1057            }
1058            // Messages that should only be received by a vmbus server.
1059            Message::RequestOffers(..)
1060            | Message::OpenChannel2(..)
1061            | Message::OpenChannel(..)
1062            | Message::CloseChannel(..)
1063            | Message::GpadlHeader(..)
1064            | Message::GpadlBody(..)
1065            | Message::GpadlTeardown(..)
1066            | Message::RelIdReleased(..)
1067            | Message::InitiateContact(..)
1068            | Message::InitiateContact2(..)
1069            | Message::Unload(..)
1070            | Message::OpenReservedChannel(..)
1071            | Message::CloseReservedChannel(..)
1072            | Message::TlConnectRequest2(..)
1073            | Message::TlConnectRequest(..)
1074            | Message::ModifyChannel(..)
1075            | Message::ModifyConnection(..)
1076            | Message::Pause(..)
1077            | Message::Resume(..) => {
1078                unreachable!("Client received server message {msg:?}");
1079            }
1080        }
1081        true
1082    }
1083
1084    fn handle_open_channel(
1085        &mut self,
1086        channel_id: ChannelId,
1087        rpc: FailableRpc<OpenRequest, OpenOutput>,
1088    ) {
1089        let mut channel = self.channels.get_mut(channel_id);
1090        match &channel.state {
1091            ChannelState::Offered => {}
1092            ChannelState::Revoked => {
1093                rpc.fail(anyhow::anyhow!("channel revoked"));
1094                return;
1095            }
1096            state => {
1097                rpc.fail(anyhow::anyhow!("invalid channel state: {}", state));
1098                return;
1099            }
1100        }
1101
1102        tracing::info!(channel_id = channel_id.0, "opening channel on host");
1103        let (request, rpc) = rpc.split();
1104        let open_data = &request.open_data;
1105
1106        let supports_interrupt_redirection =
1107            if let ClientState::Connected { version, .. } = self.state {
1108                version.feature_flags.guest_specified_signal_parameters()
1109                    || version.feature_flags.channel_interrupt_redirection()
1110            } else {
1111                false
1112            };
1113
1114        if !supports_interrupt_redirection && open_data.event_flag != channel_id.0 as u16 {
1115            rpc.fail(anyhow::anyhow!(
1116                "host does not support specifying the event flag"
1117            ));
1118            return;
1119        }
1120
1121        let open_channel = protocol::OpenChannel {
1122            channel_id,
1123            open_id: 0,
1124            ring_buffer_gpadl_id: open_data.ring_gpadl_id,
1125            target_vp: open_data.target_vp,
1126            downstream_ring_buffer_page_offset: open_data.ring_offset,
1127            user_data: open_data.user_data,
1128        };
1129
1130        let connection_id = if request.use_vtl2_connection_id {
1131            if !supports_interrupt_redirection {
1132                rpc.fail(anyhow::anyhow!(
1133                    "host does not support specfiying the connection ID"
1134                ));
1135                return;
1136            }
1137            protocol::ConnectionId::new(channel_id.0, 2.try_into().unwrap(), 7).0
1138        } else {
1139            open_data.connection_id
1140        };
1141
1142        // No failure paths after the one for allocating the event flag, since
1143        // otherwise we would need to free the event flag.
1144        let mut flags = OpenChannelFlags::new();
1145        let event_flag = if let Some(event) = &request.incoming_event {
1146            if !supports_interrupt_redirection {
1147                rpc.fail(anyhow::anyhow!(
1148                    "host does not support redirecting interrupts"
1149                ));
1150                return;
1151            }
1152
1153            flags.set_redirect_interrupt(true);
1154            match self.inner.synic.allocate_event_flag(event) {
1155                Ok(flag) => flag,
1156                Err(err) => {
1157                    rpc.fail(err.context("failed to allocate event flag"));
1158                    return;
1159                }
1160            }
1161        } else {
1162            open_data.event_flag
1163        };
1164
1165        if supports_interrupt_redirection {
1166            self.inner.messages.send(&protocol::OpenChannel2 {
1167                open_channel,
1168                connection_id,
1169                event_flag,
1170                flags,
1171            });
1172        } else {
1173            self.inner.messages.send(&open_channel);
1174        }
1175
1176        channel.state = ChannelState::Opening {
1177            connection_id,
1178            redirected_event_flag: (request.incoming_event.is_some()).then_some(event_flag),
1179            redirected_event: request.incoming_event,
1180            rpc,
1181        }
1182    }
1183
1184    fn handle_restore_channel(
1185        &mut self,
1186        channel_id: ChannelId,
1187        request: RestoreRequest,
1188    ) -> Result<OpenOutput> {
1189        let mut channel = self.channels.get_mut(channel_id);
1190        if !matches!(channel.state, ChannelState::Restored) {
1191            anyhow::bail!("invalid channel state: {}", channel.state);
1192        }
1193
1194        if request.incoming_event.is_some() != request.redirected_event_flag.is_some() {
1195            anyhow::bail!("incoming event and redirected event flag must both be set or unset");
1196        }
1197
1198        if let Some((flag, event)) = request
1199            .redirected_event_flag
1200            .zip(request.incoming_event.as_ref())
1201        {
1202            self.inner.synic.restore_event_flag(flag, event)?;
1203        }
1204
1205        channel.state = ChannelState::Opened {
1206            connection_id: request.connection_id,
1207            redirected_event_flag: request.redirected_event_flag,
1208            redirected_event: request.incoming_event,
1209        };
1210        Ok(OpenOutput {
1211            redirected_event_flag: request.redirected_event_flag,
1212            guest_to_host_signal: self
1213                .inner
1214                .synic
1215                .guest_to_host_interrupt(request.connection_id),
1216        })
1217    }
1218
1219    fn handle_gpadl(&mut self, channel_id: ChannelId, rpc: FailableRpc<GpadlRequest, ()>) {
1220        let (request, rpc) = rpc.split();
1221        let mut channel = self.channels.get_mut(channel_id);
1222        if channel
1223            .gpadls
1224            .insert(request.id, GpadlState::Offered(rpc))
1225            .is_some()
1226        {
1227            panic!(
1228                "duplicate gpadl ID {:?} for channel {:?}.",
1229                request.id, channel_id
1230            );
1231        }
1232
1233        tracing::trace!(
1234            channel_id = channel_id.0,
1235            gpadl_id = request.id.0,
1236            count = request.count,
1237            len = request.buf.len(),
1238            "received gpadl request"
1239        );
1240
1241        // Split off the values that fit in the header.
1242        let (first, remaining) = if request.buf.len() > protocol::GpadlHeader::MAX_DATA_VALUES {
1243            request.buf.split_at(protocol::GpadlHeader::MAX_DATA_VALUES)
1244        } else {
1245            (request.buf.as_slice(), [].as_slice())
1246        };
1247
1248        let message = protocol::GpadlHeader {
1249            channel_id,
1250            gpadl_id: request.id,
1251            len: (request.buf.len() * size_of::<u64>())
1252                .try_into()
1253                .expect("Too many GPA values"),
1254            count: request.count,
1255        };
1256
1257        self.inner
1258            .messages
1259            .send_with_data(&message, first.as_bytes());
1260
1261        // Send GpadlBody messages for the remaining values.
1262        let message = protocol::GpadlBody {
1263            rsvd: 0,
1264            gpadl_id: request.id,
1265        };
1266        for chunk in remaining.chunks(protocol::GpadlBody::MAX_DATA_VALUES) {
1267            self.inner
1268                .messages
1269                .send_with_data(&message, chunk.as_bytes());
1270        }
1271    }
1272
1273    fn handle_gpadl_teardown(&mut self, channel_id: ChannelId, rpc: Rpc<GpadlId, ()>) {
1274        let (gpadl_id, rpc) = rpc.split();
1275        let mut channel = self.channels.get_mut(channel_id);
1276        let Some(gpadl_state) = channel.gpadls.get_mut(&gpadl_id) else {
1277            tracing::warn!(
1278                gpadl_id = gpadl_id.0,
1279                channel_id = channel_id.0,
1280                "Gpadl teardown for unknown gpadl or revoked channel"
1281            );
1282            return;
1283        };
1284
1285        match gpadl_state {
1286            GpadlState::Offered(_) => {
1287                tracing::warn!(
1288                    gpadl_id = gpadl_id.0,
1289                    channel_id = channel_id.0,
1290                    "gpadl teardown for offered gpadl"
1291                );
1292            }
1293            GpadlState::Created => {
1294                *gpadl_state = GpadlState::TearingDown { rpcs: vec![rpc] };
1295                // The caller must guarantee that GPADL teardown requests are only made
1296                // for unique GPADL IDs. This is currently enforced in vmbus_server by
1297                // blocking GPADL teardown messages for reserved channels.
1298                assert!(
1299                    self.inner
1300                        .teardown_gpadls
1301                        .insert(gpadl_id, channel_id)
1302                        .is_none(),
1303                    "Gpadl state validated above"
1304                );
1305
1306                self.inner.messages.send(&protocol::GpadlTeardown {
1307                    channel_id,
1308                    gpadl_id,
1309                });
1310            }
1311            GpadlState::TearingDown { rpcs } => {
1312                rpcs.push(rpc);
1313            }
1314        }
1315    }
1316
1317    fn handle_close_channel(&mut self, channel_id: ChannelId) {
1318        let mut channel = self.channels.get_mut(channel_id);
1319        self.inner.close_channel(channel_id, &mut channel);
1320    }
1321
1322    fn handle_modify_channel(&mut self, channel_id: ChannelId, rpc: Rpc<ModifyRequest, i32>) {
1323        // The client doesn't support versions below Iron, so we always expect the host to send a
1324        // ModifyChannelResponse. This means we don't need to worry about sending a ChannelResponse
1325        // if that weren't supported.
1326        assert!(self.check_version(Version::Iron));
1327        let mut channel = self.channels.get_mut(channel_id);
1328        if channel.modify_response_send.is_some() {
1329            panic!("duplicate channel modify request {channel_id:?}");
1330        }
1331
1332        let (request, response) = rpc.split();
1333        channel.modify_response_send = Some(response);
1334        let payload = match request {
1335            ModifyRequest::TargetVp { target_vp } => protocol::ModifyChannel {
1336                channel_id,
1337                target_vp,
1338            },
1339        };
1340
1341        self.inner.messages.send(&payload);
1342    }
1343
1344    fn handle_channel_request(&mut self, channel_id: ChannelId, request: ChannelRequest) {
1345        match request {
1346            ChannelRequest::Open(rpc) => self.handle_open_channel(channel_id, rpc),
1347            ChannelRequest::Restore(rpc) => {
1348                rpc.handle_failable_sync(|request| self.handle_restore_channel(channel_id, request))
1349            }
1350            ChannelRequest::Gpadl(req) => self.handle_gpadl(channel_id, req),
1351            ChannelRequest::TeardownGpadl(req) => self.handle_gpadl_teardown(channel_id, req),
1352            ChannelRequest::Close(req) => {
1353                req.handle_sync(|()| self.handle_close_channel(channel_id))
1354            }
1355            ChannelRequest::Modify(req) => self.handle_modify_channel(channel_id, req),
1356        }
1357    }
1358
1359    async fn handle_task(&mut self, task: TaskRequest) {
1360        match task {
1361            TaskRequest::Inspect(deferred) => {
1362                deferred.inspect(&*self);
1363            }
1364            TaskRequest::Save(rpc) => rpc.handle_sync(|()| self.handle_save()),
1365            TaskRequest::Restore(rpc) => {
1366                rpc.handle_sync(|saved_state| self.handle_restore(saved_state))
1367            }
1368            TaskRequest::PostRestore(rpc) => rpc.handle_sync(|()| self.handle_post_restore()),
1369            TaskRequest::Start => self.handle_start(),
1370            TaskRequest::Stop(rpc) => rpc.handle(async |()| self.handle_stop().await).await,
1371        }
1372    }
1373
1374    /// Makes sure a channel is closed if the channel request stream was dropped.
1375    fn handle_device_removal(&mut self, channel_id: ChannelId) -> TriedRelease {
1376        let mut channel = self.channels.get_mut(channel_id);
1377        channel.is_client_released = true;
1378        // Close the channel if it is still open.
1379        if let ChannelState::Opened { .. } = channel.state {
1380            tracing::warn!(
1381                channel_id = channel_id.0,
1382                "Channel dropped without closing first"
1383            );
1384            self.inner.close_channel(channel_id, &mut channel);
1385        }
1386        channel.try_release(&mut self.inner.messages)
1387    }
1388
1389    /// Determines if the client is connected with at least the specified version.
1390    fn check_version(&self, version: Version) -> bool {
1391        matches!(self.state, ClientState::Connected { version: v, .. } if v.version >= version)
1392    }
1393
1394    fn handle_start(&mut self) {
1395        assert!(!self.running);
1396        self.msg_source.resume_message_stream();
1397        self.inner.messages.resume();
1398        self.running = true;
1399    }
1400
1401    async fn handle_stop(&mut self) {
1402        assert!(self.running);
1403
1404        loop {
1405            // Process messages until there are no more channels waiting for
1406            // responses. This is necessary to ensure that the saved state does
1407            // not have to support encoding revoked channels for which we are
1408            // waiting for GPADL or modify responses.
1409            while let Some((id, request)) = self.channels.revoked_channel_with_pending_request() {
1410                tracelimit::info_ratelimited!(
1411                    channel_id = id.0,
1412                    request,
1413                    "waiting for responses for channel"
1414                );
1415                assert!(self.process_next_message().await);
1416            }
1417
1418            if self.can_pause_resume() {
1419                self.inner.messages.pause();
1420            } else {
1421                // Mask the sint to pause the message stream. The host will
1422                // retry any queued messages after the sint is unmasked.
1423                self.msg_source.pause_message_stream();
1424                self.inner.messages.force_pause();
1425            }
1426
1427            // Continue processing messages until we hit EOF or get a pause
1428            // response.
1429            while self.process_next_message().await {}
1430
1431            // Ensure there are still no pending requests. If there are, resume
1432            // and go around again.
1433            if self
1434                .channels
1435                .revoked_channel_with_pending_request()
1436                .is_none()
1437            {
1438                break;
1439            }
1440            if !self.can_pause_resume() {
1441                self.msg_source.resume_message_stream();
1442            }
1443            self.inner.messages.resume();
1444        }
1445
1446        tracing::debug!("messages drained");
1447        // Because the run loop awaits all async operations, there is no need for rundown.
1448        self.running = false;
1449    }
1450
1451    async fn process_next_message(&mut self) -> bool {
1452        let mut buf = [0; protocol::MAX_MESSAGE_SIZE];
1453        let recv = self.msg_source.recv(&mut buf);
1454        // Concurrently flush until there is no more work to do, since pending
1455        // messages may be blocking responses from the host.
1456        let flush = async {
1457            self.inner.messages.flush_messages().await;
1458            std::future::pending().await
1459        };
1460        let size = (recv, flush)
1461            .race()
1462            .await
1463            .expect("Fatal error reading messages from synic");
1464        if size == 0 {
1465            return false;
1466        }
1467        self.handle_synic_message(&buf[..size])
1468    }
1469
1470    /// Returns whether the server supports in-band messages to pause/resume the
1471    /// message stream.
1472    ///
1473    /// For hosts where this is not supported, we mask the sint to pause new
1474    /// messages being queued to the sint, then drain the messages. This does
1475    /// not work with some host implementations, which cannot support draining
1476    /// the message queue while the sint is masked (due to the use of
1477    /// HvPostMessageDirect).
1478    fn can_pause_resume(&self) -> bool {
1479        if let ClientState::Connected { version, .. } = self.state {
1480            version.feature_flags.pause_resume()
1481        } else {
1482            false
1483        }
1484    }
1485
1486    async fn run(&mut self) {
1487        let mut buf = [0; protocol::MAX_MESSAGE_SIZE];
1488        loop {
1489            let mut message_recv =
1490                OptionFuture::from(self.running.then(|| self.msg_source.recv(&mut buf).fuse()));
1491
1492            // If there are pending outgoing messages, the host is backed up.
1493            // Try to flush the queue, and in the meantime, stop generating new
1494            // messages by stopping processing client requests, so as to avoid
1495            // the outgoing message queue growing without bound.
1496            //
1497            // We still need to process incoming messages when in this state,
1498            // even though they may generate additional outgoing messages, to
1499            // avoid a deadlock with the host. The host can always DoS the
1500            // guest, so this is not an attack vector.
1501            let host_backed_up = !self.inner.messages.is_empty();
1502            let mut flush_messages = OptionFuture::from(
1503                (self.running && host_backed_up)
1504                    .then(|| self.inner.messages.flush_messages().fuse()),
1505            );
1506
1507            let mut client_request_recv = OptionFuture::from(
1508                (self.running && !host_backed_up).then(|| self.client_request_recv.next()),
1509            );
1510
1511            let mut channel_requests = OptionFuture::from(
1512                (self.running && !host_backed_up)
1513                    .then(|| self.inner.channel_requests.select_next_some()),
1514            );
1515
1516            futures::select! { // merge semantics
1517                _r = pin!(flush_messages) => {}
1518                r = self.task_recv.next() => {
1519                    if let Some(task) = r {
1520                        self.handle_task(task).await;
1521                    } else {
1522                        break;
1523                    }
1524                }
1525                r = client_request_recv => {
1526                    if let Some(Some(request)) = r {
1527                        self.handle_client_request(request);
1528                    } else {
1529                        break;
1530                    }
1531                }
1532                r = channel_requests => {
1533                    match r.unwrap() {
1534                        (id, Some(request)) => self.handle_channel_request(id, request),
1535                        (id, _) => {
1536                            self.handle_device_removal(id);
1537                        }
1538                    }
1539                }
1540                r = message_recv => {
1541                    match r.unwrap() {
1542                        Ok(size) => {
1543                            if size == 0 {
1544                                panic!("Unexpected end of file reading messages from synic.");
1545                            }
1546
1547                            self.handle_synic_message(&buf[..size]);
1548                        }
1549                        Err(err) => {
1550                            panic!("Error reading messages from synic: {err:?}");
1551                        }
1552                    }
1553                }
1554                complete => break,
1555            }
1556        }
1557    }
1558}
1559
1560impl ClientTaskInner {
1561    fn close_channel(&mut self, channel_id: ChannelId, channel: &mut Channel) {
1562        if let ChannelState::Opened {
1563            redirected_event_flag,
1564            ..
1565        } = channel.state
1566        {
1567            if let Some(flag) = redirected_event_flag {
1568                self.synic.free_event_flag(flag);
1569            }
1570            tracing::info!(channel_id = channel_id.0, "closing channel on host");
1571            self.messages.send(&protocol::CloseChannel { channel_id });
1572            channel.state = ChannelState::Offered;
1573        } else {
1574            tracing::warn!(
1575                id = %channel_id.0,
1576                channel_state = %channel.state,
1577                "invalid channel state for close channel"
1578            );
1579        }
1580    }
1581}
1582
1583#[derive(Debug, Inspect)]
1584#[inspect(external_tag)]
1585enum GpadlState {
1586    /// GpadlHeader has been sent to the host.
1587    Offered(#[inspect(skip)] FailableRpc<(), ()>),
1588    /// Host has responded with GpadlCreated.
1589    Created,
1590    /// GpadlTeardown message has been sent to the host.
1591    TearingDown {
1592        #[inspect(skip)]
1593        rpcs: Vec<Rpc<(), ()>>,
1594    },
1595}
1596
1597#[derive(Inspect)]
1598struct OutgoingMessages {
1599    #[inspect(skip)]
1600    poster: Box<dyn PollPostMessage>,
1601    #[inspect(with = "|x| x.len()")]
1602    queued: VecDeque<OutgoingMessage>,
1603    state: OutgoingMessageState,
1604}
1605
1606#[derive(Inspect, PartialEq, Eq, Debug)]
1607enum OutgoingMessageState {
1608    Running,
1609    SendingPauseMessage,
1610    Paused,
1611}
1612
1613impl OutgoingMessages {
1614    fn send<T: IntoBytes + protocol::VmbusMessage + std::fmt::Debug + Immutable + KnownLayout>(
1615        &mut self,
1616        msg: &T,
1617    ) {
1618        self.send_with_data(msg, &[])
1619    }
1620
1621    fn send_with_data<
1622        T: IntoBytes + protocol::VmbusMessage + std::fmt::Debug + Immutable + KnownLayout,
1623    >(
1624        &mut self,
1625        msg: &T,
1626        data: &[u8],
1627    ) {
1628        tracing::trace!(typ = ?T::MESSAGE_TYPE, "Sending message to host");
1629        let msg = OutgoingMessage::with_data(msg, data);
1630        if self.queued.is_empty() && self.state == OutgoingMessageState::Running {
1631            let r = self.poster.poll_post_message(
1632                &mut Context::from_waker(std::task::Waker::noop()),
1633                protocol::VMBUS_MESSAGE_REDIRECT_CONNECTION_ID,
1634                1,
1635                msg.data(),
1636            );
1637            if let Poll::Ready(()) = r {
1638                return;
1639            }
1640        }
1641        tracing::trace!("queueing message");
1642        self.queued.push_back(msg);
1643    }
1644
1645    async fn flush_messages(&mut self) {
1646        let mut send = async |msg: &OutgoingMessage| {
1647            poll_fn(|cx| {
1648                self.poster.poll_post_message(
1649                    cx,
1650                    protocol::VMBUS_MESSAGE_REDIRECT_CONNECTION_ID,
1651                    1,
1652                    msg.data(),
1653                )
1654            })
1655            .await
1656        };
1657        match self.state {
1658            OutgoingMessageState::Running => {
1659                while let Some(msg) = self.queued.front() {
1660                    send(msg).await;
1661                    tracing::trace!("sent queued message");
1662                    self.queued.pop_front();
1663                }
1664            }
1665            OutgoingMessageState::SendingPauseMessage => {
1666                send(&OutgoingMessage::new(&protocol::Pause)).await;
1667                tracing::trace!("sent pause message");
1668                self.state = OutgoingMessageState::Paused;
1669            }
1670            OutgoingMessageState::Paused => {}
1671        }
1672    }
1673
1674    /// Pause by sending a pause message to the host. This will cause the host
1675    /// to stop sending messages after sending a pause response.
1676    fn pause(&mut self) {
1677        assert_eq!(self.state, OutgoingMessageState::Running);
1678        self.state = OutgoingMessageState::SendingPauseMessage;
1679        // Queue a resume message to be sent later.
1680        self.queued
1681            .push_front(OutgoingMessage::new(&protocol::Resume));
1682    }
1683
1684    /// Force a pause by setting the state to Paused. This is used when the
1685    /// host does not support in-band pause/resume messages, in which case
1686    /// the SINT is masked to force the host to stop sending messages.
1687    fn force_pause(&mut self) {
1688        assert_eq!(self.state, OutgoingMessageState::Running);
1689        self.state = OutgoingMessageState::Paused;
1690    }
1691
1692    fn resume(&mut self) {
1693        assert_eq!(self.state, OutgoingMessageState::Paused);
1694        self.state = OutgoingMessageState::Running;
1695    }
1696
1697    fn is_empty(&self) -> bool {
1698        self.queued.is_empty()
1699    }
1700}
1701
1702#[derive(Inspect)]
1703struct ClientTaskInner {
1704    messages: OutgoingMessages,
1705    #[inspect(with = "|x| inspect::iter_by_key(x).map_key(|id| id.0)")]
1706    teardown_gpadls: HashMap<GpadlId, ChannelId>,
1707    #[inspect(skip)]
1708    channel_requests: SelectAll<TaggedStream<ChannelId, mesh::Receiver<ChannelRequest>>>,
1709    synic: SynicState,
1710}
1711
1712#[derive(Inspect)]
1713struct SynicState {
1714    #[inspect(skip)]
1715    event_client: Arc<dyn SynicEventClient>,
1716    #[inspect(iter_by_index)]
1717    event_flag_state: Vec<bool>,
1718}
1719
1720#[derive(Inspect, Default)]
1721#[inspect(transparent)]
1722struct ChannelList(
1723    #[inspect(with = "|x| inspect::iter_by_key(x).map_key(|id| id.0)")] HashMap<ChannelId, Channel>,
1724);
1725
1726/// A reference to a channel that can be used to remove the channel from the map
1727/// as well.
1728struct ChannelRef<'a>(hash_map::OccupiedEntry<'a, ChannelId, Channel>);
1729
1730/// A tag value used to indicate that [`ChannelRef::try_release`] has been called.
1731/// This is useful as a return value for methods that might transition a channel
1732/// into a fully released state.
1733struct TriedRelease(());
1734
1735impl ChannelRef<'_> {
1736    /// If the channel has been fully released (revoked, released by the client,
1737    /// no pending requests), notifes the server and removes this channel from
1738    /// the map.
1739    fn try_release(self, messages: &mut OutgoingMessages) -> TriedRelease {
1740        if self.is_client_released
1741            && matches!(self.state, ChannelState::Revoked)
1742            && self.pending_request().is_none()
1743        {
1744            let channel_id = *self.0.key();
1745            tracelimit::info_ratelimited!(channel_id = channel_id.0, "releasing channel");
1746            messages.send(&protocol::RelIdReleased { channel_id });
1747            self.0.remove();
1748        }
1749        TriedRelease(())
1750    }
1751}
1752
1753impl Deref for ChannelRef<'_> {
1754    type Target = Channel;
1755
1756    fn deref(&self) -> &Self::Target {
1757        self.0.get()
1758    }
1759}
1760
1761impl DerefMut for ChannelRef<'_> {
1762    fn deref_mut(&mut self) -> &mut Self::Target {
1763        self.0.get_mut()
1764    }
1765}
1766
1767impl ChannelList {
1768    fn revoked_channel_with_pending_request(&self) -> Option<(ChannelId, &'static str)> {
1769        self.0.iter().find_map(|(&id, channel)| {
1770            if !matches!(channel.state, ChannelState::Revoked) {
1771                return None;
1772            }
1773            Some((id, channel.pending_request()?))
1774        })
1775    }
1776
1777    #[track_caller]
1778    fn get_mut(&mut self, channel_id: ChannelId) -> ChannelRef<'_> {
1779        match self.0.entry(channel_id) {
1780            hash_map::Entry::Occupied(entry) => ChannelRef(entry),
1781            hash_map::Entry::Vacant(_) => {
1782                panic!("channel {:?} not found", channel_id);
1783            }
1784        }
1785    }
1786}
1787
1788impl SynicState {
1789    fn guest_to_host_interrupt(&self, connection_id: u32) -> Interrupt {
1790        Interrupt::from_fn({
1791            let event_client = self.event_client.clone();
1792            move || {
1793                if let Err(err) = event_client.signal_event(connection_id, 0) {
1794                    tracelimit::warn_ratelimited!(
1795                        error = &err as &dyn std::error::Error,
1796                        "failed to signal event"
1797                    );
1798                }
1799            }
1800        })
1801    }
1802
1803    const MAX_EVENT_FLAGS: u16 = 2047;
1804
1805    fn allocate_event_flag(&mut self, event: &Event) -> Result<u16> {
1806        let i = self
1807            .event_flag_state
1808            .iter()
1809            .position(|&used| !used)
1810            .ok_or(())
1811            .or_else(|()| {
1812                if self.event_flag_state.len() >= Self::MAX_EVENT_FLAGS as usize {
1813                    anyhow::bail!("out of event flags");
1814                }
1815                self.event_flag_state.push(false);
1816                Ok(self.event_flag_state.len() - 1)
1817            })?;
1818
1819        let event_flag = (i + 1) as u16;
1820        self.event_client
1821            .map_event(event_flag, event)
1822            .context("failed to map event")?;
1823        self.event_flag_state[i] = true;
1824        Ok(event_flag)
1825    }
1826
1827    fn restore_event_flag(&mut self, flag: u16, event: &Event) -> Result<()> {
1828        let i = (flag as usize)
1829            .checked_sub(1)
1830            .context("invalid event flag")?;
1831        if i >= Self::MAX_EVENT_FLAGS as usize {
1832            anyhow::bail!("invalid event flag");
1833        }
1834        if self.event_flag_state.len() <= i {
1835            self.event_flag_state.resize(i + 1, false);
1836        }
1837        if self.event_flag_state[i] {
1838            anyhow::bail!("event flag already in use");
1839        }
1840        self.event_client
1841            .map_event(flag, event)
1842            .context("failed to map event")?;
1843        self.event_flag_state[i] = true;
1844        Ok(())
1845    }
1846
1847    fn free_event_flag(&mut self, flag: u16) {
1848        let i = flag as usize - 1;
1849        assert!(i < self.event_flag_state.len());
1850        self.event_flag_state[i] = false;
1851    }
1852}
1853
1854#[cfg(test)]
1855mod tests {
1856    use super::*;
1857    use futures_concurrency::future::Join;
1858    use guid::Guid;
1859    use pal_async::DefaultDriver;
1860    use pal_async::async_test;
1861    use pal_async::timer::PolledTimer;
1862    use protocol::TargetInfo;
1863    use std::fmt::Debug;
1864    use std::task::ready;
1865    use std::time::Duration;
1866    use test_with_tracing::test;
1867    use vmbus_core::protocol::MessageHeader;
1868    use vmbus_core::protocol::MessageType;
1869    use vmbus_core::protocol::OfferFlags;
1870    use vmbus_core::protocol::UserDefinedData;
1871    use vmbus_core::protocol::VmbusMessage;
1872    use zerocopy::FromBytes;
1873    use zerocopy::FromZeros;
1874    use zerocopy::Immutable;
1875    use zerocopy::IntoBytes;
1876    use zerocopy::KnownLayout;
1877
1878    const VMBUS_TEST_CLIENT_ID: Guid = guid::guid!("e6e6e6e6-e6e6-e6e6-e6e6-e6e6e6e6e6e6");
1879
1880    fn in_msg<T: IntoBytes + Immutable + KnownLayout>(message_type: MessageType, t: T) -> Vec<u8> {
1881        let mut data = Vec::new();
1882        data.extend_from_slice(&message_type.0.to_ne_bytes());
1883        data.extend_from_slice(&0u32.to_ne_bytes());
1884        data.extend_from_slice(t.as_bytes());
1885        data
1886    }
1887
1888    #[track_caller]
1889    fn check_message<T>(msg: OutgoingMessage, chk: T)
1890    where
1891        T: IntoBytes + FromBytes + Immutable + KnownLayout + Debug + VmbusMessage,
1892    {
1893        check_message_with_data(msg, chk, &[]);
1894    }
1895
1896    #[track_caller]
1897    fn check_message_with_data<T>(msg: OutgoingMessage, chk: T, data: &[u8])
1898    where
1899        T: IntoBytes + FromBytes + Immutable + KnownLayout + Debug + VmbusMessage,
1900    {
1901        let chk_data = OutgoingMessage::with_data(&chk, data);
1902        if msg.data() != chk_data.data() {
1903            let (header, rest) = MessageHeader::read_from_prefix(msg.data()).unwrap();
1904            assert_eq!(header.message_type(), <T as VmbusMessage>::MESSAGE_TYPE);
1905            let (msg, rest) = T::read_from_prefix(rest).expect("incorrect message size");
1906            if msg.as_bytes() != chk.as_bytes() {
1907                panic!("mismatched messages, expected {:#?}, got {:#?}", chk, msg);
1908            }
1909            if rest != data {
1910                panic!("mismatched data, expected {:#?}, got {:#?}", data, rest);
1911            }
1912        }
1913    }
1914
1915    struct TestServer {
1916        messages: mesh::Receiver<OutgoingMessage>,
1917        send: mesh::Sender<Vec<u8>>,
1918    }
1919
1920    impl TestServer {
1921        async fn next(&mut self) -> Option<OutgoingMessage> {
1922            self.messages.next().await
1923        }
1924
1925        fn send(&self, msg: Vec<u8>) {
1926            self.send.send(msg);
1927        }
1928
1929        async fn connect(&mut self, client: &mut VmbusClient) -> ConnectResult {
1930            self.connect_with_channels(client, |_| {}).await
1931        }
1932
1933        async fn connect_with_channels(
1934            &mut self,
1935            client: &mut VmbusClient,
1936            send_offers: impl FnOnce(&mut Self),
1937        ) -> ConnectResult {
1938            let client_connect = client.connect(0, None, Guid::ZERO);
1939
1940            let server_connect = async {
1941                let _ = self.next().await.unwrap();
1942
1943                self.send(in_msg(
1944                    MessageType::VERSION_RESPONSE,
1945                    protocol::VersionResponse2 {
1946                        version_response: protocol::VersionResponse {
1947                            version_supported: 1,
1948                            connection_state: ConnectionState::SUCCESSFUL,
1949                            padding: 0,
1950                            selected_version_or_connection_id: 0,
1951                        },
1952                        supported_features: SUPPORTED_FEATURE_FLAGS.into(),
1953                    },
1954                ));
1955
1956                check_message(self.next().await.unwrap(), protocol::RequestOffers {});
1957
1958                send_offers(self);
1959                self.send(in_msg(MessageType::ALL_OFFERS_DELIVERED, [0x00]));
1960            };
1961
1962            let (connection, ()) = (client_connect, server_connect).join().await;
1963
1964            let connection = connection.unwrap();
1965            assert_eq!(connection.version.version, Version::Copper);
1966            assert_eq!(connection.version.feature_flags, SUPPORTED_FEATURE_FLAGS);
1967            connection
1968        }
1969
1970        async fn get_channel(&mut self, client: &mut VmbusClient) -> OfferInfo {
1971            let [channel] = self
1972                .get_channels(client, 1)
1973                .await
1974                .offers
1975                .try_into()
1976                .unwrap();
1977            channel
1978        }
1979
1980        async fn get_channels(&mut self, client: &mut VmbusClient, count: usize) -> ConnectResult {
1981            self.connect_with_channels(client, |this| {
1982                for i in 0..count {
1983                    let offer = protocol::OfferChannel {
1984                        interface_id: Guid::new_random(),
1985                        instance_id: Guid::new_random(),
1986                        rsvd: [0; 4],
1987                        flags: OfferFlags::new(),
1988                        mmio_megabytes: 0,
1989                        user_defined: UserDefinedData::new_zeroed(),
1990                        subchannel_index: 0,
1991                        mmio_megabytes_optional: 0,
1992                        channel_id: ChannelId(i as u32),
1993                        monitor_id: 0,
1994                        monitor_allocated: 0,
1995                        is_dedicated: 0,
1996                        connection_id: 0,
1997                    };
1998
1999                    this.send(in_msg(MessageType::OFFER_CHANNEL, offer));
2000                }
2001            })
2002            .await
2003        }
2004
2005        async fn stop_client(&mut self, client: &mut VmbusClient) {
2006            let client_stop = client.stop();
2007            let server_stop = async {
2008                check_message(self.next().await.unwrap(), protocol::Pause);
2009                self.send(in_msg(MessageType::PAUSE_RESPONSE, protocol::PauseResponse));
2010            };
2011            (client_stop, server_stop).join().await;
2012        }
2013
2014        async fn start_client(&mut self, client: &mut VmbusClient) {
2015            client.start();
2016            check_message(self.next().await.unwrap(), protocol::Resume);
2017        }
2018    }
2019
2020    struct TestServerClient {
2021        sender: mesh::Sender<OutgoingMessage>,
2022        timer: PolledTimer,
2023        deadline: Option<pal_async::timer::Instant>,
2024    }
2025
2026    impl PollPostMessage for TestServerClient {
2027        fn poll_post_message(
2028            &mut self,
2029            cx: &mut Context<'_>,
2030            _connection_id: u32,
2031            _typ: u32,
2032            msg: &[u8],
2033        ) -> Poll<()> {
2034            loop {
2035                if let Some(deadline) = self.deadline {
2036                    ready!(self.timer.poll_until(cx, deadline));
2037                    self.deadline = None;
2038                }
2039                // Randomly choose whether to delay the message.
2040                //
2041                // FUTURE: use some kind of deterministic test framework for this to
2042                // allow for reproducible tests.
2043                let mut b = [0];
2044                getrandom::fill(&mut b).unwrap();
2045                if b[0] % 4 == 0 {
2046                    self.deadline =
2047                        Some(pal_async::timer::Instant::now() + Duration::from_millis(10));
2048                } else {
2049                    let msg = OutgoingMessage::from_message(msg).unwrap();
2050                    tracing::info!(
2051                        msg = ?MessageHeader::read_from_prefix(msg.data()),
2052                        "sending message"
2053                    );
2054                    self.sender.send(msg);
2055                    break Poll::Ready(());
2056                }
2057            }
2058        }
2059    }
2060
2061    struct NoopSynicEvents;
2062
2063    impl SynicEventClient for NoopSynicEvents {
2064        fn map_event(&self, _event_flag: u16, _event: &Event) -> std::io::Result<()> {
2065            Ok(())
2066        }
2067
2068        fn unmap_event(&self, _event_flag: u16) {}
2069
2070        fn signal_event(&self, _connection_id: u32, _event_flag: u16) -> std::io::Result<()> {
2071            Err(std::io::ErrorKind::Unsupported.into())
2072        }
2073    }
2074
2075    struct TestMessageSource {
2076        msg_recv: mesh::Receiver<Vec<u8>>,
2077        paused: bool,
2078    }
2079
2080    impl AsyncRecv for TestMessageSource {
2081        fn poll_recv(
2082            &mut self,
2083            cx: &mut Context<'_>,
2084            mut bufs: &mut [std::io::IoSliceMut<'_>],
2085        ) -> Poll<std::io::Result<usize>> {
2086            let value = match self.msg_recv.poll_recv(cx) {
2087                Poll::Ready(v) => v.unwrap(),
2088                Poll::Pending => {
2089                    if self.paused {
2090                        return Poll::Ready(Ok(0));
2091                    } else {
2092                        return Poll::Pending;
2093                    }
2094                }
2095            };
2096            let mut remaining = value.as_slice();
2097            let mut total_size = 0;
2098            while !remaining.is_empty() && !bufs.is_empty() {
2099                let size = bufs[0].len().min(remaining.len());
2100                bufs[0][..size].copy_from_slice(&remaining[..size]);
2101                remaining = &remaining[size..];
2102                bufs = &mut bufs[1..];
2103                total_size += size;
2104            }
2105
2106            Ok(total_size).into()
2107        }
2108    }
2109
2110    impl VmbusMessageSource for TestMessageSource {
2111        fn pause_message_stream(&mut self) {
2112            self.paused = true;
2113        }
2114
2115        fn resume_message_stream(&mut self) {
2116            self.paused = false;
2117        }
2118    }
2119
2120    fn test_init(driver: &DefaultDriver) -> (TestServer, VmbusClient) {
2121        let (msg_send, msg_recv) = mesh::channel();
2122        let (synic_send, synic_recv) = mesh::channel();
2123        let server = TestServer {
2124            messages: synic_recv,
2125            send: msg_send,
2126        };
2127        let mut client = VmbusClientBuilder::new(
2128            NoopSynicEvents,
2129            TestMessageSource {
2130                msg_recv,
2131                paused: false,
2132            },
2133            TestServerClient {
2134                sender: synic_send,
2135                deadline: None,
2136                timer: PolledTimer::new(driver),
2137            },
2138        )
2139        .build(driver);
2140        client.start();
2141        (server, client)
2142    }
2143
2144    #[async_test]
2145    async fn test_initiate_contact_success(driver: DefaultDriver) {
2146        let (mut server, client) = test_init(&driver);
2147        let _recv = client
2148            .access
2149            .client_request_send
2150            .call(ClientRequest::Connect, ConnectRequest::default());
2151        check_message(
2152            server.next().await.unwrap(),
2153            protocol::InitiateContact2 {
2154                initiate_contact: protocol::InitiateContact {
2155                    version_requested: Version::Copper as u32,
2156                    target_message_vp: 0,
2157                    interrupt_page_or_target_info: TargetInfo::new()
2158                        .with_sint(2)
2159                        .with_vtl(0)
2160                        .with_feature_flags(SUPPORTED_FEATURE_FLAGS.into())
2161                        .into(),
2162                    parent_to_child_monitor_page_gpa: 0,
2163                    child_to_parent_monitor_page_gpa: 0,
2164                },
2165                ..FromZeros::new_zeroed()
2166            },
2167        );
2168    }
2169
2170    #[async_test]
2171    async fn test_connect_success(driver: DefaultDriver) {
2172        let (mut server, mut client) = test_init(&driver);
2173        let client_connect = client.connect(0, None, Guid::ZERO);
2174
2175        let server_connect = async {
2176            check_message(
2177                server.next().await.unwrap(),
2178                protocol::InitiateContact2 {
2179                    initiate_contact: protocol::InitiateContact {
2180                        version_requested: Version::Copper as u32,
2181                        target_message_vp: 0,
2182                        interrupt_page_or_target_info: TargetInfo::new()
2183                            .with_sint(2)
2184                            .with_vtl(0)
2185                            .with_feature_flags(SUPPORTED_FEATURE_FLAGS.into())
2186                            .into(),
2187                        parent_to_child_monitor_page_gpa: 0,
2188                        child_to_parent_monitor_page_gpa: 0,
2189                    },
2190                    ..FromZeros::new_zeroed()
2191                },
2192            );
2193
2194            server.send(in_msg(
2195                MessageType::VERSION_RESPONSE,
2196                protocol::VersionResponse2 {
2197                    version_response: protocol::VersionResponse {
2198                        version_supported: 1,
2199                        connection_state: ConnectionState::SUCCESSFUL,
2200                        padding: 0,
2201                        selected_version_or_connection_id: 0,
2202                    },
2203                    supported_features: SUPPORTED_FEATURE_FLAGS.into_bits(),
2204                },
2205            ));
2206
2207            check_message(server.next().await.unwrap(), protocol::RequestOffers {});
2208            server.send(in_msg(MessageType::ALL_OFFERS_DELIVERED, [0x00]));
2209        };
2210
2211        let (connection, ()) = (client_connect, server_connect).join().await;
2212        let connection = connection.unwrap();
2213
2214        assert_eq!(connection.version.version, Version::Copper);
2215        assert_eq!(connection.version.feature_flags, SUPPORTED_FEATURE_FLAGS);
2216    }
2217
2218    #[async_test]
2219    async fn test_feature_flags(driver: DefaultDriver) {
2220        let (mut server, mut client) = test_init(&driver);
2221        let client_connect = client.connect(0, None, Guid::ZERO);
2222
2223        let server_connect = async {
2224            check_message(
2225                server.next().await.unwrap(),
2226                protocol::InitiateContact2 {
2227                    initiate_contact: protocol::InitiateContact {
2228                        version_requested: Version::Copper as u32,
2229                        target_message_vp: 0,
2230                        interrupt_page_or_target_info: TargetInfo::new()
2231                            .with_sint(2)
2232                            .with_vtl(0)
2233                            .with_feature_flags(SUPPORTED_FEATURE_FLAGS.into())
2234                            .into(),
2235                        parent_to_child_monitor_page_gpa: 0,
2236                        child_to_parent_monitor_page_gpa: 0,
2237                    },
2238                    ..FromZeros::new_zeroed()
2239                },
2240            );
2241
2242            // Report the server doesn't support some of the feature flags, and make
2243            // sure this is reflected in the returned version.
2244            server.send(in_msg(
2245                MessageType::VERSION_RESPONSE,
2246                protocol::VersionResponse2 {
2247                    version_response: protocol::VersionResponse {
2248                        version_supported: 1,
2249                        connection_state: ConnectionState::SUCCESSFUL,
2250                        padding: 0,
2251                        selected_version_or_connection_id: 0,
2252                    },
2253                    supported_features: 2,
2254                },
2255            ));
2256
2257            check_message(server.next().await.unwrap(), protocol::RequestOffers {});
2258            server.send(in_msg(MessageType::ALL_OFFERS_DELIVERED, [0x00]));
2259        };
2260
2261        let (connection, ()) = (client_connect, server_connect).join().await;
2262        let connection = connection.unwrap();
2263
2264        assert_eq!(connection.version.version, Version::Copper);
2265        assert_eq!(
2266            connection.version.feature_flags,
2267            FeatureFlags::new().with_channel_interrupt_redirection(true)
2268        );
2269    }
2270
2271    #[async_test]
2272    async fn test_client_id(driver: DefaultDriver) {
2273        let (mut server, client) = test_init(&driver);
2274        let initiate_contact = ConnectRequest {
2275            client_id: VMBUS_TEST_CLIENT_ID,
2276            ..Default::default()
2277        };
2278        let _recv = client
2279            .access
2280            .client_request_send
2281            .call(ClientRequest::Connect, initiate_contact);
2282
2283        check_message(
2284            server.next().await.unwrap(),
2285            protocol::InitiateContact2 {
2286                initiate_contact: protocol::InitiateContact {
2287                    version_requested: Version::Copper as u32,
2288                    target_message_vp: 0,
2289                    interrupt_page_or_target_info: TargetInfo::new()
2290                        .with_sint(2)
2291                        .with_vtl(0)
2292                        .with_feature_flags(SUPPORTED_FEATURE_FLAGS.into())
2293                        .into(),
2294                    parent_to_child_monitor_page_gpa: 0,
2295                    child_to_parent_monitor_page_gpa: 0,
2296                },
2297                client_id: VMBUS_TEST_CLIENT_ID,
2298            },
2299        );
2300    }
2301
2302    #[async_test]
2303    async fn test_version_negotiation(driver: DefaultDriver) {
2304        let (mut server, mut client) = test_init(&driver);
2305        let client_connect = client.connect(0, None, Guid::ZERO);
2306
2307        let server_connect = async {
2308            check_message(
2309                server.next().await.unwrap(),
2310                protocol::InitiateContact2 {
2311                    initiate_contact: protocol::InitiateContact {
2312                        version_requested: Version::Copper as u32,
2313                        target_message_vp: 0,
2314                        interrupt_page_or_target_info: TargetInfo::new()
2315                            .with_sint(2)
2316                            .with_vtl(0)
2317                            .with_feature_flags(SUPPORTED_FEATURE_FLAGS.into())
2318                            .into(),
2319                        parent_to_child_monitor_page_gpa: 0,
2320                        child_to_parent_monitor_page_gpa: 0,
2321                    },
2322                    ..FromZeros::new_zeroed()
2323                },
2324            );
2325
2326            server.send(in_msg(
2327                MessageType::VERSION_RESPONSE,
2328                protocol::VersionResponse {
2329                    version_supported: 0,
2330                    connection_state: ConnectionState::SUCCESSFUL,
2331                    padding: 0,
2332                    selected_version_or_connection_id: 0,
2333                },
2334            ));
2335
2336            check_message(
2337                server.next().await.unwrap(),
2338                protocol::InitiateContact {
2339                    version_requested: Version::Iron as u32,
2340                    target_message_vp: 0,
2341                    interrupt_page_or_target_info: TargetInfo::new()
2342                        .with_sint(2)
2343                        .with_vtl(0)
2344                        .with_feature_flags(FeatureFlags::new().into())
2345                        .into(),
2346                    parent_to_child_monitor_page_gpa: 0,
2347                    child_to_parent_monitor_page_gpa: 0,
2348                },
2349            );
2350
2351            server.send(in_msg(
2352                MessageType::VERSION_RESPONSE,
2353                protocol::VersionResponse {
2354                    version_supported: 1,
2355                    connection_state: ConnectionState::SUCCESSFUL,
2356                    padding: 0,
2357                    selected_version_or_connection_id: 0,
2358                },
2359            ));
2360
2361            check_message(server.next().await.unwrap(), protocol::RequestOffers {});
2362            server.send(in_msg(MessageType::ALL_OFFERS_DELIVERED, [0x00]));
2363        };
2364
2365        let (connection, ()) = (client_connect, server_connect).join().await;
2366        let connection = connection.unwrap();
2367
2368        assert_eq!(connection.version.version, Version::Iron);
2369        assert_eq!(connection.version.feature_flags, FeatureFlags::new());
2370    }
2371
2372    #[async_test]
2373    async fn test_open_channel_success(driver: DefaultDriver) {
2374        let (mut server, mut client) = test_init(&driver);
2375        let channel = server.get_channel(&mut client).await;
2376
2377        let recv = channel.request_send.call(
2378            ChannelRequest::Open,
2379            OpenRequest {
2380                open_data: OpenData {
2381                    target_vp: 0,
2382                    ring_offset: 0,
2383                    ring_gpadl_id: GpadlId(0),
2384                    event_flag: 0,
2385                    connection_id: 0,
2386                    user_data: UserDefinedData::new_zeroed(),
2387                },
2388                incoming_event: None,
2389                use_vtl2_connection_id: false,
2390            },
2391        );
2392
2393        check_message(
2394            server.next().await.unwrap(),
2395            protocol::OpenChannel2 {
2396                open_channel: protocol::OpenChannel {
2397                    channel_id: ChannelId(0),
2398                    open_id: 0,
2399                    ring_buffer_gpadl_id: GpadlId(0),
2400                    target_vp: 0,
2401                    downstream_ring_buffer_page_offset: 0,
2402                    user_data: UserDefinedData::new_zeroed(),
2403                },
2404                connection_id: 0,
2405                event_flag: 0,
2406                flags: Default::default(),
2407            },
2408        );
2409
2410        server.send(in_msg(
2411            MessageType::OPEN_CHANNEL_RESULT,
2412            protocol::OpenResult {
2413                channel_id: ChannelId(0),
2414                open_id: 0,
2415                status: protocol::STATUS_SUCCESS as u32,
2416            },
2417        ));
2418
2419        recv.await.unwrap().unwrap();
2420    }
2421
2422    #[async_test]
2423    async fn test_open_channel_fail(driver: DefaultDriver) {
2424        let (mut server, mut client) = test_init(&driver);
2425        let channel = server.get_channel(&mut client).await;
2426
2427        let recv = channel.request_send.call(
2428            ChannelRequest::Open,
2429            OpenRequest {
2430                open_data: OpenData {
2431                    target_vp: 0,
2432                    ring_offset: 0,
2433                    ring_gpadl_id: GpadlId(0),
2434                    event_flag: 0,
2435                    connection_id: 0,
2436                    user_data: UserDefinedData::new_zeroed(),
2437                },
2438                incoming_event: None,
2439                use_vtl2_connection_id: false,
2440            },
2441        );
2442
2443        check_message(
2444            server.next().await.unwrap(),
2445            protocol::OpenChannel2 {
2446                open_channel: protocol::OpenChannel {
2447                    channel_id: ChannelId(0),
2448                    open_id: 0,
2449                    ring_buffer_gpadl_id: GpadlId(0),
2450                    target_vp: 0,
2451                    downstream_ring_buffer_page_offset: 0,
2452                    user_data: UserDefinedData::new_zeroed(),
2453                },
2454                connection_id: 0,
2455                event_flag: 0,
2456                flags: Default::default(),
2457            },
2458        );
2459
2460        server.send(in_msg(
2461            MessageType::OPEN_CHANNEL_RESULT,
2462            protocol::OpenResult {
2463                channel_id: ChannelId(0),
2464                open_id: 0,
2465                status: protocol::STATUS_UNSUCCESSFUL as u32,
2466            },
2467        ));
2468
2469        recv.await.unwrap().unwrap_err();
2470    }
2471
2472    #[async_test]
2473    async fn test_modify_channel(driver: DefaultDriver) {
2474        let (mut server, mut client) = test_init(&driver);
2475        let channel = server.get_channel(&mut client).await;
2476
2477        // N.B. A real server requires the channel to be open before sending this, but the test
2478        //      server doesn't care.
2479        let recv = channel.request_send.call(
2480            ChannelRequest::Modify,
2481            ModifyRequest::TargetVp { target_vp: 1 },
2482        );
2483
2484        check_message(
2485            server.next().await.unwrap(),
2486            protocol::ModifyChannel {
2487                channel_id: ChannelId(0),
2488                target_vp: 1,
2489            },
2490        );
2491
2492        server.send(in_msg(
2493            MessageType::MODIFY_CHANNEL_RESPONSE,
2494            protocol::ModifyChannelResponse {
2495                channel_id: ChannelId(0),
2496                status: protocol::STATUS_SUCCESS,
2497            },
2498        ));
2499
2500        let status = recv.await.unwrap();
2501        assert_eq!(status, protocol::STATUS_SUCCESS);
2502    }
2503
2504    #[async_test]
2505    async fn test_save_restore_connected(driver: DefaultDriver) {
2506        let (mut server, mut client) = test_init(&driver);
2507        server.connect(&mut client).await;
2508        server.stop_client(&mut client).await;
2509        let s0 = client.save().await;
2510        let builder = client.sever().await;
2511        let mut client = builder.build(&driver);
2512        client.restore(s0.clone()).await.unwrap();
2513
2514        let s1 = client.save().await;
2515
2516        assert_eq!(s0, s1);
2517    }
2518
2519    #[async_test]
2520    async fn test_save_restore_connected_with_channel(driver: DefaultDriver) {
2521        let (mut server, mut client) = test_init(&driver);
2522        let c0 = server.get_channel(&mut client).await;
2523        server.stop_client(&mut client).await;
2524        let s0 = client.save().await;
2525        let builder = client.sever().await;
2526        let mut client = builder.build(&driver);
2527        let connection = client.restore(s0.clone()).await.unwrap().unwrap();
2528        let s1 = client.save().await;
2529        assert_eq!(s0, s1);
2530        assert_eq!(connection.offers[0].offer, c0.offer);
2531    }
2532
2533    #[async_test]
2534    async fn test_save_restore_connected_with_revoked_channel(driver: DefaultDriver) {
2535        let (mut server, mut client) = test_init(&driver);
2536        let c0 = server.get_channel(&mut client).await;
2537        server.send(in_msg(
2538            MessageType::RESCIND_CHANNEL_OFFER,
2539            protocol::RescindChannelOffer {
2540                channel_id: ChannelId(0),
2541            },
2542        ));
2543        c0.revoke_recv.await.unwrap();
2544        let rpc = c0.request_send.call(
2545            ChannelRequest::Modify,
2546            ModifyRequest::TargetVp { target_vp: 1 },
2547        );
2548
2549        check_message(
2550            server.next().await.unwrap(),
2551            protocol::ModifyChannel {
2552                channel_id: ChannelId(0),
2553                target_vp: 1,
2554            },
2555        );
2556
2557        let client_stop = client.stop();
2558        let server_stop = async {
2559            server.send(in_msg(
2560                MessageType::MODIFY_CHANNEL_RESPONSE,
2561                protocol::ModifyChannelResponse {
2562                    channel_id: ChannelId(0),
2563                    status: protocol::STATUS_SUCCESS,
2564                },
2565            ));
2566            check_message(server.next().await.unwrap(), protocol::Pause);
2567            server.send(in_msg(MessageType::PAUSE_RESPONSE, protocol::PauseResponse));
2568        };
2569        (client_stop, server_stop).join().await;
2570
2571        rpc.await.unwrap();
2572
2573        let s0 = client.save().await;
2574        let builder = client.sever().await;
2575        let mut client = builder.build(&driver);
2576        let connection = client.restore(s0.clone()).await.unwrap().unwrap();
2577        let s1 = client.save().await;
2578        assert_eq!(s0, s1);
2579        assert!(connection.offers.is_empty());
2580        server.start_client(&mut client).await;
2581        check_message(
2582            server.next().await.unwrap(),
2583            protocol::RelIdReleased {
2584                channel_id: ChannelId(0),
2585            },
2586        );
2587    }
2588
2589    #[async_test]
2590    async fn test_connect_fails_on_incorrect_state(driver: DefaultDriver) {
2591        let (mut server, mut client) = test_init(&driver);
2592        server.connect(&mut client).await;
2593        let err = client.connect(0, None, Guid::ZERO).await.unwrap_err();
2594        assert!(matches!(err, ConnectError::InvalidState), "{:?}", err);
2595    }
2596
2597    #[async_test]
2598    async fn test_hot_add_remove(driver: DefaultDriver) {
2599        let (mut server, mut client) = test_init(&driver);
2600
2601        let mut connection = server.connect(&mut client).await;
2602        let offer = protocol::OfferChannel {
2603            interface_id: Guid::new_random(),
2604            instance_id: Guid::new_random(),
2605            rsvd: [0; 4],
2606            flags: OfferFlags::new(),
2607            mmio_megabytes: 0,
2608            user_defined: UserDefinedData::new_zeroed(),
2609            subchannel_index: 0,
2610            mmio_megabytes_optional: 0,
2611            channel_id: ChannelId(5),
2612            monitor_id: 0,
2613            monitor_allocated: 0,
2614            is_dedicated: 0,
2615            connection_id: 0,
2616        };
2617
2618        server.send(in_msg(MessageType::OFFER_CHANNEL, offer));
2619        let info = connection.offer_recv.next().await.unwrap();
2620
2621        assert_eq!(offer, info.offer);
2622
2623        server.send(in_msg(
2624            MessageType::RESCIND_CHANNEL_OFFER,
2625            protocol::RescindChannelOffer {
2626                channel_id: ChannelId(5),
2627            },
2628        ));
2629
2630        info.revoke_recv.await.unwrap();
2631        drop(info.request_send);
2632
2633        check_message(
2634            server.next().await.unwrap(),
2635            protocol::RelIdReleased {
2636                channel_id: ChannelId(5),
2637            },
2638        );
2639    }
2640
2641    #[async_test]
2642    async fn test_gpadl_success(driver: DefaultDriver) {
2643        let (mut server, mut client) = test_init(&driver);
2644        let channel = server.get_channel(&mut client).await;
2645        let recv = channel.request_send.call(
2646            ChannelRequest::Gpadl,
2647            GpadlRequest {
2648                id: GpadlId(1),
2649                count: 1,
2650                buf: vec![5],
2651            },
2652        );
2653
2654        check_message_with_data(
2655            server.next().await.unwrap(),
2656            protocol::GpadlHeader {
2657                channel_id: ChannelId(0),
2658                gpadl_id: GpadlId(1),
2659                len: 8,
2660                count: 1,
2661            },
2662            0x5u64.as_bytes(),
2663        );
2664
2665        server.send(in_msg(
2666            MessageType::GPADL_CREATED,
2667            protocol::GpadlCreated {
2668                channel_id: ChannelId(0),
2669                gpadl_id: GpadlId(1),
2670                status: protocol::STATUS_SUCCESS,
2671            },
2672        ));
2673
2674        recv.await.unwrap().unwrap();
2675
2676        let rpc = channel
2677            .request_send
2678            .call(ChannelRequest::TeardownGpadl, GpadlId(1));
2679
2680        check_message(
2681            server.next().await.unwrap(),
2682            protocol::GpadlTeardown {
2683                channel_id: ChannelId(0),
2684                gpadl_id: GpadlId(1),
2685            },
2686        );
2687
2688        server.send(in_msg(
2689            MessageType::GPADL_TORNDOWN,
2690            protocol::GpadlTorndown {
2691                gpadl_id: GpadlId(1),
2692            },
2693        ));
2694
2695        rpc.await.unwrap();
2696    }
2697
2698    #[async_test]
2699    async fn test_gpadl_fail(driver: DefaultDriver) {
2700        let (mut server, mut client) = test_init(&driver);
2701        let channel = server.get_channel(&mut client).await;
2702        let recv = channel.request_send.call(
2703            ChannelRequest::Gpadl,
2704            GpadlRequest {
2705                id: GpadlId(1),
2706                count: 1,
2707                buf: vec![7],
2708            },
2709        );
2710
2711        check_message_with_data(
2712            server.next().await.unwrap(),
2713            protocol::GpadlHeader {
2714                channel_id: ChannelId(0),
2715                gpadl_id: GpadlId(1),
2716                len: 8,
2717                count: 1,
2718            },
2719            0x7u64.as_bytes(),
2720        );
2721
2722        server.send(in_msg(
2723            MessageType::GPADL_CREATED,
2724            protocol::GpadlCreated {
2725                channel_id: ChannelId(0),
2726                gpadl_id: GpadlId(1),
2727                status: protocol::STATUS_UNSUCCESSFUL,
2728            },
2729        ));
2730
2731        recv.await.unwrap().unwrap_err();
2732    }
2733
2734    #[async_test]
2735    async fn test_gpadl_with_revoke(driver: DefaultDriver) {
2736        let (mut server, mut client) = test_init(&driver);
2737        let channel = server.get_channel(&mut client).await;
2738        let channel_id = ChannelId(0);
2739        for gpadl_id in [1, 2, 3].map(GpadlId) {
2740            let recv = channel.request_send.call(
2741                ChannelRequest::Gpadl,
2742                GpadlRequest {
2743                    id: gpadl_id,
2744                    count: 1,
2745                    buf: vec![3],
2746                },
2747            );
2748
2749            check_message_with_data(
2750                server.next().await.unwrap(),
2751                protocol::GpadlHeader {
2752                    channel_id,
2753                    gpadl_id,
2754                    len: 8,
2755                    count: 1,
2756                },
2757                0x3u64.as_bytes(),
2758            );
2759
2760            server.send(in_msg(
2761                MessageType::GPADL_CREATED,
2762                protocol::GpadlCreated {
2763                    channel_id,
2764                    gpadl_id,
2765                    status: protocol::STATUS_SUCCESS,
2766                },
2767            ));
2768
2769            recv.await.unwrap().unwrap();
2770        }
2771
2772        let rpc = channel
2773            .request_send
2774            .call(ChannelRequest::TeardownGpadl, GpadlId(1));
2775
2776        check_message(
2777            server.next().await.unwrap(),
2778            protocol::GpadlTeardown {
2779                channel_id,
2780                gpadl_id: GpadlId(1),
2781            },
2782        );
2783
2784        server.send(in_msg(
2785            MessageType::RESCIND_CHANNEL_OFFER,
2786            protocol::RescindChannelOffer { channel_id },
2787        ));
2788
2789        let recv = channel.request_send.call_failable(
2790            ChannelRequest::Gpadl,
2791            GpadlRequest {
2792                id: GpadlId(4),
2793                count: 1,
2794                buf: vec![3],
2795            },
2796        );
2797
2798        check_message_with_data(
2799            server.next().await.unwrap(),
2800            protocol::GpadlHeader {
2801                channel_id,
2802                gpadl_id: GpadlId(4),
2803                len: 8,
2804                count: 1,
2805            },
2806            0x3u64.as_bytes(),
2807        );
2808
2809        server.send(in_msg(
2810            MessageType::GPADL_CREATED,
2811            protocol::GpadlCreated {
2812                channel_id,
2813                gpadl_id: GpadlId(4),
2814                status: protocol::STATUS_UNSUCCESSFUL,
2815            },
2816        ));
2817
2818        server.send(in_msg(
2819            MessageType::GPADL_TORNDOWN,
2820            protocol::GpadlTorndown {
2821                gpadl_id: GpadlId(1),
2822            },
2823        ));
2824
2825        rpc.await.unwrap();
2826        recv.await.unwrap_err();
2827
2828        channel.revoke_recv.await.unwrap();
2829
2830        let rpc = channel
2831            .request_send
2832            .call(ChannelRequest::TeardownGpadl, GpadlId(2));
2833        drop(channel.request_send);
2834
2835        check_message(
2836            server.next().await.unwrap(),
2837            protocol::GpadlTeardown {
2838                channel_id,
2839                gpadl_id: GpadlId(2),
2840            },
2841        );
2842
2843        server.send(in_msg(
2844            MessageType::GPADL_TORNDOWN,
2845            protocol::GpadlTorndown {
2846                gpadl_id: GpadlId(2),
2847            },
2848        ));
2849
2850        rpc.await.unwrap();
2851
2852        check_message(
2853            server.next().await.unwrap(),
2854            protocol::RelIdReleased { channel_id },
2855        );
2856    }
2857
2858    #[async_test]
2859    async fn test_modify_connection(driver: DefaultDriver) {
2860        let (mut server, mut client) = test_init(&driver);
2861        server.connect(&mut client).await;
2862        let call = client.access.client_request_send.call(
2863            ClientRequest::Modify,
2864            ModifyConnectionRequest {
2865                monitor_page: Some(MonitorPageGpas {
2866                    child_to_parent: 5,
2867                    parent_to_child: 6,
2868                }),
2869            },
2870        );
2871
2872        check_message(
2873            server.next().await.unwrap(),
2874            protocol::ModifyConnection {
2875                child_to_parent_monitor_page_gpa: 5,
2876                parent_to_child_monitor_page_gpa: 6,
2877            },
2878        );
2879
2880        server.send(in_msg(
2881            MessageType::MODIFY_CONNECTION_RESPONSE,
2882            protocol::ModifyConnectionResponse {
2883                connection_state: ConnectionState::FAILED_LOW_RESOURCES,
2884            },
2885        ));
2886
2887        let result = call.await.unwrap();
2888        assert_eq!(ConnectionState::FAILED_LOW_RESOURCES, result);
2889    }
2890
2891    #[async_test]
2892    async fn test_hvsock(driver: DefaultDriver) {
2893        let (mut server, mut client) = test_init(&driver);
2894        server.connect(&mut client).await;
2895        let request = HvsockConnectRequest {
2896            service_id: Guid::new_random(),
2897            endpoint_id: Guid::new_random(),
2898            silo_id: Guid::new_random(),
2899            hosted_silo_unaware: false,
2900        };
2901
2902        let resp = client.access().connect_hvsock(request);
2903        check_message(
2904            server.next().await.unwrap(),
2905            protocol::TlConnectRequest2 {
2906                base: protocol::TlConnectRequest {
2907                    service_id: request.service_id,
2908                    endpoint_id: request.endpoint_id,
2909                },
2910                silo_id: request.silo_id,
2911            },
2912        );
2913
2914        // Now send a failure result.
2915        server.send(in_msg(
2916            MessageType::TL_CONNECT_REQUEST_RESULT,
2917            protocol::TlConnectResult {
2918                service_id: request.service_id,
2919                endpoint_id: request.endpoint_id,
2920                status: protocol::STATUS_CONNECTION_REFUSED,
2921            },
2922        ));
2923
2924        let result = resp.await;
2925        assert!(result.is_none());
2926    }
2927
2928    #[async_test]
2929    async fn test_synic_event_flags(driver: DefaultDriver) {
2930        let (mut server, mut client) = test_init(&driver);
2931        let connection = server.get_channels(&mut client, 5).await;
2932        let event = Event::new();
2933
2934        for _ in 0..5 {
2935            for (i, channel) in connection.offers.iter().enumerate() {
2936                let recv = channel.request_send.call(
2937                    ChannelRequest::Open,
2938                    OpenRequest {
2939                        open_data: OpenData {
2940                            target_vp: 0,
2941                            ring_offset: 0,
2942                            ring_gpadl_id: GpadlId(0),
2943                            event_flag: 0,
2944                            connection_id: 0,
2945                            user_data: UserDefinedData::new_zeroed(),
2946                        },
2947                        incoming_event: Some(event.clone()),
2948                        use_vtl2_connection_id: false,
2949                    },
2950                );
2951
2952                let expected_event_flag = i as u16 + 1;
2953
2954                check_message(
2955                    server.next().await.unwrap(),
2956                    protocol::OpenChannel2 {
2957                        open_channel: protocol::OpenChannel {
2958                            channel_id: channel.offer.channel_id,
2959                            open_id: 0,
2960                            ring_buffer_gpadl_id: GpadlId(0),
2961                            target_vp: 0,
2962                            downstream_ring_buffer_page_offset: 0,
2963                            user_data: UserDefinedData::new_zeroed(),
2964                        },
2965                        connection_id: 0,
2966                        event_flag: expected_event_flag,
2967                        flags: OpenChannelFlags::new().with_redirect_interrupt(true),
2968                    },
2969                );
2970
2971                server.send(in_msg(
2972                    MessageType::OPEN_CHANNEL_RESULT,
2973                    protocol::OpenResult {
2974                        channel_id: channel.offer.channel_id,
2975                        open_id: 0,
2976                        status: protocol::STATUS_SUCCESS as u32,
2977                    },
2978                ));
2979
2980                let output = recv.await.unwrap().unwrap();
2981                assert_eq!(output.redirected_event_flag, Some(expected_event_flag));
2982            }
2983
2984            for (i, channel) in connection.offers.iter().enumerate() {
2985                // Close the channel to prepare for the next iteration of the loop.
2986                // The event flag should be the same each time.
2987                channel
2988                    .request_send
2989                    .call(ChannelRequest::Close, ())
2990                    .await
2991                    .unwrap();
2992
2993                check_message(
2994                    server.next().await.unwrap(),
2995                    protocol::CloseChannel {
2996                        channel_id: ChannelId(i as u32),
2997                    },
2998                );
2999            }
3000        }
3001    }
3002
3003    #[async_test]
3004    async fn test_revoke(driver: DefaultDriver) {
3005        let (mut server, mut client) = test_init(&driver);
3006        let channel = server.get_channel(&mut client).await;
3007
3008        server.send(in_msg(
3009            MessageType::RESCIND_CHANNEL_OFFER,
3010            protocol::RescindChannelOffer {
3011                channel_id: ChannelId(0),
3012            },
3013        ));
3014
3015        channel.revoke_recv.await.unwrap();
3016
3017        channel
3018            .request_send
3019            .call_failable(
3020                ChannelRequest::Open,
3021                OpenRequest {
3022                    open_data: OpenData {
3023                        target_vp: 0,
3024                        ring_offset: 0,
3025                        ring_gpadl_id: GpadlId(0),
3026                        event_flag: 0,
3027                        connection_id: 0,
3028                        user_data: UserDefinedData::new_zeroed(),
3029                    },
3030                    incoming_event: None,
3031                    use_vtl2_connection_id: false,
3032                },
3033            )
3034            .await
3035            .unwrap_err();
3036    }
3037
3038    #[async_test]
3039    #[should_panic(expected = "channel should not exist")]
3040    async fn test_reoffer_in_use_rel_id(driver: DefaultDriver) {
3041        let (mut server, mut client) = test_init(&driver);
3042        let mut connection = server.get_channels(&mut client, 1).await;
3043        let [channel] = connection.offers.try_into().unwrap();
3044
3045        server.send(in_msg(
3046            MessageType::RESCIND_CHANNEL_OFFER,
3047            protocol::RescindChannelOffer {
3048                channel_id: ChannelId(0),
3049            },
3050        ));
3051
3052        channel.revoke_recv.await.unwrap();
3053
3054        // This offer will cause a panic since the rel id is still in use.
3055        let offer = protocol::OfferChannel {
3056            interface_id: Guid::new_random(),
3057            instance_id: Guid::new_random(),
3058            rsvd: [0; 4],
3059            flags: OfferFlags::new(),
3060            mmio_megabytes: 0,
3061            user_defined: UserDefinedData::new_zeroed(),
3062            subchannel_index: 0,
3063            mmio_megabytes_optional: 0,
3064            channel_id: ChannelId(0),
3065            monitor_id: 0,
3066            monitor_allocated: 0,
3067            is_dedicated: 0,
3068            connection_id: 0,
3069        };
3070
3071        server.send(in_msg(MessageType::OFFER_CHANNEL, offer));
3072
3073        connection.offer_recv.next().await;
3074    }
3075
3076    #[async_test]
3077    async fn test_revoke_release_and_reoffer(driver: DefaultDriver) {
3078        let (mut server, mut client) = test_init(&driver);
3079        let mut connection = server.get_channels(&mut client, 1).await;
3080        let [channel] = connection.offers.try_into().unwrap();
3081
3082        server.send(in_msg(
3083            MessageType::RESCIND_CHANNEL_OFFER,
3084            protocol::RescindChannelOffer {
3085                channel_id: ChannelId(0),
3086            },
3087        ));
3088
3089        channel.revoke_recv.await.unwrap();
3090        drop(channel.request_send);
3091
3092        check_message(
3093            server.next().await.unwrap(),
3094            protocol::RelIdReleased {
3095                channel_id: ChannelId(0),
3096            },
3097        );
3098
3099        let offer = protocol::OfferChannel {
3100            interface_id: Guid::new_random(),
3101            instance_id: Guid::new_random(),
3102            rsvd: [0; 4],
3103            flags: OfferFlags::new(),
3104            mmio_megabytes: 0,
3105            user_defined: UserDefinedData::new_zeroed(),
3106            subchannel_index: 0,
3107            mmio_megabytes_optional: 0,
3108            channel_id: ChannelId(0),
3109            monitor_id: 0,
3110            monitor_allocated: 0,
3111            is_dedicated: 0,
3112            connection_id: 0,
3113        };
3114
3115        server.send(in_msg(MessageType::OFFER_CHANNEL, offer));
3116
3117        connection.offer_recv.next().await.unwrap();
3118    }
3119
3120    #[async_test]
3121    async fn test_release_revoke_and_reoffer(driver: DefaultDriver) {
3122        let (mut server, mut client) = test_init(&driver);
3123        let mut connection = server.get_channels(&mut client, 1).await;
3124        let [channel] = connection.offers.try_into().unwrap();
3125
3126        let open = channel.request_send.call_failable(
3127            ChannelRequest::Open,
3128            OpenRequest {
3129                open_data: OpenData {
3130                    target_vp: 0,
3131                    ring_offset: 0,
3132                    ring_gpadl_id: GpadlId(0),
3133                    event_flag: 0,
3134                    connection_id: 0,
3135                    user_data: UserDefinedData::new_zeroed(),
3136                },
3137                incoming_event: None,
3138                use_vtl2_connection_id: false,
3139            },
3140        );
3141
3142        let server_open = async {
3143            check_message(
3144                server.next().await.unwrap(),
3145                protocol::OpenChannel2 {
3146                    open_channel: protocol::OpenChannel {
3147                        channel_id: ChannelId(0),
3148                        open_id: 0,
3149                        ring_buffer_gpadl_id: GpadlId(0),
3150                        target_vp: 0,
3151                        downstream_ring_buffer_page_offset: 0,
3152                        user_data: UserDefinedData::new_zeroed(),
3153                    },
3154                    connection_id: 0,
3155                    event_flag: 0,
3156                    flags: Default::default(),
3157                },
3158            );
3159            server.send(in_msg(
3160                MessageType::OPEN_CHANNEL_RESULT,
3161                protocol::OpenResult {
3162                    channel_id: ChannelId(0),
3163                    open_id: 0,
3164                    status: protocol::STATUS_SUCCESS as u32,
3165                },
3166            ));
3167        };
3168
3169        (open, server_open).join().await.0.unwrap();
3170
3171        // This will close the channel but won't release it yet.
3172        drop(channel);
3173
3174        check_message(
3175            server.next().await.unwrap(),
3176            protocol::CloseChannel {
3177                channel_id: ChannelId(0),
3178            },
3179        );
3180
3181        server.send(in_msg(
3182            MessageType::RESCIND_CHANNEL_OFFER,
3183            protocol::RescindChannelOffer {
3184                channel_id: ChannelId(0),
3185            },
3186        ));
3187
3188        // Should be released.
3189        check_message(
3190            server.next().await.unwrap(),
3191            protocol::RelIdReleased {
3192                channel_id: ChannelId(0),
3193            },
3194        );
3195
3196        let offer = protocol::OfferChannel {
3197            interface_id: Guid::new_random(),
3198            instance_id: Guid::new_random(),
3199            rsvd: [0; 4],
3200            flags: OfferFlags::new(),
3201            mmio_megabytes: 0,
3202            user_defined: UserDefinedData::new_zeroed(),
3203            subchannel_index: 0,
3204            mmio_megabytes_optional: 0,
3205            channel_id: ChannelId(0),
3206            monitor_id: 0,
3207            monitor_allocated: 0,
3208            is_dedicated: 0,
3209            connection_id: 0,
3210        };
3211
3212        server.send(in_msg(MessageType::OFFER_CHANNEL, offer));
3213
3214        // New offer should come through.
3215        connection.offer_recv.next().await.unwrap();
3216    }
3217}