vmbus_client/
lib.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! Client driver for the Hyper-V Virtual Machine Bus (VmBus).
5
6#![expect(missing_docs)]
7#![forbid(unsafe_code)]
8
9pub mod driver;
10pub mod filter;
11mod hvsock;
12pub mod saved_state;
13
14pub use self::saved_state::SavedState;
15use anyhow::Context as _;
16use anyhow::Result;
17use futures::FutureExt;
18use futures::StreamExt;
19use futures::future::OptionFuture;
20use futures::stream::SelectAll;
21use futures_concurrency::future::Race;
22use guid::Guid;
23use inspect::Inspect;
24use mesh::rpc::FailableRpc;
25use mesh::rpc::Rpc;
26use mesh::rpc::RpcSend;
27use pal_async::task::Spawn;
28use pal_async::task::Task;
29use pal_event::Event;
30use std::collections::HashMap;
31use std::collections::VecDeque;
32use std::collections::hash_map;
33use std::convert::TryInto;
34use std::future::Future;
35use std::future::poll_fn;
36use std::ops::Deref;
37use std::ops::DerefMut;
38use std::pin::pin;
39use std::sync::Arc;
40use std::sync::atomic::AtomicU32;
41use std::sync::atomic::Ordering;
42use std::task::Context;
43use std::task::Poll;
44use thiserror::Error;
45use vmbus_async::async_dgram::AsyncRecv;
46use vmbus_async::async_dgram::AsyncRecvExt;
47use vmbus_channel::bus::GpadlRequest;
48use vmbus_channel::bus::ModifyRequest;
49use vmbus_channel::bus::OpenData;
50use vmbus_channel::gpadl::GpadlId;
51use vmbus_core::HvsockConnectRequest;
52use vmbus_core::OutgoingMessage;
53use vmbus_core::TaggedStream;
54use vmbus_core::VersionInfo;
55use vmbus_core::protocol;
56use vmbus_core::protocol::ChannelId;
57use vmbus_core::protocol::ConnectionState;
58use vmbus_core::protocol::FeatureFlags;
59use vmbus_core::protocol::Message;
60use vmbus_core::protocol::OpenChannelFlags;
61use vmbus_core::protocol::Version;
62use vmcore::interrupt::Interrupt;
63use vmcore::synic::MonitorPageGpas;
64use zerocopy::Immutable;
65use zerocopy::IntoBytes;
66use zerocopy::KnownLayout;
67
68const SINT: u8 = 2;
69const VTL: u8 = 0;
70const SUPPORTED_VERSIONS: &[Version] = &[Version::Iron, Version::Copper];
71const SUPPORTED_FEATURE_FLAGS: FeatureFlags = FeatureFlags::new()
72    .with_guest_specified_signal_parameters(true)
73    .with_channel_interrupt_redirection(true)
74    .with_modify_connection(true)
75    .with_client_id(true)
76    .with_pause_resume(true);
77
78/// The client interface synic events.
79pub trait SynicEventClient: Send + Sync {
80    /// Maps an incoming event signal on SINT7 to `event`.
81    fn map_event(&self, event_flag: u16, event: &Event) -> std::io::Result<()>;
82
83    /// Unmaps an event previously mapped with `map_event`.
84    fn unmap_event(&self, event_flag: u16);
85
86    /// Signals an event on the synic.
87    fn signal_event(&self, connection_id: u32, event_flag: u16) -> std::io::Result<()>;
88}
89
90/// A stream of vmbus messages that can be paused and resumed.
91pub trait VmbusMessageSource: AsyncRecv + Send {
92    /// Stop accepting new messages from the synic. After this is called, the message source must
93    /// return any pending messages already in the queue, and then return EOF.
94    fn pause_message_stream(&mut self) {}
95
96    /// Resume accepting new messages from the synic.
97    fn resume_message_stream(&mut self) {}
98}
99
100pub trait PollPostMessage: Send {
101    fn poll_post_message(
102        &mut self,
103        cx: &mut Context<'_>,
104        connection_id: u32,
105        typ: u32,
106        msg: &[u8],
107    ) -> Poll<()>;
108}
109
110pub struct VmbusClient {
111    task_send: mesh::Sender<TaskRequest>,
112    access: VmbusClientAccess,
113    task: Task<ClientTask>,
114}
115
116#[derive(Debug, thiserror::Error)]
117pub enum ConnectError {
118    #[error("invalid state to connect to the server")]
119    InvalidState,
120    #[error("no supported protocol versions")]
121    NoSupportedVersions,
122    #[error("failed to connect to the server: {0:?}")]
123    FailedToConnect(ConnectionState),
124}
125
126#[derive(Clone)]
127pub struct VmbusClientAccess {
128    client_request_send: mesh::Sender<ClientRequest>,
129}
130
131/// A builder for creating a [`VmbusClient`].
132pub struct VmbusClientBuilder {
133    event_client: Arc<dyn SynicEventClient>,
134    msg_source: Box<dyn VmbusMessageSource>,
135    msg_client: Box<dyn PollPostMessage>,
136}
137
138impl VmbusClientBuilder {
139    /// Creates a new instance of the builder with the given synic input.
140    pub fn new(
141        event_client: impl SynicEventClient + 'static,
142        msg_source: impl VmbusMessageSource + 'static,
143        msg_client: impl PollPostMessage + 'static,
144    ) -> Self {
145        Self {
146            event_client: Arc::new(event_client),
147            msg_source: Box::new(msg_source),
148            msg_client: Box::new(msg_client),
149        }
150    }
151
152    /// Creates a new instance with a receiver for incoming synic messages.
153    pub fn build(self, spawner: &impl Spawn) -> VmbusClient {
154        let (task_send, task_recv) = mesh::channel();
155        let (client_request_send, client_request_recv) = mesh::channel();
156
157        let inner = ClientTaskInner {
158            messages: OutgoingMessages {
159                poster: self.msg_client,
160                queued: VecDeque::new(),
161                state: OutgoingMessageState::Paused,
162            },
163            teardown_gpadls: HashMap::new(),
164            channel_requests: SelectAll::new(),
165            synic: SynicState {
166                event_flag_state: Vec::new(),
167                event_client: self.event_client,
168            },
169        };
170
171        let mut task = ClientTask {
172            inner,
173            channels: ChannelList::default(),
174            task_recv,
175            running: false,
176            msg_source: self.msg_source,
177            client_request_recv,
178            state: ClientState::Disconnected,
179            modify_request: None,
180            hvsock_tracker: hvsock::HvsockRequestTracker::new(),
181        };
182
183        let task = spawner.spawn("vmbus client", async move {
184            task.run().await;
185            task
186        });
187
188        VmbusClient {
189            access: VmbusClientAccess {
190                client_request_send,
191            },
192            task_send,
193            task,
194        }
195    }
196}
197
198impl VmbusClient {
199    /// Connects to the server, negotiating the protocol version and retrieving
200    /// the initial list of channel offers.
201    pub async fn connect(
202        &mut self,
203        target_message_vp: u32,
204        monitor_page: Option<MonitorPageGpas>,
205        client_id: Guid,
206    ) -> Result<ConnectResult, ConnectError> {
207        let request = ConnectRequest {
208            target_message_vp,
209            monitor_page,
210            client_id,
211        };
212
213        self.access
214            .client_request_send
215            .call(ClientRequest::Connect, request)
216            .await
217            .unwrap()
218    }
219
220    pub async fn unload(self) {
221        self.access
222            .client_request_send
223            .call(ClientRequest::Unload, ())
224            .await
225            .unwrap();
226
227        self.sever().await;
228    }
229
230    pub fn access(&self) -> &VmbusClientAccess {
231        &self.access
232    }
233
234    pub fn start(&mut self) {
235        self.task_send.send(TaskRequest::Start);
236    }
237
238    pub async fn stop(&mut self) {
239        self.task_send
240            .call(TaskRequest::Stop, ())
241            .await
242            .expect("Failed to send stop request");
243    }
244
245    pub async fn save(&self) -> SavedState {
246        self.task_send
247            .call(TaskRequest::Save, ())
248            .await
249            .expect("Failed to send save request")
250    }
251
252    pub async fn restore(
253        &mut self,
254        state: SavedState,
255    ) -> Result<Option<ConnectResult>, RestoreError> {
256        self.task_send
257            .call(TaskRequest::Restore, state)
258            .await
259            .expect("Failed to send restore request")
260    }
261
262    pub async fn post_restore(&mut self) {
263        self.task_send
264            .call(TaskRequest::PostRestore, ())
265            .await
266            .expect("Failed to send post-restore request");
267    }
268
269    async fn sever(self) -> VmbusClientBuilder {
270        drop(self.task_send);
271        let task = self.task.await;
272        VmbusClientBuilder {
273            event_client: task.inner.synic.event_client,
274            msg_source: task.msg_source,
275            msg_client: task.inner.messages.poster,
276        }
277    }
278}
279
280impl Inspect for VmbusClient {
281    fn inspect(&self, req: inspect::Request<'_>) {
282        self.task_send.send(TaskRequest::Inspect(req.defer()));
283    }
284}
285
286#[derive(Debug)]
287pub struct ConnectResult {
288    pub version: VersionInfo,
289    pub offers: Vec<OfferInfo>,
290    pub offer_recv: mesh::Receiver<OfferInfo>,
291}
292
293impl VmbusClientAccess {
294    pub async fn modify(&self, request: ModifyConnectionRequest) -> ConnectionState {
295        self.client_request_send
296            .call(ClientRequest::Modify, request)
297            .await
298            .expect("Failed to send modify request")
299    }
300
301    pub fn connect_hvsock(
302        &self,
303        request: HvsockConnectRequest,
304    ) -> impl Future<Output = Option<OfferInfo>> + use<> {
305        self.client_request_send
306            .call(ClientRequest::HvsockConnect, request)
307            .map(|r| r.ok().flatten())
308    }
309}
310
311#[derive(Debug)]
312pub struct OpenRequest {
313    pub open_data: OpenData,
314    pub incoming_event: Option<Event>,
315    pub use_vtl2_connection_id: bool,
316}
317
318#[derive(Debug)]
319pub struct RestoreRequest {
320    pub incoming_event: Option<Event>,
321    // FUTURE: move to saved state, don't rely on the caller.
322    pub redirected_event_flag: Option<u16>,
323    // FUTURE: ditto
324    pub connection_id: u32,
325}
326
327/// Expresses an operation requested of the client.
328pub enum ChannelRequest {
329    Open(FailableRpc<OpenRequest, OpenOutput>),
330    Restore(FailableRpc<RestoreRequest, OpenOutput>),
331    Close(Rpc<(), ()>),
332    Gpadl(FailableRpc<GpadlRequest, ()>),
333    TeardownGpadl(Rpc<GpadlId, ()>),
334    Modify(Rpc<ModifyRequest, i32>),
335}
336
337#[derive(Debug)]
338pub struct OpenOutput {
339    // FUTURE: remove this once it's part of the saved state.
340    pub redirected_event_flag: Option<u16>,
341}
342
343impl std::fmt::Display for ChannelRequest {
344    fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
345        let s = match self {
346            ChannelRequest::Open(_) => "Open",
347            ChannelRequest::Close(_) => "Close",
348            ChannelRequest::Restore(_) => "Restore",
349            ChannelRequest::Gpadl(_) => "Gpadl",
350            ChannelRequest::TeardownGpadl(_) => "TeardownGpadl",
351            ChannelRequest::Modify(_) => "Modify",
352        };
353        fmt.pad(s)
354    }
355}
356
357#[derive(Debug, Error)]
358pub enum RestoreError {
359    #[error("unsupported protocol version {0:#x}")]
360    UnsupportedVersion(u32),
361
362    #[error("unsupported feature flags {0:#x}")]
363    UnsupportedFeatureFlags(u32),
364
365    #[error("duplicate channel id {0}")]
366    DuplicateChannelId(u32),
367
368    #[error("duplicate gpadl id {0}")]
369    DuplicateGpadlId(u32),
370
371    #[error("gpadl for unknown channel id {0}")]
372    GpadlForUnknownChannelId(u32),
373
374    #[error("invalid pending message")]
375    InvalidPendingMessage(#[source] vmbus_core::MessageTooLarge),
376
377    #[error("failed to offer restored channel")]
378    OfferFailed(#[source] anyhow::Error),
379}
380
381/// Provides the offer details from the server in addition to both a channel
382/// to request client actions and a channel to receive server responses.
383#[derive(Debug, Inspect)]
384pub struct OfferInfo {
385    pub offer: protocol::OfferChannel,
386    #[inspect(skip)]
387    pub guest_to_host_interrupt: Interrupt,
388    #[inspect(skip)]
389    pub request_send: mesh::Sender<ChannelRequest>,
390    #[inspect(skip)]
391    pub revoke_recv: mesh::OneshotReceiver<()>,
392}
393
394#[derive(Debug)]
395enum ClientRequest {
396    Connect(Rpc<ConnectRequest, Result<ConnectResult, ConnectError>>),
397    Unload(Rpc<(), ()>),
398    Modify(Rpc<ModifyConnectionRequest, ConnectionState>),
399    HvsockConnect(Rpc<HvsockConnectRequest, Option<OfferInfo>>),
400}
401
402impl std::fmt::Display for ClientRequest {
403    fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
404        let s = match self {
405            ClientRequest::Connect(..) => "Connect",
406            ClientRequest::Unload { .. } => "Unload",
407            ClientRequest::Modify(..) => "Modify",
408            ClientRequest::HvsockConnect(..) => "HvsockConnect",
409        };
410        fmt.pad(s)
411    }
412}
413
414enum TaskRequest {
415    Inspect(inspect::Deferred),
416    Save(Rpc<(), SavedState>),
417    Restore(Rpc<SavedState, Result<Option<ConnectResult>, RestoreError>>),
418    PostRestore(Rpc<(), ()>),
419    Start,
420    Stop(Rpc<(), ()>),
421}
422
423/// The overall state machine used to drive which actions the client can legally
424/// take. This primarily pertains to overall client activity but has a
425/// side-effect of limiting whether or not channels can perform actions.
426#[derive(Inspect)]
427#[inspect(external_tag)]
428enum ClientState {
429    /// The client has yet to connect to the server.
430    Disconnected,
431    /// The client has initiated contact with the server.
432    Connecting {
433        version: Version,
434        #[inspect(skip)]
435        rpc: Rpc<ConnectRequest, Result<ConnectResult, ConnectError>>,
436    },
437    /// The client has negotiated the protocol version with the server.
438    Connected {
439        version: VersionInfo,
440        #[inspect(skip)]
441        offer_send: mesh::Sender<OfferInfo>,
442    },
443    /// The client has requested offers from the server.
444    RequestingOffers {
445        version: VersionInfo,
446        #[inspect(skip)]
447        rpc: Rpc<(), Result<ConnectResult, ConnectError>>,
448        #[inspect(skip)]
449        offers: Vec<OfferInfo>,
450    },
451    /// The client has initiated an unload from the server.
452    Disconnecting {
453        version: VersionInfo,
454        #[inspect(skip)]
455        rpc: Rpc<(), ()>,
456    },
457}
458
459impl ClientState {
460    fn get_version(&self) -> Option<VersionInfo> {
461        match self {
462            ClientState::Connected { version, .. } => Some(*version),
463            ClientState::RequestingOffers { version, .. } => Some(*version),
464            ClientState::Disconnecting { version, .. } => Some(*version),
465            ClientState::Disconnected | ClientState::Connecting { .. } => None,
466        }
467    }
468}
469
470impl std::fmt::Display for ClientState {
471    fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
472        let s = match self {
473            ClientState::Disconnected => "Disconnected",
474            ClientState::Connecting { .. } => "Connecting",
475            ClientState::Connected { .. } => "Connected",
476            ClientState::RequestingOffers { .. } => "RequestingOffers",
477            ClientState::Disconnecting { .. } => "Disconnecting",
478        };
479        fmt.pad(s)
480    }
481}
482
483#[derive(Copy, Clone, Debug, Default)]
484struct ConnectRequest {
485    target_message_vp: u32,
486    monitor_page: Option<MonitorPageGpas>,
487    client_id: Guid,
488}
489
490#[derive(Copy, Clone, Debug, Default)]
491pub struct ModifyConnectionRequest {
492    pub monitor_page: Option<MonitorPageGpas>,
493}
494
495impl From<ModifyConnectionRequest> for protocol::ModifyConnection {
496    fn from(value: ModifyConnectionRequest) -> Self {
497        let monitor_page = value.monitor_page.unwrap_or_default();
498
499        Self {
500            parent_to_child_monitor_page_gpa: monitor_page.parent_to_child,
501            child_to_parent_monitor_page_gpa: monitor_page.child_to_parent,
502        }
503    }
504}
505
506/// The per-channel state which dictates which whether or not a channel can
507/// request an Open/Close. As GPADLs can happen outside this loop there is no
508/// state tied to GPADL actions.
509#[derive(Debug, Inspect)]
510#[inspect(external_tag)]
511enum ChannelState {
512    /// The channel has been offered to the client.
513    Offered,
514    /// The channel has requested the server to be opened.
515    Opening {
516        redirected_event_flag: Option<u16>,
517        #[inspect(skip)]
518        redirected_event: Option<Event>,
519        #[inspect(skip)]
520        rpc: FailableRpc<(), OpenOutput>,
521    },
522    /// The channel has been restored but not claimed.
523    Restored,
524    /// The channel has been successfully opened.
525    Opened {
526        redirected_event_flag: Option<u16>,
527        #[inspect(skip)]
528        redirected_event: Option<Event>,
529    },
530    /// The channel has been revoked by the server.
531    Revoked,
532}
533
534impl std::fmt::Display for ChannelState {
535    fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
536        let s = match self {
537            ChannelState::Opening { .. } => "Opening",
538            ChannelState::Offered => "Offered",
539            ChannelState::Opened { .. } => "Opened",
540            ChannelState::Restored => "Restored",
541            ChannelState::Revoked => "Revoked",
542        };
543        fmt.pad(s)
544    }
545}
546
547#[derive(Debug, Inspect)]
548struct Channel {
549    offer: protocol::OfferChannel,
550    // When dropped, notifies the caller the channel has been revoked.
551    #[inspect(skip)]
552    revoke_send: Option<mesh::OneshotSender<()>>,
553    state: ChannelState,
554    #[inspect(with = "|x| x.is_some()")]
555    modify_response_send: Option<Rpc<(), i32>>,
556    #[inspect(with = "|x| inspect::iter_by_key(x).map_key(|x| x.0)")]
557    gpadls: HashMap<GpadlId, GpadlState>,
558    is_client_released: bool,
559    connection_id: Arc<AtomicU32>,
560}
561
562impl Channel {
563    fn pending_request(&self) -> Option<&'static str> {
564        if self.modify_response_send.is_some() {
565            return Some("modify");
566        }
567        self.gpadls.iter().find_map(|(_, gpadl)| match gpadl {
568            GpadlState::Offered(_) => Some("creating gpadl"),
569            GpadlState::Created => None,
570            GpadlState::TearingDown { .. } => Some("tearing down gpadl"),
571        })
572    }
573}
574
575#[derive(Inspect)]
576struct ClientTask {
577    #[inspect(flatten)]
578    inner: ClientTaskInner,
579    channels: ChannelList,
580    state: ClientState,
581    hvsock_tracker: hvsock::HvsockRequestTracker,
582    running: bool,
583    #[inspect(with = "|x| x.is_some()")]
584    modify_request: Option<Rpc<ModifyConnectionRequest, ConnectionState>>,
585    #[inspect(skip)]
586    msg_source: Box<dyn VmbusMessageSource>,
587    #[inspect(skip)]
588    task_recv: mesh::Receiver<TaskRequest>,
589    #[inspect(skip)]
590    client_request_recv: mesh::Receiver<ClientRequest>,
591}
592
593impl ClientTask {
594    fn handle_initiate_contact(
595        &mut self,
596        rpc: Rpc<ConnectRequest, Result<ConnectResult, ConnectError>>,
597        version: Version,
598    ) {
599        let ClientState::Disconnected = self.state else {
600            tracing::warn!(client_state = %self.state, "invalid client state for InitiateContact");
601            rpc.complete(Err(ConnectError::InvalidState));
602            return;
603        };
604        let feature_flags = if version >= Version::Copper {
605            SUPPORTED_FEATURE_FLAGS
606        } else {
607            FeatureFlags::new()
608        };
609
610        let request = rpc.input();
611
612        tracing::debug!(version = ?version, ?feature_flags, "VmBus client connecting");
613        let target_info = protocol::TargetInfo::new()
614            .with_sint(SINT)
615            .with_vtl(VTL)
616            .with_feature_flags(feature_flags.into());
617        let monitor_page = request.monitor_page.unwrap_or_default();
618        let msg = protocol::InitiateContact2 {
619            initiate_contact: protocol::InitiateContact {
620                version_requested: version as u32,
621                target_message_vp: request.target_message_vp,
622                interrupt_page_or_target_info: target_info.into(),
623                parent_to_child_monitor_page_gpa: monitor_page.parent_to_child,
624                child_to_parent_monitor_page_gpa: monitor_page.child_to_parent,
625            },
626            client_id: request.client_id,
627        };
628
629        self.state = ClientState::Connecting { version, rpc };
630        if version < Version::Copper {
631            self.inner.messages.send(&msg.initiate_contact)
632        } else {
633            self.inner.messages.send(&msg);
634        }
635    }
636
637    fn handle_unload(&mut self, rpc: Rpc<(), ()>) {
638        tracing::debug!(%self.state, "VmBus client disconnecting");
639        self.state = ClientState::Disconnecting {
640            version: self.state.get_version().expect("invalid state for unload"),
641            rpc,
642        };
643
644        self.inner.messages.send(&protocol::Unload {});
645    }
646
647    fn handle_modify(&mut self, request: Rpc<ModifyConnectionRequest, ConnectionState>) {
648        if !matches!(self.state, ClientState::Connected { version, .. }
649            if version.feature_flags.modify_connection())
650        {
651            tracing::warn!("ModifyConnection not supported");
652            request.complete(ConnectionState::FAILED_UNKNOWN_FAILURE);
653            return;
654        }
655
656        if self.modify_request.is_some() {
657            tracing::warn!("Duplicate ModifyConnection request");
658            request.complete(ConnectionState::FAILED_UNKNOWN_FAILURE);
659            return;
660        }
661
662        let message = protocol::ModifyConnection::from(*request.input());
663        self.modify_request = Some(request);
664        self.inner.messages.send(&message);
665    }
666
667    fn handle_tl_connect(&mut self, rpc: Rpc<HvsockConnectRequest, Option<OfferInfo>>) {
668        // The client only supports protocol versions which use the newer message format.
669        // The host will not send a TlConnectRequestResult message on success, so a response to this
670        // message is not guaranteed.
671        let message = protocol::TlConnectRequest2::from(*rpc.input());
672        self.hvsock_tracker.add_request(rpc);
673        self.inner.messages.send(&message);
674    }
675
676    fn handle_client_request(&mut self, request: ClientRequest) {
677        match request {
678            ClientRequest::Connect(rpc) => {
679                self.handle_initiate_contact(rpc, *SUPPORTED_VERSIONS.last().unwrap());
680            }
681            ClientRequest::Unload(rpc) => {
682                self.handle_unload(rpc);
683            }
684            ClientRequest::Modify(request) => self.handle_modify(request),
685            ClientRequest::HvsockConnect(request) => self.handle_tl_connect(request),
686        }
687    }
688
689    fn handle_version_response(&mut self, msg: protocol::VersionResponse2) {
690        let old_state = std::mem::replace(&mut self.state, ClientState::Disconnected);
691        let ClientState::Connecting { version, rpc } = old_state else {
692            self.state = old_state;
693            tracing::warn!(
694                client_state = %self.state,
695                "invalid client state to handle VersionResponse"
696            );
697            return;
698        };
699        if msg.version_response.version_supported > 0 {
700            if msg.version_response.connection_state != ConnectionState::SUCCESSFUL {
701                rpc.complete(Err(ConnectError::FailedToConnect(
702                    msg.version_response.connection_state,
703                )));
704                return;
705            }
706
707            let feature_flags = if version >= Version::Copper {
708                FeatureFlags::from(msg.supported_features)
709            } else {
710                FeatureFlags::new()
711            };
712
713            let version = VersionInfo {
714                version,
715                feature_flags,
716            };
717
718            self.inner.messages.send(&protocol::RequestOffers {});
719            self.state = ClientState::RequestingOffers {
720                version,
721                rpc: rpc.split().1,
722                offers: Vec::new(),
723            };
724            tracing::info!(?version, "VmBus client connected, requesting offers");
725        } else {
726            let index = SUPPORTED_VERSIONS
727                .iter()
728                .position(|v| *v == version)
729                .unwrap();
730
731            if index == 0 {
732                rpc.complete(Err(ConnectError::NoSupportedVersions));
733                return;
734            }
735            let next_version = SUPPORTED_VERSIONS[index - 1];
736            tracing::debug!(
737                version = version as u32,
738                next_version = next_version as u32,
739                "Unsupported version, retrying"
740            );
741            self.handle_initiate_contact(rpc, next_version);
742        }
743    }
744
745    fn create_channel(&mut self, offer: protocol::OfferChannel) -> Result<OfferInfo> {
746        self.create_channel_core(offer, ChannelState::Offered)
747    }
748
749    fn create_channel_core(
750        &mut self,
751        offer: protocol::OfferChannel,
752        state: ChannelState,
753    ) -> Result<OfferInfo> {
754        if self.channels.0.contains_key(&offer.channel_id) {
755            anyhow::bail!("channel {:?} exists", offer.channel_id);
756        }
757        let (request_send, request_recv) = mesh::channel();
758        let (revoke_send, revoke_recv) = mesh::oneshot();
759
760        let connection_id = Arc::new(AtomicU32::new(0));
761        self.channels.0.insert(
762            offer.channel_id,
763            Channel {
764                revoke_send: Some(revoke_send),
765                offer,
766                state,
767                modify_response_send: None,
768                gpadls: HashMap::new(),
769                is_client_released: false,
770                connection_id: connection_id.clone(),
771            },
772        );
773
774        self.inner
775            .channel_requests
776            .push(TaggedStream::new(offer.channel_id, request_recv));
777
778        Ok(OfferInfo {
779            offer,
780            guest_to_host_interrupt: self.inner.synic.guest_to_host_interrupt(connection_id),
781            revoke_recv,
782            request_send,
783        })
784    }
785
786    fn handle_offer(&mut self, offer: protocol::OfferChannel) {
787        let offer_info = self
788            .create_channel(offer)
789            .expect("channel should not exist");
790
791        tracing::info!(
792                state = %self.state,
793                channel_id = offer.channel_id.0,
794                interface_id = %offer.interface_id,
795                instance_id = %offer.instance_id,
796                subchannel_index = offer.subchannel_index,
797                "received offer");
798
799        if let Some(offer) = self.hvsock_tracker.check_offer(&offer_info.offer) {
800            offer.complete(Some(offer_info));
801        } else {
802            match &mut self.state {
803                ClientState::Connected { offer_send, .. } => {
804                    offer_send.send(offer_info);
805                }
806                ClientState::RequestingOffers { offers, .. } => {
807                    offers.push(offer_info);
808                }
809                state => unreachable!("invalid client state for OfferChannel: {state}"),
810            }
811        }
812    }
813
814    fn handle_rescind(&mut self, rescind: protocol::RescindChannelOffer) -> TriedRelease {
815        tracing::info!(state = %self.state, channel_id = rescind.channel_id.0, "received rescind");
816
817        let mut channel = self.channels.get_mut(rescind.channel_id);
818        let event_flag = match std::mem::replace(&mut channel.state, ChannelState::Revoked) {
819            ChannelState::Offered => None,
820            ChannelState::Opening {
821                redirected_event_flag,
822                redirected_event: _,
823                rpc,
824            } => {
825                rpc.fail(anyhow::anyhow!("channel revoked"));
826                redirected_event_flag
827            }
828            ChannelState::Restored => None,
829            ChannelState::Opened {
830                redirected_event_flag,
831                redirected_event: _,
832            } => redirected_event_flag,
833            ChannelState::Revoked => {
834                panic!("channel id {:?} already revoked", rescind.channel_id);
835            }
836        };
837        if let Some(event_flag) = event_flag {
838            self.inner.synic.free_event_flag(event_flag);
839        }
840
841        // Drop the channel and send the revoked message to the client.
842        channel.revoke_send.take().unwrap().send(());
843
844        channel.try_release(&mut self.inner.messages)
845    }
846
847    fn handle_offers_delivered(&mut self) {
848        match std::mem::replace(&mut self.state, ClientState::Disconnected) {
849            ClientState::RequestingOffers {
850                version,
851                rpc,
852                offers,
853            } => {
854                tracing::info!(version = ?version, "VmBus client connected, offers delivered");
855                let (offer_send, offer_recv) = mesh::channel();
856                self.state = ClientState::Connected {
857                    version,
858                    offer_send,
859                };
860                rpc.complete(Ok(ConnectResult {
861                    version,
862                    offers,
863                    offer_recv,
864                }));
865            }
866            state => {
867                tracing::warn!(client_state = %state, "invalid client state for OffersDelivered");
868                self.state = state;
869            }
870        }
871    }
872
873    fn handle_gpadl_created(&mut self, request: protocol::GpadlCreated) -> TriedRelease {
874        let mut channel = self.channels.get_mut(request.channel_id);
875        let Some(gpadl_state) = channel.gpadls.get_mut(&request.gpadl_id) else {
876            panic!("GpadlCreated for unknown gpadl {:#x}", request.gpadl_id.0);
877        };
878
879        let rpc = match std::mem::replace(gpadl_state, GpadlState::Created) {
880            GpadlState::Offered(rpc) => rpc,
881            old_state => {
882                panic!(
883                    "invalid state {old_state:?} for gpadl {:#x}:{:#x}",
884                    request.channel_id.0, request.gpadl_id.0
885                );
886            }
887        };
888
889        let gpadl_created = request.status == protocol::STATUS_SUCCESS;
890        if gpadl_created {
891            rpc.complete(Ok(()));
892        } else {
893            channel.gpadls.remove(&request.gpadl_id).unwrap();
894            rpc.fail(anyhow::anyhow!(
895                "gpadl creation failed: {:#x}",
896                request.status
897            ));
898        };
899        channel.try_release(&mut self.inner.messages)
900    }
901
902    fn handle_open_result(&mut self, result: protocol::OpenResult) {
903        tracing::debug!(
904            channel_id = result.channel_id.0,
905            result = result.status,
906            "received open result"
907        );
908
909        let mut channel = self.channels.get_mut(result.channel_id);
910
911        let channel_opened = result.status == protocol::STATUS_SUCCESS as u32;
912        let old_state = std::mem::replace(&mut channel.state, ChannelState::Offered);
913        let ChannelState::Opening {
914            redirected_event_flag,
915            redirected_event,
916            rpc,
917        } = old_state
918        else {
919            tracing::warn!(
920                old_state = ?channel.state,
921                channel_opened,
922                "invalid state for open result"
923            );
924            channel.state = old_state;
925            return;
926        };
927
928        if !channel_opened {
929            if let Some(event_flag) = redirected_event_flag {
930                self.inner.synic.free_event_flag(event_flag);
931            }
932            rpc.fail(anyhow::anyhow!("open failed: {:#x}", result.status));
933            return;
934        }
935
936        channel.state = ChannelState::Opened {
937            redirected_event_flag,
938            redirected_event,
939        };
940
941        rpc.complete(Ok(OpenOutput {
942            redirected_event_flag,
943        }));
944    }
945
946    fn handle_gpadl_torndown(&mut self, request: protocol::GpadlTorndown) -> TriedRelease {
947        let Some(channel_id) = self.inner.teardown_gpadls.remove(&request.gpadl_id) else {
948            panic!("gpadl {:#x} not in teardown list", request.gpadl_id.0);
949        };
950
951        tracing::debug!(
952            gpadl_id = request.gpadl_id.0,
953            channel_id = channel_id.0,
954            "Received GpadlTorndown"
955        );
956
957        let mut channel = self.channels.get_mut(channel_id);
958        let gpadl_state = channel
959            .gpadls
960            .remove(&request.gpadl_id)
961            .expect("gpadl validated above");
962
963        let GpadlState::TearingDown { rpcs } = gpadl_state else {
964            panic!("gpadl should be tearing down if in teardown list, state = {gpadl_state:?}");
965        };
966
967        for rpc in rpcs {
968            rpc.complete(());
969        }
970        channel.try_release(&mut self.inner.messages)
971    }
972
973    fn handle_unload_complete(&mut self) {
974        match std::mem::replace(&mut self.state, ClientState::Disconnected) {
975            ClientState::Disconnecting { version: _, rpc } => {
976                tracing::info!("VmBus client disconnected");
977                rpc.complete(());
978            }
979            state => {
980                tracing::warn!(client_state = %state, "invalid client state for UnloadComplete");
981            }
982        }
983    }
984
985    fn handle_modify_complete(&mut self, response: protocol::ModifyConnectionResponse) {
986        if let Some(request) = self.modify_request.take() {
987            request.complete(response.connection_state)
988        } else {
989            tracing::warn!("Unexpected modify complete request");
990        }
991    }
992
993    fn handle_modify_channel_response(
994        &mut self,
995        response: protocol::ModifyChannelResponse,
996    ) -> TriedRelease {
997        let mut channel = self.channels.get_mut(response.channel_id);
998        let Some(sender) = channel.modify_response_send.take() else {
999            panic!(
1000                "unexpected modify channel response for channel {:#x}",
1001                response.channel_id.0
1002            );
1003        };
1004
1005        sender.complete(response.status);
1006        channel.try_release(&mut self.inner.messages)
1007    }
1008
1009    fn handle_tl_connect_result(&mut self, response: protocol::TlConnectResult) {
1010        if let Some(rpc) = self.hvsock_tracker.check_result(&response) {
1011            rpc.complete(None);
1012        }
1013    }
1014
1015    /// Returns false if the message was a pause complete message.
1016    fn handle_synic_message(&mut self, data: &[u8]) -> bool {
1017        let msg = Message::parse(data, self.state.get_version()).unwrap();
1018        tracing::trace!(?msg, "received client message from synic");
1019
1020        match msg {
1021            Message::VersionResponse3(version_response, ..) => {
1022                // The client never sends the server-specified monitor pages feature flag, but
1023                // since version response messages are distinguished only by size, the response can
1024                // still look like `VersionResponse3` if the size was not set exactly by the server.
1025                // Since the feature flag can't be set, the extra data can be ignored.
1026                self.handle_version_response(version_response.version_response2);
1027            }
1028            Message::VersionResponse2(version_response, ..) => {
1029                self.handle_version_response(version_response);
1030            }
1031            Message::VersionResponse(version_response, ..) => {
1032                self.handle_version_response(version_response.into());
1033            }
1034            Message::OfferChannel(offer, ..) => {
1035                self.handle_offer(offer);
1036            }
1037            Message::AllOffersDelivered(..) => {
1038                self.handle_offers_delivered();
1039            }
1040            Message::UnloadComplete(..) => {
1041                self.handle_unload_complete();
1042            }
1043            Message::ModifyConnectionResponse(response, ..) => {
1044                self.handle_modify_complete(response);
1045            }
1046            Message::GpadlCreated(gpadl, ..) => {
1047                self.handle_gpadl_created(gpadl);
1048            }
1049            Message::OpenResult(result, ..) => {
1050                self.handle_open_result(result);
1051            }
1052            Message::GpadlTorndown(gpadl, ..) => {
1053                self.handle_gpadl_torndown(gpadl);
1054            }
1055            Message::RescindChannelOffer(rescind, ..) => {
1056                self.handle_rescind(rescind);
1057            }
1058            Message::ModifyChannelResponse(response, ..) => {
1059                self.handle_modify_channel_response(response);
1060            }
1061            Message::TlConnectResult(response, ..) => self.handle_tl_connect_result(response),
1062            // Unsupported messages.
1063            Message::CloseReservedChannelResponse(..) => {
1064                todo!("Unsupported message {msg:?}")
1065            }
1066            Message::PauseResponse(..) => {
1067                return false;
1068            }
1069            // Messages that should only be received by a vmbus server.
1070            Message::RequestOffers(..)
1071            | Message::OpenChannel2(..)
1072            | Message::OpenChannel(..)
1073            | Message::CloseChannel(..)
1074            | Message::GpadlHeader(..)
1075            | Message::GpadlBody(..)
1076            | Message::GpadlTeardown(..)
1077            | Message::RelIdReleased(..)
1078            | Message::InitiateContact(..)
1079            | Message::InitiateContact2(..)
1080            | Message::Unload(..)
1081            | Message::OpenReservedChannel(..)
1082            | Message::CloseReservedChannel(..)
1083            | Message::TlConnectRequest2(..)
1084            | Message::TlConnectRequest(..)
1085            | Message::ModifyChannel(..)
1086            | Message::ModifyConnection(..)
1087            | Message::Pause(..)
1088            | Message::Resume(..) => {
1089                unreachable!("Client received server message {msg:?}");
1090            }
1091        }
1092        true
1093    }
1094
1095    fn handle_open_channel(
1096        &mut self,
1097        channel_id: ChannelId,
1098        rpc: FailableRpc<OpenRequest, OpenOutput>,
1099    ) {
1100        let mut channel = self.channels.get_mut(channel_id);
1101        match &channel.state {
1102            ChannelState::Offered => {}
1103            ChannelState::Revoked => {
1104                rpc.fail(anyhow::anyhow!("channel revoked"));
1105                return;
1106            }
1107            state => {
1108                rpc.fail(anyhow::anyhow!("invalid channel state: {}", state));
1109                return;
1110            }
1111        }
1112
1113        tracing::info!(channel_id = channel_id.0, "opening channel on host");
1114        let (request, rpc) = rpc.split();
1115        let open_data = &request.open_data;
1116
1117        let supports_interrupt_redirection =
1118            if let ClientState::Connected { version, .. } = self.state {
1119                version.feature_flags.guest_specified_signal_parameters()
1120                    || version.feature_flags.channel_interrupt_redirection()
1121            } else {
1122                false
1123            };
1124
1125        if !supports_interrupt_redirection && open_data.event_flag != channel_id.0 as u16 {
1126            rpc.fail(anyhow::anyhow!(
1127                "host does not support specifying the event flag"
1128            ));
1129            return;
1130        }
1131
1132        let open_channel = protocol::OpenChannel {
1133            channel_id,
1134            open_id: 0,
1135            ring_buffer_gpadl_id: open_data.ring_gpadl_id,
1136            target_vp: open_data.target_vp,
1137            downstream_ring_buffer_page_offset: open_data.ring_offset,
1138            user_data: open_data.user_data,
1139        };
1140
1141        let connection_id = if request.use_vtl2_connection_id {
1142            if !supports_interrupt_redirection {
1143                rpc.fail(anyhow::anyhow!(
1144                    "host does not support specfiying the connection ID"
1145                ));
1146                return;
1147            }
1148            protocol::ConnectionId::new(channel_id.0, 2.try_into().unwrap(), 7).0
1149        } else {
1150            open_data.connection_id
1151        };
1152
1153        // No failure paths after the one for allocating the event flag, since
1154        // otherwise we would need to free the event flag.
1155        let mut flags = OpenChannelFlags::new();
1156        let event_flag = if let Some(event) = &request.incoming_event {
1157            if !supports_interrupt_redirection {
1158                rpc.fail(anyhow::anyhow!(
1159                    "host does not support redirecting interrupts"
1160                ));
1161                return;
1162            }
1163
1164            flags.set_redirect_interrupt(true);
1165            match self.inner.synic.allocate_event_flag(event) {
1166                Ok(flag) => flag,
1167                Err(err) => {
1168                    rpc.fail(err.context("failed to allocate event flag"));
1169                    return;
1170                }
1171            }
1172        } else {
1173            open_data.event_flag
1174        };
1175
1176        if supports_interrupt_redirection {
1177            self.inner.messages.send(&protocol::OpenChannel2 {
1178                open_channel,
1179                connection_id,
1180                event_flag,
1181                flags,
1182            });
1183        } else {
1184            self.inner.messages.send(&open_channel);
1185        }
1186
1187        channel
1188            .connection_id
1189            .store(connection_id, Ordering::Release);
1190        channel.state = ChannelState::Opening {
1191            redirected_event_flag: (request.incoming_event.is_some()).then_some(event_flag),
1192            redirected_event: request.incoming_event,
1193            rpc,
1194        }
1195    }
1196
1197    fn handle_restore_channel(
1198        &mut self,
1199        channel_id: ChannelId,
1200        request: RestoreRequest,
1201    ) -> Result<OpenOutput> {
1202        let mut channel = self.channels.get_mut(channel_id);
1203        if !matches!(channel.state, ChannelState::Restored) {
1204            anyhow::bail!("invalid channel state: {}", channel.state);
1205        }
1206
1207        if request.incoming_event.is_some() != request.redirected_event_flag.is_some() {
1208            anyhow::bail!("incoming event and redirected event flag must both be set or unset");
1209        }
1210
1211        if let Some((flag, event)) = request
1212            .redirected_event_flag
1213            .zip(request.incoming_event.as_ref())
1214        {
1215            self.inner.synic.restore_event_flag(flag, event)?;
1216        }
1217
1218        channel
1219            .connection_id
1220            .store(request.connection_id, Ordering::Release);
1221        channel.state = ChannelState::Opened {
1222            redirected_event_flag: request.redirected_event_flag,
1223            redirected_event: request.incoming_event,
1224        };
1225        Ok(OpenOutput {
1226            redirected_event_flag: request.redirected_event_flag,
1227        })
1228    }
1229
1230    fn handle_gpadl(&mut self, channel_id: ChannelId, rpc: FailableRpc<GpadlRequest, ()>) {
1231        let (request, rpc) = rpc.split();
1232        let mut channel = self.channels.get_mut(channel_id);
1233        if channel
1234            .gpadls
1235            .insert(request.id, GpadlState::Offered(rpc))
1236            .is_some()
1237        {
1238            panic!(
1239                "duplicate gpadl ID {:?} for channel {:?}.",
1240                request.id, channel_id
1241            );
1242        }
1243
1244        tracing::trace!(
1245            channel_id = channel_id.0,
1246            gpadl_id = request.id.0,
1247            count = request.count,
1248            len = request.buf.len(),
1249            "received gpadl request"
1250        );
1251
1252        // Split off the values that fit in the header.
1253        let (first, remaining) = if request.buf.len() > protocol::GpadlHeader::MAX_DATA_VALUES {
1254            request.buf.split_at(protocol::GpadlHeader::MAX_DATA_VALUES)
1255        } else {
1256            (request.buf.as_slice(), [].as_slice())
1257        };
1258
1259        let message = protocol::GpadlHeader {
1260            channel_id,
1261            gpadl_id: request.id,
1262            len: (request.buf.len() * size_of::<u64>())
1263                .try_into()
1264                .expect("Too many GPA values"),
1265            count: request.count,
1266        };
1267
1268        self.inner
1269            .messages
1270            .send_with_data(&message, first.as_bytes());
1271
1272        // Send GpadlBody messages for the remaining values.
1273        let message = protocol::GpadlBody {
1274            rsvd: 0,
1275            gpadl_id: request.id,
1276        };
1277        for chunk in remaining.chunks(protocol::GpadlBody::MAX_DATA_VALUES) {
1278            self.inner
1279                .messages
1280                .send_with_data(&message, chunk.as_bytes());
1281        }
1282    }
1283
1284    fn handle_gpadl_teardown(&mut self, channel_id: ChannelId, rpc: Rpc<GpadlId, ()>) {
1285        let (gpadl_id, rpc) = rpc.split();
1286        let mut channel = self.channels.get_mut(channel_id);
1287        let Some(gpadl_state) = channel.gpadls.get_mut(&gpadl_id) else {
1288            tracing::warn!(
1289                gpadl_id = gpadl_id.0,
1290                channel_id = channel_id.0,
1291                "Gpadl teardown for unknown gpadl or revoked channel"
1292            );
1293            return;
1294        };
1295
1296        match gpadl_state {
1297            GpadlState::Offered(_) => {
1298                tracing::warn!(
1299                    gpadl_id = gpadl_id.0,
1300                    channel_id = channel_id.0,
1301                    "gpadl teardown for offered gpadl"
1302                );
1303            }
1304            GpadlState::Created => {
1305                *gpadl_state = GpadlState::TearingDown { rpcs: vec![rpc] };
1306                // The caller must guarantee that GPADL teardown requests are only made
1307                // for unique GPADL IDs. This is currently enforced in vmbus_server by
1308                // blocking GPADL teardown messages for reserved channels.
1309                assert!(
1310                    self.inner
1311                        .teardown_gpadls
1312                        .insert(gpadl_id, channel_id)
1313                        .is_none(),
1314                    "Gpadl state validated above"
1315                );
1316
1317                self.inner.messages.send(&protocol::GpadlTeardown {
1318                    channel_id,
1319                    gpadl_id,
1320                });
1321            }
1322            GpadlState::TearingDown { rpcs } => {
1323                rpcs.push(rpc);
1324            }
1325        }
1326    }
1327
1328    fn handle_close_channel(&mut self, channel_id: ChannelId) {
1329        let mut channel = self.channels.get_mut(channel_id);
1330        self.inner.close_channel(channel_id, &mut channel);
1331    }
1332
1333    fn handle_modify_channel(&mut self, channel_id: ChannelId, rpc: Rpc<ModifyRequest, i32>) {
1334        // The client doesn't support versions below Iron, so we always expect the host to send a
1335        // ModifyChannelResponse. This means we don't need to worry about sending a ChannelResponse
1336        // if that weren't supported.
1337        assert!(self.check_version(Version::Iron));
1338        let mut channel = self.channels.get_mut(channel_id);
1339        if channel.modify_response_send.is_some() {
1340            panic!("duplicate channel modify request {channel_id:?}");
1341        }
1342
1343        let (request, response) = rpc.split();
1344        channel.modify_response_send = Some(response);
1345        let payload = match request {
1346            ModifyRequest::TargetVp { target_vp } => protocol::ModifyChannel {
1347                channel_id,
1348                target_vp,
1349            },
1350        };
1351
1352        self.inner.messages.send(&payload);
1353    }
1354
1355    fn handle_channel_request(&mut self, channel_id: ChannelId, request: ChannelRequest) {
1356        match request {
1357            ChannelRequest::Open(rpc) => self.handle_open_channel(channel_id, rpc),
1358            ChannelRequest::Restore(rpc) => {
1359                rpc.handle_failable_sync(|request| self.handle_restore_channel(channel_id, request))
1360            }
1361            ChannelRequest::Gpadl(req) => self.handle_gpadl(channel_id, req),
1362            ChannelRequest::TeardownGpadl(req) => self.handle_gpadl_teardown(channel_id, req),
1363            ChannelRequest::Close(req) => {
1364                req.handle_sync(|()| self.handle_close_channel(channel_id))
1365            }
1366            ChannelRequest::Modify(req) => self.handle_modify_channel(channel_id, req),
1367        }
1368    }
1369
1370    async fn handle_task(&mut self, task: TaskRequest) {
1371        match task {
1372            TaskRequest::Inspect(deferred) => {
1373                deferred.inspect(&*self);
1374            }
1375            TaskRequest::Save(rpc) => rpc.handle_sync(|()| self.handle_save()),
1376            TaskRequest::Restore(rpc) => {
1377                rpc.handle_sync(|saved_state| self.handle_restore(saved_state))
1378            }
1379            TaskRequest::PostRestore(rpc) => rpc.handle_sync(|()| self.handle_post_restore()),
1380            TaskRequest::Start => self.handle_start(),
1381            TaskRequest::Stop(rpc) => rpc.handle(async |()| self.handle_stop().await).await,
1382        }
1383    }
1384
1385    /// Makes sure a channel is closed if the channel request stream was dropped.
1386    fn handle_device_removal(&mut self, channel_id: ChannelId) -> TriedRelease {
1387        let mut channel = self.channels.get_mut(channel_id);
1388        channel.is_client_released = true;
1389        // Close the channel if it is still open.
1390        if let ChannelState::Opened { .. } = channel.state {
1391            tracing::warn!(
1392                channel_id = channel_id.0,
1393                "Channel dropped without closing first"
1394            );
1395            self.inner.close_channel(channel_id, &mut channel);
1396        }
1397        channel.try_release(&mut self.inner.messages)
1398    }
1399
1400    /// Determines if the client is connected with at least the specified version.
1401    fn check_version(&self, version: Version) -> bool {
1402        matches!(self.state, ClientState::Connected { version: v, .. } if v.version >= version)
1403    }
1404
1405    fn handle_start(&mut self) {
1406        assert!(!self.running);
1407        self.msg_source.resume_message_stream();
1408        self.inner.messages.resume();
1409        self.running = true;
1410    }
1411
1412    async fn handle_stop(&mut self) {
1413        assert!(self.running);
1414
1415        loop {
1416            // Process messages until there are no more channels waiting for
1417            // responses. This is necessary to ensure that the saved state does
1418            // not have to support encoding revoked channels for which we are
1419            // waiting for GPADL or modify responses.
1420            while let Some((id, request)) = self.channels.revoked_channel_with_pending_request() {
1421                tracelimit::info_ratelimited!(
1422                    channel_id = id.0,
1423                    request,
1424                    "waiting for responses for channel"
1425                );
1426                assert!(self.process_next_message().await);
1427            }
1428
1429            if self.can_pause_resume() {
1430                self.inner.messages.pause();
1431            } else {
1432                // Mask the sint to pause the message stream. The host will
1433                // retry any queued messages after the sint is unmasked.
1434                self.msg_source.pause_message_stream();
1435                self.inner.messages.force_pause();
1436            }
1437
1438            // Continue processing messages until we hit EOF or get a pause
1439            // response.
1440            while self.process_next_message().await {}
1441
1442            // Ensure there are still no pending requests. If there are, resume
1443            // and go around again.
1444            if self
1445                .channels
1446                .revoked_channel_with_pending_request()
1447                .is_none()
1448            {
1449                break;
1450            }
1451            if !self.can_pause_resume() {
1452                self.msg_source.resume_message_stream();
1453            }
1454            self.inner.messages.resume();
1455        }
1456
1457        tracing::debug!("messages drained");
1458        // Because the run loop awaits all async operations, there is no need for rundown.
1459        self.running = false;
1460    }
1461
1462    async fn process_next_message(&mut self) -> bool {
1463        let mut buf = [0; protocol::MAX_MESSAGE_SIZE];
1464        let recv = self.msg_source.recv(&mut buf);
1465        // Concurrently flush until there is no more work to do, since pending
1466        // messages may be blocking responses from the host.
1467        let flush = async {
1468            self.inner.messages.flush_messages().await;
1469            std::future::pending().await
1470        };
1471        let size = (recv, flush)
1472            .race()
1473            .await
1474            .expect("Fatal error reading messages from synic");
1475        if size == 0 {
1476            return false;
1477        }
1478        self.handle_synic_message(&buf[..size])
1479    }
1480
1481    /// Returns whether the server supports in-band messages to pause/resume the
1482    /// message stream.
1483    ///
1484    /// For hosts where this is not supported, we mask the sint to pause new
1485    /// messages being queued to the sint, then drain the messages. This does
1486    /// not work with some host implementations, which cannot support draining
1487    /// the message queue while the sint is masked (due to the use of
1488    /// HvPostMessageDirect).
1489    fn can_pause_resume(&self) -> bool {
1490        if let ClientState::Connected { version, .. } = self.state {
1491            version.feature_flags.pause_resume()
1492        } else {
1493            false
1494        }
1495    }
1496
1497    async fn run(&mut self) {
1498        let mut buf = [0; protocol::MAX_MESSAGE_SIZE];
1499        loop {
1500            let mut message_recv =
1501                OptionFuture::from(self.running.then(|| self.msg_source.recv(&mut buf).fuse()));
1502
1503            // If there are pending outgoing messages, the host is backed up.
1504            // Try to flush the queue, and in the meantime, stop generating new
1505            // messages by stopping processing client requests, so as to avoid
1506            // the outgoing message queue growing without bound.
1507            //
1508            // We still need to process incoming messages when in this state,
1509            // even though they may generate additional outgoing messages, to
1510            // avoid a deadlock with the host. The host can always DoS the
1511            // guest, so this is not an attack vector.
1512            let host_backed_up = !self.inner.messages.is_empty();
1513            let flush_messages = OptionFuture::from(
1514                (self.running && host_backed_up)
1515                    .then(|| self.inner.messages.flush_messages().fuse()),
1516            );
1517
1518            let mut client_request_recv = OptionFuture::from(
1519                (self.running && !host_backed_up).then(|| self.client_request_recv.next()),
1520            );
1521
1522            let mut channel_requests = OptionFuture::from(
1523                (self.running && !host_backed_up)
1524                    .then(|| self.inner.channel_requests.select_next_some()),
1525            );
1526
1527            futures::select! { // merge semantics
1528                _r = pin!(flush_messages) => {}
1529                r = self.task_recv.next() => {
1530                    if let Some(task) = r {
1531                        self.handle_task(task).await;
1532                    } else {
1533                        break;
1534                    }
1535                }
1536                r = client_request_recv => {
1537                    if let Some(Some(request)) = r {
1538                        self.handle_client_request(request);
1539                    } else {
1540                        break;
1541                    }
1542                }
1543                r = channel_requests => {
1544                    match r.unwrap() {
1545                        (id, Some(request)) => self.handle_channel_request(id, request),
1546                        (id, _) => {
1547                            self.handle_device_removal(id);
1548                        }
1549                    }
1550                }
1551                r = message_recv => {
1552                    match r.unwrap() {
1553                        Ok(size) => {
1554                            if size == 0 {
1555                                panic!("Unexpected end of file reading messages from synic.");
1556                            }
1557
1558                            self.handle_synic_message(&buf[..size]);
1559                        }
1560                        Err(err) => {
1561                            panic!("Error reading messages from synic: {err:?}");
1562                        }
1563                    }
1564                }
1565                complete => break,
1566            }
1567        }
1568    }
1569}
1570
1571impl ClientTaskInner {
1572    fn close_channel(&mut self, channel_id: ChannelId, channel: &mut Channel) {
1573        if let ChannelState::Opened {
1574            redirected_event_flag,
1575            ..
1576        } = channel.state
1577        {
1578            if let Some(flag) = redirected_event_flag {
1579                self.synic.free_event_flag(flag);
1580            }
1581            tracing::info!(channel_id = channel_id.0, "closing channel on host");
1582            self.messages.send(&protocol::CloseChannel { channel_id });
1583            channel.state = ChannelState::Offered;
1584            channel.connection_id.store(0, Ordering::Release);
1585        } else {
1586            tracing::warn!(
1587                id = %channel_id.0,
1588                channel_state = %channel.state,
1589                "invalid channel state for close channel"
1590            );
1591        }
1592    }
1593}
1594
1595#[derive(Debug, Inspect)]
1596#[inspect(external_tag)]
1597enum GpadlState {
1598    /// GpadlHeader has been sent to the host.
1599    Offered(#[inspect(skip)] FailableRpc<(), ()>),
1600    /// Host has responded with GpadlCreated.
1601    Created,
1602    /// GpadlTeardown message has been sent to the host.
1603    TearingDown {
1604        #[inspect(skip)]
1605        rpcs: Vec<Rpc<(), ()>>,
1606    },
1607}
1608
1609#[derive(Inspect)]
1610struct OutgoingMessages {
1611    #[inspect(skip)]
1612    poster: Box<dyn PollPostMessage>,
1613    #[inspect(with = "|x| x.len()")]
1614    queued: VecDeque<OutgoingMessage>,
1615    state: OutgoingMessageState,
1616}
1617
1618#[derive(Inspect, PartialEq, Eq, Debug)]
1619enum OutgoingMessageState {
1620    Running,
1621    SendingPauseMessage,
1622    Paused,
1623}
1624
1625impl OutgoingMessages {
1626    fn send<T: IntoBytes + protocol::VmbusMessage + std::fmt::Debug + Immutable + KnownLayout>(
1627        &mut self,
1628        msg: &T,
1629    ) {
1630        self.send_with_data(msg, &[])
1631    }
1632
1633    fn send_with_data<
1634        T: IntoBytes + protocol::VmbusMessage + std::fmt::Debug + Immutable + KnownLayout,
1635    >(
1636        &mut self,
1637        msg: &T,
1638        data: &[u8],
1639    ) {
1640        tracing::trace!(typ = ?T::MESSAGE_TYPE, "Sending message to host");
1641        let msg = OutgoingMessage::with_data(msg, data);
1642        if self.queued.is_empty() && self.state == OutgoingMessageState::Running {
1643            let r = self.poster.poll_post_message(
1644                &mut Context::from_waker(std::task::Waker::noop()),
1645                protocol::VMBUS_MESSAGE_REDIRECT_CONNECTION_ID,
1646                1,
1647                msg.data(),
1648            );
1649            if let Poll::Ready(()) = r {
1650                return;
1651            }
1652        }
1653        tracing::trace!("queueing message");
1654        self.queued.push_back(msg);
1655    }
1656
1657    async fn flush_messages(&mut self) {
1658        let mut send = async |msg: &OutgoingMessage| {
1659            poll_fn(|cx| {
1660                self.poster.poll_post_message(
1661                    cx,
1662                    protocol::VMBUS_MESSAGE_REDIRECT_CONNECTION_ID,
1663                    1,
1664                    msg.data(),
1665                )
1666            })
1667            .await
1668        };
1669        match self.state {
1670            OutgoingMessageState::Running => {
1671                while let Some(msg) = self.queued.front() {
1672                    send(msg).await;
1673                    tracing::trace!("sent queued message");
1674                    self.queued.pop_front();
1675                }
1676            }
1677            OutgoingMessageState::SendingPauseMessage => {
1678                send(&OutgoingMessage::new(&protocol::Pause)).await;
1679                tracing::trace!("sent pause message");
1680                self.state = OutgoingMessageState::Paused;
1681            }
1682            OutgoingMessageState::Paused => {}
1683        }
1684    }
1685
1686    /// Pause by sending a pause message to the host. This will cause the host
1687    /// to stop sending messages after sending a pause response.
1688    fn pause(&mut self) {
1689        assert_eq!(self.state, OutgoingMessageState::Running);
1690        self.state = OutgoingMessageState::SendingPauseMessage;
1691        // Queue a resume message to be sent later.
1692        self.queued
1693            .push_front(OutgoingMessage::new(&protocol::Resume));
1694    }
1695
1696    /// Force a pause by setting the state to Paused. This is used when the
1697    /// host does not support in-band pause/resume messages, in which case
1698    /// the SINT is masked to force the host to stop sending messages.
1699    fn force_pause(&mut self) {
1700        assert_eq!(self.state, OutgoingMessageState::Running);
1701        self.state = OutgoingMessageState::Paused;
1702    }
1703
1704    fn resume(&mut self) {
1705        assert_eq!(self.state, OutgoingMessageState::Paused);
1706        self.state = OutgoingMessageState::Running;
1707    }
1708
1709    fn is_empty(&self) -> bool {
1710        self.queued.is_empty()
1711    }
1712}
1713
1714#[derive(Inspect)]
1715struct ClientTaskInner {
1716    messages: OutgoingMessages,
1717    #[inspect(with = "|x| inspect::iter_by_key(x).map_key(|id| id.0)")]
1718    teardown_gpadls: HashMap<GpadlId, ChannelId>,
1719    #[inspect(skip)]
1720    channel_requests: SelectAll<TaggedStream<ChannelId, mesh::Receiver<ChannelRequest>>>,
1721    synic: SynicState,
1722}
1723
1724#[derive(Inspect)]
1725struct SynicState {
1726    #[inspect(skip)]
1727    event_client: Arc<dyn SynicEventClient>,
1728    #[inspect(iter_by_index)]
1729    event_flag_state: Vec<bool>,
1730}
1731
1732#[derive(Inspect, Default)]
1733#[inspect(transparent)]
1734struct ChannelList(
1735    #[inspect(with = "|x| inspect::iter_by_key(x).map_key(|id| id.0)")] HashMap<ChannelId, Channel>,
1736);
1737
1738/// A reference to a channel that can be used to remove the channel from the map
1739/// as well.
1740struct ChannelRef<'a>(hash_map::OccupiedEntry<'a, ChannelId, Channel>);
1741
1742/// A tag value used to indicate that [`ChannelRef::try_release`] has been called.
1743/// This is useful as a return value for methods that might transition a channel
1744/// into a fully released state.
1745struct TriedRelease(());
1746
1747impl ChannelRef<'_> {
1748    /// If the channel has been fully released (revoked, released by the client,
1749    /// no pending requests), notifes the server and removes this channel from
1750    /// the map.
1751    fn try_release(self, messages: &mut OutgoingMessages) -> TriedRelease {
1752        if self.is_client_released
1753            && matches!(self.state, ChannelState::Revoked)
1754            && self.pending_request().is_none()
1755        {
1756            let channel_id = *self.0.key();
1757            tracelimit::info_ratelimited!(channel_id = channel_id.0, "releasing channel");
1758            messages.send(&protocol::RelIdReleased { channel_id });
1759            self.0.remove();
1760        }
1761        TriedRelease(())
1762    }
1763}
1764
1765impl Deref for ChannelRef<'_> {
1766    type Target = Channel;
1767
1768    fn deref(&self) -> &Self::Target {
1769        self.0.get()
1770    }
1771}
1772
1773impl DerefMut for ChannelRef<'_> {
1774    fn deref_mut(&mut self) -> &mut Self::Target {
1775        self.0.get_mut()
1776    }
1777}
1778
1779impl ChannelList {
1780    fn revoked_channel_with_pending_request(&self) -> Option<(ChannelId, &'static str)> {
1781        self.0.iter().find_map(|(&id, channel)| {
1782            if !matches!(channel.state, ChannelState::Revoked) {
1783                return None;
1784            }
1785            Some((id, channel.pending_request()?))
1786        })
1787    }
1788
1789    #[track_caller]
1790    fn get_mut(&mut self, channel_id: ChannelId) -> ChannelRef<'_> {
1791        match self.0.entry(channel_id) {
1792            hash_map::Entry::Occupied(entry) => ChannelRef(entry),
1793            hash_map::Entry::Vacant(_) => {
1794                panic!("channel {:?} not found", channel_id);
1795            }
1796        }
1797    }
1798}
1799
1800impl SynicState {
1801    fn guest_to_host_interrupt(&self, connection_id: Arc<AtomicU32>) -> Interrupt {
1802        Interrupt::from_fn({
1803            let event_client = self.event_client.clone();
1804            move || {
1805                let connection_id = connection_id.load(Ordering::Acquire);
1806                if connection_id == 0 {
1807                    tracing::debug!("interrupt signal after close");
1808                    return;
1809                }
1810
1811                if let Err(err) = event_client.signal_event(connection_id, 0) {
1812                    tracelimit::warn_ratelimited!(
1813                        error = &err as &dyn std::error::Error,
1814                        "failed to signal event"
1815                    );
1816                }
1817            }
1818        })
1819    }
1820
1821    const MAX_EVENT_FLAGS: u16 = 2047;
1822
1823    fn allocate_event_flag(&mut self, event: &Event) -> Result<u16> {
1824        let i = self
1825            .event_flag_state
1826            .iter()
1827            .position(|&used| !used)
1828            .ok_or(())
1829            .or_else(|()| {
1830                if self.event_flag_state.len() >= Self::MAX_EVENT_FLAGS as usize {
1831                    anyhow::bail!("out of event flags");
1832                }
1833                self.event_flag_state.push(false);
1834                Ok(self.event_flag_state.len() - 1)
1835            })?;
1836
1837        let event_flag = (i + 1) as u16;
1838        self.event_client
1839            .map_event(event_flag, event)
1840            .context("failed to map event")?;
1841        self.event_flag_state[i] = true;
1842        Ok(event_flag)
1843    }
1844
1845    fn restore_event_flag(&mut self, flag: u16, event: &Event) -> Result<()> {
1846        let i = (flag as usize)
1847            .checked_sub(1)
1848            .context("invalid event flag")?;
1849        if i >= Self::MAX_EVENT_FLAGS as usize {
1850            anyhow::bail!("invalid event flag");
1851        }
1852        if self.event_flag_state.len() <= i {
1853            self.event_flag_state.resize(i + 1, false);
1854        }
1855        if self.event_flag_state[i] {
1856            anyhow::bail!("event flag already in use");
1857        }
1858        self.event_client
1859            .map_event(flag, event)
1860            .context("failed to map event")?;
1861        self.event_flag_state[i] = true;
1862        Ok(())
1863    }
1864
1865    fn free_event_flag(&mut self, flag: u16) {
1866        let i = flag as usize - 1;
1867        assert!(i < self.event_flag_state.len());
1868        self.event_flag_state[i] = false;
1869    }
1870}
1871
1872#[cfg(test)]
1873mod tests {
1874    use super::*;
1875    use futures_concurrency::future::Join;
1876    use guid::Guid;
1877    use pal_async::DefaultDriver;
1878    use pal_async::async_test;
1879    use pal_async::timer::PolledTimer;
1880    use protocol::TargetInfo;
1881    use std::fmt::Debug;
1882    use std::task::ready;
1883    use std::time::Duration;
1884    use test_with_tracing::test;
1885    use vmbus_core::protocol::MessageHeader;
1886    use vmbus_core::protocol::MessageType;
1887    use vmbus_core::protocol::OfferFlags;
1888    use vmbus_core::protocol::UserDefinedData;
1889    use vmbus_core::protocol::VmbusMessage;
1890    use zerocopy::FromBytes;
1891    use zerocopy::FromZeros;
1892    use zerocopy::Immutable;
1893    use zerocopy::IntoBytes;
1894    use zerocopy::KnownLayout;
1895
1896    const VMBUS_TEST_CLIENT_ID: Guid = guid::guid!("e6e6e6e6-e6e6-e6e6-e6e6-e6e6e6e6e6e6");
1897
1898    fn in_msg<T: IntoBytes + Immutable + KnownLayout>(message_type: MessageType, t: T) -> Vec<u8> {
1899        let mut data = Vec::new();
1900        data.extend_from_slice(&message_type.0.to_ne_bytes());
1901        data.extend_from_slice(&0u32.to_ne_bytes());
1902        data.extend_from_slice(t.as_bytes());
1903        data
1904    }
1905
1906    #[track_caller]
1907    fn check_message<T>(msg: OutgoingMessage, chk: T)
1908    where
1909        T: IntoBytes + FromBytes + Immutable + KnownLayout + Debug + VmbusMessage,
1910    {
1911        check_message_with_data(msg, chk, &[]);
1912    }
1913
1914    #[track_caller]
1915    fn check_message_with_data<T>(msg: OutgoingMessage, chk: T, data: &[u8])
1916    where
1917        T: IntoBytes + FromBytes + Immutable + KnownLayout + Debug + VmbusMessage,
1918    {
1919        let chk_data = OutgoingMessage::with_data(&chk, data);
1920        if msg.data() != chk_data.data() {
1921            let (header, rest) = MessageHeader::read_from_prefix(msg.data()).unwrap();
1922            assert_eq!(header.message_type(), <T as VmbusMessage>::MESSAGE_TYPE);
1923            let (msg, rest) = T::read_from_prefix(rest).expect("incorrect message size");
1924            if msg.as_bytes() != chk.as_bytes() {
1925                panic!("mismatched messages, expected {:#?}, got {:#?}", chk, msg);
1926            }
1927            if rest != data {
1928                panic!("mismatched data, expected {:#?}, got {:#?}", data, rest);
1929            }
1930        }
1931    }
1932
1933    struct TestServer {
1934        messages: mesh::Receiver<OutgoingMessage>,
1935        send: mesh::Sender<Vec<u8>>,
1936    }
1937
1938    impl TestServer {
1939        async fn next(&mut self) -> Option<OutgoingMessage> {
1940            self.messages.next().await
1941        }
1942
1943        fn send(&self, msg: Vec<u8>) {
1944            self.send.send(msg);
1945        }
1946
1947        async fn connect(&mut self, client: &mut VmbusClient) -> ConnectResult {
1948            self.connect_with_channels(client, |_| {}).await
1949        }
1950
1951        async fn connect_with_channels(
1952            &mut self,
1953            client: &mut VmbusClient,
1954            send_offers: impl FnOnce(&mut Self),
1955        ) -> ConnectResult {
1956            let client_connect = client.connect(0, None, Guid::ZERO);
1957
1958            let server_connect = async {
1959                let _ = self.next().await.unwrap();
1960
1961                self.send(in_msg(
1962                    MessageType::VERSION_RESPONSE,
1963                    protocol::VersionResponse2 {
1964                        version_response: protocol::VersionResponse {
1965                            version_supported: 1,
1966                            connection_state: ConnectionState::SUCCESSFUL,
1967                            padding: 0,
1968                            selected_version_or_connection_id: 0,
1969                        },
1970                        supported_features: SUPPORTED_FEATURE_FLAGS.into(),
1971                    },
1972                ));
1973
1974                check_message(self.next().await.unwrap(), protocol::RequestOffers {});
1975
1976                send_offers(self);
1977                self.send(in_msg(MessageType::ALL_OFFERS_DELIVERED, [0x00]));
1978            };
1979
1980            let (connection, ()) = (client_connect, server_connect).join().await;
1981
1982            let connection = connection.unwrap();
1983            assert_eq!(connection.version.version, Version::Copper);
1984            assert_eq!(connection.version.feature_flags, SUPPORTED_FEATURE_FLAGS);
1985            connection
1986        }
1987
1988        async fn get_channel(&mut self, client: &mut VmbusClient) -> OfferInfo {
1989            let [channel] = self
1990                .get_channels(client, 1)
1991                .await
1992                .offers
1993                .try_into()
1994                .unwrap();
1995            channel
1996        }
1997
1998        async fn get_channels(&mut self, client: &mut VmbusClient, count: usize) -> ConnectResult {
1999            self.connect_with_channels(client, |this| {
2000                for i in 0..count {
2001                    let offer = protocol::OfferChannel {
2002                        interface_id: Guid::new_random(),
2003                        instance_id: Guid::new_random(),
2004                        rsvd: [0; 4],
2005                        flags: OfferFlags::new(),
2006                        mmio_megabytes: 0,
2007                        user_defined: UserDefinedData::new_zeroed(),
2008                        subchannel_index: 0,
2009                        mmio_megabytes_optional: 0,
2010                        channel_id: ChannelId(i as u32),
2011                        monitor_id: 0,
2012                        monitor_allocated: 0,
2013                        is_dedicated: 0,
2014                        connection_id: 0,
2015                    };
2016
2017                    this.send(in_msg(MessageType::OFFER_CHANNEL, offer));
2018                }
2019            })
2020            .await
2021        }
2022
2023        async fn stop_client(&mut self, client: &mut VmbusClient) {
2024            let client_stop = client.stop();
2025            let server_stop = async {
2026                check_message(self.next().await.unwrap(), protocol::Pause);
2027                self.send(in_msg(MessageType::PAUSE_RESPONSE, protocol::PauseResponse));
2028            };
2029            (client_stop, server_stop).join().await;
2030        }
2031
2032        async fn start_client(&mut self, client: &mut VmbusClient) {
2033            client.start();
2034            check_message(self.next().await.unwrap(), protocol::Resume);
2035        }
2036    }
2037
2038    struct TestServerClient {
2039        sender: mesh::Sender<OutgoingMessage>,
2040        timer: PolledTimer,
2041        deadline: Option<pal_async::timer::Instant>,
2042    }
2043
2044    impl PollPostMessage for TestServerClient {
2045        fn poll_post_message(
2046            &mut self,
2047            cx: &mut Context<'_>,
2048            _connection_id: u32,
2049            _typ: u32,
2050            msg: &[u8],
2051        ) -> Poll<()> {
2052            loop {
2053                if let Some(deadline) = self.deadline {
2054                    ready!(self.timer.poll_until(cx, deadline));
2055                    self.deadline = None;
2056                }
2057                // Randomly choose whether to delay the message.
2058                //
2059                // FUTURE: use some kind of deterministic test framework for this to
2060                // allow for reproducible tests.
2061                let mut b = [0];
2062                getrandom::fill(&mut b).unwrap();
2063                if b[0] % 4 == 0 {
2064                    self.deadline =
2065                        Some(pal_async::timer::Instant::now() + Duration::from_millis(10));
2066                } else {
2067                    let msg = OutgoingMessage::from_message(msg).unwrap();
2068                    tracing::info!(
2069                        msg = ?MessageHeader::read_from_prefix(msg.data()),
2070                        "sending message"
2071                    );
2072                    self.sender.send(msg);
2073                    break Poll::Ready(());
2074                }
2075            }
2076        }
2077    }
2078
2079    struct NoopSynicEvents;
2080
2081    impl SynicEventClient for NoopSynicEvents {
2082        fn map_event(&self, _event_flag: u16, _event: &Event) -> std::io::Result<()> {
2083            Ok(())
2084        }
2085
2086        fn unmap_event(&self, _event_flag: u16) {}
2087
2088        fn signal_event(&self, _connection_id: u32, _event_flag: u16) -> std::io::Result<()> {
2089            Err(std::io::ErrorKind::Unsupported.into())
2090        }
2091    }
2092
2093    struct TestMessageSource {
2094        msg_recv: mesh::Receiver<Vec<u8>>,
2095        paused: bool,
2096    }
2097
2098    impl AsyncRecv for TestMessageSource {
2099        fn poll_recv(
2100            &mut self,
2101            cx: &mut Context<'_>,
2102            mut bufs: &mut [std::io::IoSliceMut<'_>],
2103        ) -> Poll<std::io::Result<usize>> {
2104            let value = match self.msg_recv.poll_recv(cx) {
2105                Poll::Ready(v) => v.unwrap(),
2106                Poll::Pending => {
2107                    if self.paused {
2108                        return Poll::Ready(Ok(0));
2109                    } else {
2110                        return Poll::Pending;
2111                    }
2112                }
2113            };
2114            let mut remaining = value.as_slice();
2115            let mut total_size = 0;
2116            while !remaining.is_empty() && !bufs.is_empty() {
2117                let size = bufs[0].len().min(remaining.len());
2118                bufs[0][..size].copy_from_slice(&remaining[..size]);
2119                remaining = &remaining[size..];
2120                bufs = &mut bufs[1..];
2121                total_size += size;
2122            }
2123
2124            Ok(total_size).into()
2125        }
2126    }
2127
2128    impl VmbusMessageSource for TestMessageSource {
2129        fn pause_message_stream(&mut self) {
2130            self.paused = true;
2131        }
2132
2133        fn resume_message_stream(&mut self) {
2134            self.paused = false;
2135        }
2136    }
2137
2138    fn test_init(driver: &DefaultDriver) -> (TestServer, VmbusClient) {
2139        let (msg_send, msg_recv) = mesh::channel();
2140        let (synic_send, synic_recv) = mesh::channel();
2141        let server = TestServer {
2142            messages: synic_recv,
2143            send: msg_send,
2144        };
2145        let mut client = VmbusClientBuilder::new(
2146            NoopSynicEvents,
2147            TestMessageSource {
2148                msg_recv,
2149                paused: false,
2150            },
2151            TestServerClient {
2152                sender: synic_send,
2153                deadline: None,
2154                timer: PolledTimer::new(driver),
2155            },
2156        )
2157        .build(driver);
2158        client.start();
2159        (server, client)
2160    }
2161
2162    #[async_test]
2163    async fn test_initiate_contact_success(driver: DefaultDriver) {
2164        let (mut server, client) = test_init(&driver);
2165        let _recv = client
2166            .access
2167            .client_request_send
2168            .call(ClientRequest::Connect, ConnectRequest::default());
2169        check_message(
2170            server.next().await.unwrap(),
2171            protocol::InitiateContact2 {
2172                initiate_contact: protocol::InitiateContact {
2173                    version_requested: Version::Copper as u32,
2174                    target_message_vp: 0,
2175                    interrupt_page_or_target_info: TargetInfo::new()
2176                        .with_sint(2)
2177                        .with_vtl(0)
2178                        .with_feature_flags(SUPPORTED_FEATURE_FLAGS.into())
2179                        .into(),
2180                    parent_to_child_monitor_page_gpa: 0,
2181                    child_to_parent_monitor_page_gpa: 0,
2182                },
2183                ..FromZeros::new_zeroed()
2184            },
2185        );
2186    }
2187
2188    #[async_test]
2189    async fn test_connect_success(driver: DefaultDriver) {
2190        let (mut server, mut client) = test_init(&driver);
2191        let client_connect = client.connect(0, None, Guid::ZERO);
2192
2193        let server_connect = async {
2194            check_message(
2195                server.next().await.unwrap(),
2196                protocol::InitiateContact2 {
2197                    initiate_contact: protocol::InitiateContact {
2198                        version_requested: Version::Copper as u32,
2199                        target_message_vp: 0,
2200                        interrupt_page_or_target_info: TargetInfo::new()
2201                            .with_sint(2)
2202                            .with_vtl(0)
2203                            .with_feature_flags(SUPPORTED_FEATURE_FLAGS.into())
2204                            .into(),
2205                        parent_to_child_monitor_page_gpa: 0,
2206                        child_to_parent_monitor_page_gpa: 0,
2207                    },
2208                    ..FromZeros::new_zeroed()
2209                },
2210            );
2211
2212            server.send(in_msg(
2213                MessageType::VERSION_RESPONSE,
2214                protocol::VersionResponse2 {
2215                    version_response: protocol::VersionResponse {
2216                        version_supported: 1,
2217                        connection_state: ConnectionState::SUCCESSFUL,
2218                        padding: 0,
2219                        selected_version_or_connection_id: 0,
2220                    },
2221                    supported_features: SUPPORTED_FEATURE_FLAGS.into_bits(),
2222                },
2223            ));
2224
2225            check_message(server.next().await.unwrap(), protocol::RequestOffers {});
2226            server.send(in_msg(MessageType::ALL_OFFERS_DELIVERED, [0x00]));
2227        };
2228
2229        let (connection, ()) = (client_connect, server_connect).join().await;
2230        let connection = connection.unwrap();
2231
2232        assert_eq!(connection.version.version, Version::Copper);
2233        assert_eq!(connection.version.feature_flags, SUPPORTED_FEATURE_FLAGS);
2234    }
2235
2236    #[async_test]
2237    async fn test_feature_flags(driver: DefaultDriver) {
2238        let (mut server, mut client) = test_init(&driver);
2239        let client_connect = client.connect(0, None, Guid::ZERO);
2240
2241        let server_connect = async {
2242            check_message(
2243                server.next().await.unwrap(),
2244                protocol::InitiateContact2 {
2245                    initiate_contact: protocol::InitiateContact {
2246                        version_requested: Version::Copper as u32,
2247                        target_message_vp: 0,
2248                        interrupt_page_or_target_info: TargetInfo::new()
2249                            .with_sint(2)
2250                            .with_vtl(0)
2251                            .with_feature_flags(SUPPORTED_FEATURE_FLAGS.into())
2252                            .into(),
2253                        parent_to_child_monitor_page_gpa: 0,
2254                        child_to_parent_monitor_page_gpa: 0,
2255                    },
2256                    ..FromZeros::new_zeroed()
2257                },
2258            );
2259
2260            // Report the server doesn't support some of the feature flags, and make
2261            // sure this is reflected in the returned version.
2262            server.send(in_msg(
2263                MessageType::VERSION_RESPONSE,
2264                protocol::VersionResponse2 {
2265                    version_response: protocol::VersionResponse {
2266                        version_supported: 1,
2267                        connection_state: ConnectionState::SUCCESSFUL,
2268                        padding: 0,
2269                        selected_version_or_connection_id: 0,
2270                    },
2271                    supported_features: 2,
2272                },
2273            ));
2274
2275            check_message(server.next().await.unwrap(), protocol::RequestOffers {});
2276            server.send(in_msg(MessageType::ALL_OFFERS_DELIVERED, [0x00]));
2277        };
2278
2279        let (connection, ()) = (client_connect, server_connect).join().await;
2280        let connection = connection.unwrap();
2281
2282        assert_eq!(connection.version.version, Version::Copper);
2283        assert_eq!(
2284            connection.version.feature_flags,
2285            FeatureFlags::new().with_channel_interrupt_redirection(true)
2286        );
2287    }
2288
2289    #[async_test]
2290    async fn test_client_id(driver: DefaultDriver) {
2291        let (mut server, client) = test_init(&driver);
2292        let initiate_contact = ConnectRequest {
2293            client_id: VMBUS_TEST_CLIENT_ID,
2294            ..Default::default()
2295        };
2296        let _recv = client
2297            .access
2298            .client_request_send
2299            .call(ClientRequest::Connect, initiate_contact);
2300
2301        check_message(
2302            server.next().await.unwrap(),
2303            protocol::InitiateContact2 {
2304                initiate_contact: protocol::InitiateContact {
2305                    version_requested: Version::Copper as u32,
2306                    target_message_vp: 0,
2307                    interrupt_page_or_target_info: TargetInfo::new()
2308                        .with_sint(2)
2309                        .with_vtl(0)
2310                        .with_feature_flags(SUPPORTED_FEATURE_FLAGS.into())
2311                        .into(),
2312                    parent_to_child_monitor_page_gpa: 0,
2313                    child_to_parent_monitor_page_gpa: 0,
2314                },
2315                client_id: VMBUS_TEST_CLIENT_ID,
2316            },
2317        );
2318    }
2319
2320    #[async_test]
2321    async fn test_version_negotiation(driver: DefaultDriver) {
2322        let (mut server, mut client) = test_init(&driver);
2323        let client_connect = client.connect(0, None, Guid::ZERO);
2324
2325        let server_connect = async {
2326            check_message(
2327                server.next().await.unwrap(),
2328                protocol::InitiateContact2 {
2329                    initiate_contact: protocol::InitiateContact {
2330                        version_requested: Version::Copper as u32,
2331                        target_message_vp: 0,
2332                        interrupt_page_or_target_info: TargetInfo::new()
2333                            .with_sint(2)
2334                            .with_vtl(0)
2335                            .with_feature_flags(SUPPORTED_FEATURE_FLAGS.into())
2336                            .into(),
2337                        parent_to_child_monitor_page_gpa: 0,
2338                        child_to_parent_monitor_page_gpa: 0,
2339                    },
2340                    ..FromZeros::new_zeroed()
2341                },
2342            );
2343
2344            server.send(in_msg(
2345                MessageType::VERSION_RESPONSE,
2346                protocol::VersionResponse {
2347                    version_supported: 0,
2348                    connection_state: ConnectionState::SUCCESSFUL,
2349                    padding: 0,
2350                    selected_version_or_connection_id: 0,
2351                },
2352            ));
2353
2354            check_message(
2355                server.next().await.unwrap(),
2356                protocol::InitiateContact {
2357                    version_requested: Version::Iron as u32,
2358                    target_message_vp: 0,
2359                    interrupt_page_or_target_info: TargetInfo::new()
2360                        .with_sint(2)
2361                        .with_vtl(0)
2362                        .with_feature_flags(FeatureFlags::new().into())
2363                        .into(),
2364                    parent_to_child_monitor_page_gpa: 0,
2365                    child_to_parent_monitor_page_gpa: 0,
2366                },
2367            );
2368
2369            server.send(in_msg(
2370                MessageType::VERSION_RESPONSE,
2371                protocol::VersionResponse {
2372                    version_supported: 1,
2373                    connection_state: ConnectionState::SUCCESSFUL,
2374                    padding: 0,
2375                    selected_version_or_connection_id: 0,
2376                },
2377            ));
2378
2379            check_message(server.next().await.unwrap(), protocol::RequestOffers {});
2380            server.send(in_msg(MessageType::ALL_OFFERS_DELIVERED, [0x00]));
2381        };
2382
2383        let (connection, ()) = (client_connect, server_connect).join().await;
2384        let connection = connection.unwrap();
2385
2386        assert_eq!(connection.version.version, Version::Iron);
2387        assert_eq!(connection.version.feature_flags, FeatureFlags::new());
2388    }
2389
2390    #[async_test]
2391    async fn test_open_channel_success(driver: DefaultDriver) {
2392        let (mut server, mut client) = test_init(&driver);
2393        let channel = server.get_channel(&mut client).await;
2394
2395        let recv = channel.request_send.call(
2396            ChannelRequest::Open,
2397            OpenRequest {
2398                open_data: OpenData {
2399                    target_vp: 0,
2400                    ring_offset: 0,
2401                    ring_gpadl_id: GpadlId(0),
2402                    event_flag: 0,
2403                    connection_id: 0,
2404                    user_data: UserDefinedData::new_zeroed(),
2405                },
2406                incoming_event: None,
2407                use_vtl2_connection_id: false,
2408            },
2409        );
2410
2411        check_message(
2412            server.next().await.unwrap(),
2413            protocol::OpenChannel2 {
2414                open_channel: protocol::OpenChannel {
2415                    channel_id: ChannelId(0),
2416                    open_id: 0,
2417                    ring_buffer_gpadl_id: GpadlId(0),
2418                    target_vp: 0,
2419                    downstream_ring_buffer_page_offset: 0,
2420                    user_data: UserDefinedData::new_zeroed(),
2421                },
2422                connection_id: 0,
2423                event_flag: 0,
2424                flags: Default::default(),
2425            },
2426        );
2427
2428        server.send(in_msg(
2429            MessageType::OPEN_CHANNEL_RESULT,
2430            protocol::OpenResult {
2431                channel_id: ChannelId(0),
2432                open_id: 0,
2433                status: protocol::STATUS_SUCCESS as u32,
2434            },
2435        ));
2436
2437        recv.await.unwrap().unwrap();
2438    }
2439
2440    #[async_test]
2441    async fn test_open_channel_fail(driver: DefaultDriver) {
2442        let (mut server, mut client) = test_init(&driver);
2443        let channel = server.get_channel(&mut client).await;
2444
2445        let recv = channel.request_send.call(
2446            ChannelRequest::Open,
2447            OpenRequest {
2448                open_data: OpenData {
2449                    target_vp: 0,
2450                    ring_offset: 0,
2451                    ring_gpadl_id: GpadlId(0),
2452                    event_flag: 0,
2453                    connection_id: 0,
2454                    user_data: UserDefinedData::new_zeroed(),
2455                },
2456                incoming_event: None,
2457                use_vtl2_connection_id: false,
2458            },
2459        );
2460
2461        check_message(
2462            server.next().await.unwrap(),
2463            protocol::OpenChannel2 {
2464                open_channel: protocol::OpenChannel {
2465                    channel_id: ChannelId(0),
2466                    open_id: 0,
2467                    ring_buffer_gpadl_id: GpadlId(0),
2468                    target_vp: 0,
2469                    downstream_ring_buffer_page_offset: 0,
2470                    user_data: UserDefinedData::new_zeroed(),
2471                },
2472                connection_id: 0,
2473                event_flag: 0,
2474                flags: Default::default(),
2475            },
2476        );
2477
2478        server.send(in_msg(
2479            MessageType::OPEN_CHANNEL_RESULT,
2480            protocol::OpenResult {
2481                channel_id: ChannelId(0),
2482                open_id: 0,
2483                status: protocol::STATUS_UNSUCCESSFUL as u32,
2484            },
2485        ));
2486
2487        recv.await.unwrap().unwrap_err();
2488    }
2489
2490    #[async_test]
2491    async fn test_modify_channel(driver: DefaultDriver) {
2492        let (mut server, mut client) = test_init(&driver);
2493        let channel = server.get_channel(&mut client).await;
2494
2495        // N.B. A real server requires the channel to be open before sending this, but the test
2496        //      server doesn't care.
2497        let recv = channel.request_send.call(
2498            ChannelRequest::Modify,
2499            ModifyRequest::TargetVp { target_vp: 1 },
2500        );
2501
2502        check_message(
2503            server.next().await.unwrap(),
2504            protocol::ModifyChannel {
2505                channel_id: ChannelId(0),
2506                target_vp: 1,
2507            },
2508        );
2509
2510        server.send(in_msg(
2511            MessageType::MODIFY_CHANNEL_RESPONSE,
2512            protocol::ModifyChannelResponse {
2513                channel_id: ChannelId(0),
2514                status: protocol::STATUS_SUCCESS,
2515            },
2516        ));
2517
2518        let status = recv.await.unwrap();
2519        assert_eq!(status, protocol::STATUS_SUCCESS);
2520    }
2521
2522    #[async_test]
2523    async fn test_save_restore_connected(driver: DefaultDriver) {
2524        let (mut server, mut client) = test_init(&driver);
2525        server.connect(&mut client).await;
2526        server.stop_client(&mut client).await;
2527        let s0 = client.save().await;
2528        let builder = client.sever().await;
2529        let mut client = builder.build(&driver);
2530        client.restore(s0.clone()).await.unwrap();
2531
2532        let s1 = client.save().await;
2533
2534        assert_eq!(s0, s1);
2535    }
2536
2537    #[async_test]
2538    async fn test_save_restore_connected_with_channel(driver: DefaultDriver) {
2539        let (mut server, mut client) = test_init(&driver);
2540        let c0 = server.get_channel(&mut client).await;
2541        server.stop_client(&mut client).await;
2542        let s0 = client.save().await;
2543        let builder = client.sever().await;
2544        let mut client = builder.build(&driver);
2545        let connection = client.restore(s0.clone()).await.unwrap().unwrap();
2546        let s1 = client.save().await;
2547        assert_eq!(s0, s1);
2548        assert_eq!(connection.offers[0].offer, c0.offer);
2549    }
2550
2551    #[async_test]
2552    async fn test_save_restore_connected_with_revoked_channel(driver: DefaultDriver) {
2553        let (mut server, mut client) = test_init(&driver);
2554        let c0 = server.get_channel(&mut client).await;
2555        server.send(in_msg(
2556            MessageType::RESCIND_CHANNEL_OFFER,
2557            protocol::RescindChannelOffer {
2558                channel_id: ChannelId(0),
2559            },
2560        ));
2561        c0.revoke_recv.await.unwrap();
2562        let rpc = c0.request_send.call(
2563            ChannelRequest::Modify,
2564            ModifyRequest::TargetVp { target_vp: 1 },
2565        );
2566
2567        check_message(
2568            server.next().await.unwrap(),
2569            protocol::ModifyChannel {
2570                channel_id: ChannelId(0),
2571                target_vp: 1,
2572            },
2573        );
2574
2575        let client_stop = client.stop();
2576        let server_stop = async {
2577            server.send(in_msg(
2578                MessageType::MODIFY_CHANNEL_RESPONSE,
2579                protocol::ModifyChannelResponse {
2580                    channel_id: ChannelId(0),
2581                    status: protocol::STATUS_SUCCESS,
2582                },
2583            ));
2584            check_message(server.next().await.unwrap(), protocol::Pause);
2585            server.send(in_msg(MessageType::PAUSE_RESPONSE, protocol::PauseResponse));
2586        };
2587        (client_stop, server_stop).join().await;
2588
2589        rpc.await.unwrap();
2590
2591        let s0 = client.save().await;
2592        let builder = client.sever().await;
2593        let mut client = builder.build(&driver);
2594        let connection = client.restore(s0.clone()).await.unwrap().unwrap();
2595        let s1 = client.save().await;
2596        assert_eq!(s0, s1);
2597        assert!(connection.offers.is_empty());
2598        server.start_client(&mut client).await;
2599        check_message(
2600            server.next().await.unwrap(),
2601            protocol::RelIdReleased {
2602                channel_id: ChannelId(0),
2603            },
2604        );
2605    }
2606
2607    #[async_test]
2608    async fn test_connect_fails_on_incorrect_state(driver: DefaultDriver) {
2609        let (mut server, mut client) = test_init(&driver);
2610        server.connect(&mut client).await;
2611        let err = client.connect(0, None, Guid::ZERO).await.unwrap_err();
2612        assert!(matches!(err, ConnectError::InvalidState), "{:?}", err);
2613    }
2614
2615    #[async_test]
2616    async fn test_hot_add_remove(driver: DefaultDriver) {
2617        let (mut server, mut client) = test_init(&driver);
2618
2619        let mut connection = server.connect(&mut client).await;
2620        let offer = protocol::OfferChannel {
2621            interface_id: Guid::new_random(),
2622            instance_id: Guid::new_random(),
2623            rsvd: [0; 4],
2624            flags: OfferFlags::new(),
2625            mmio_megabytes: 0,
2626            user_defined: UserDefinedData::new_zeroed(),
2627            subchannel_index: 0,
2628            mmio_megabytes_optional: 0,
2629            channel_id: ChannelId(5),
2630            monitor_id: 0,
2631            monitor_allocated: 0,
2632            is_dedicated: 0,
2633            connection_id: 0,
2634        };
2635
2636        server.send(in_msg(MessageType::OFFER_CHANNEL, offer));
2637        let info = connection.offer_recv.next().await.unwrap();
2638
2639        assert_eq!(offer, info.offer);
2640
2641        server.send(in_msg(
2642            MessageType::RESCIND_CHANNEL_OFFER,
2643            protocol::RescindChannelOffer {
2644                channel_id: ChannelId(5),
2645            },
2646        ));
2647
2648        info.revoke_recv.await.unwrap();
2649        drop(info.request_send);
2650
2651        check_message(
2652            server.next().await.unwrap(),
2653            protocol::RelIdReleased {
2654                channel_id: ChannelId(5),
2655            },
2656        );
2657    }
2658
2659    #[async_test]
2660    async fn test_gpadl_success(driver: DefaultDriver) {
2661        let (mut server, mut client) = test_init(&driver);
2662        let channel = server.get_channel(&mut client).await;
2663        let recv = channel.request_send.call(
2664            ChannelRequest::Gpadl,
2665            GpadlRequest {
2666                id: GpadlId(1),
2667                count: 1,
2668                buf: vec![5],
2669            },
2670        );
2671
2672        check_message_with_data(
2673            server.next().await.unwrap(),
2674            protocol::GpadlHeader {
2675                channel_id: ChannelId(0),
2676                gpadl_id: GpadlId(1),
2677                len: 8,
2678                count: 1,
2679            },
2680            0x5u64.as_bytes(),
2681        );
2682
2683        server.send(in_msg(
2684            MessageType::GPADL_CREATED,
2685            protocol::GpadlCreated {
2686                channel_id: ChannelId(0),
2687                gpadl_id: GpadlId(1),
2688                status: protocol::STATUS_SUCCESS,
2689            },
2690        ));
2691
2692        recv.await.unwrap().unwrap();
2693
2694        let rpc = channel
2695            .request_send
2696            .call(ChannelRequest::TeardownGpadl, GpadlId(1));
2697
2698        check_message(
2699            server.next().await.unwrap(),
2700            protocol::GpadlTeardown {
2701                channel_id: ChannelId(0),
2702                gpadl_id: GpadlId(1),
2703            },
2704        );
2705
2706        server.send(in_msg(
2707            MessageType::GPADL_TORNDOWN,
2708            protocol::GpadlTorndown {
2709                gpadl_id: GpadlId(1),
2710            },
2711        ));
2712
2713        rpc.await.unwrap();
2714    }
2715
2716    #[async_test]
2717    async fn test_gpadl_fail(driver: DefaultDriver) {
2718        let (mut server, mut client) = test_init(&driver);
2719        let channel = server.get_channel(&mut client).await;
2720        let recv = channel.request_send.call(
2721            ChannelRequest::Gpadl,
2722            GpadlRequest {
2723                id: GpadlId(1),
2724                count: 1,
2725                buf: vec![7],
2726            },
2727        );
2728
2729        check_message_with_data(
2730            server.next().await.unwrap(),
2731            protocol::GpadlHeader {
2732                channel_id: ChannelId(0),
2733                gpadl_id: GpadlId(1),
2734                len: 8,
2735                count: 1,
2736            },
2737            0x7u64.as_bytes(),
2738        );
2739
2740        server.send(in_msg(
2741            MessageType::GPADL_CREATED,
2742            protocol::GpadlCreated {
2743                channel_id: ChannelId(0),
2744                gpadl_id: GpadlId(1),
2745                status: protocol::STATUS_UNSUCCESSFUL,
2746            },
2747        ));
2748
2749        recv.await.unwrap().unwrap_err();
2750    }
2751
2752    #[async_test]
2753    async fn test_gpadl_with_revoke(driver: DefaultDriver) {
2754        let (mut server, mut client) = test_init(&driver);
2755        let channel = server.get_channel(&mut client).await;
2756        let channel_id = ChannelId(0);
2757        for gpadl_id in [1, 2, 3].map(GpadlId) {
2758            let recv = channel.request_send.call(
2759                ChannelRequest::Gpadl,
2760                GpadlRequest {
2761                    id: gpadl_id,
2762                    count: 1,
2763                    buf: vec![3],
2764                },
2765            );
2766
2767            check_message_with_data(
2768                server.next().await.unwrap(),
2769                protocol::GpadlHeader {
2770                    channel_id,
2771                    gpadl_id,
2772                    len: 8,
2773                    count: 1,
2774                },
2775                0x3u64.as_bytes(),
2776            );
2777
2778            server.send(in_msg(
2779                MessageType::GPADL_CREATED,
2780                protocol::GpadlCreated {
2781                    channel_id,
2782                    gpadl_id,
2783                    status: protocol::STATUS_SUCCESS,
2784                },
2785            ));
2786
2787            recv.await.unwrap().unwrap();
2788        }
2789
2790        let rpc = channel
2791            .request_send
2792            .call(ChannelRequest::TeardownGpadl, GpadlId(1));
2793
2794        check_message(
2795            server.next().await.unwrap(),
2796            protocol::GpadlTeardown {
2797                channel_id,
2798                gpadl_id: GpadlId(1),
2799            },
2800        );
2801
2802        server.send(in_msg(
2803            MessageType::RESCIND_CHANNEL_OFFER,
2804            protocol::RescindChannelOffer { channel_id },
2805        ));
2806
2807        let recv = channel.request_send.call_failable(
2808            ChannelRequest::Gpadl,
2809            GpadlRequest {
2810                id: GpadlId(4),
2811                count: 1,
2812                buf: vec![3],
2813            },
2814        );
2815
2816        check_message_with_data(
2817            server.next().await.unwrap(),
2818            protocol::GpadlHeader {
2819                channel_id,
2820                gpadl_id: GpadlId(4),
2821                len: 8,
2822                count: 1,
2823            },
2824            0x3u64.as_bytes(),
2825        );
2826
2827        server.send(in_msg(
2828            MessageType::GPADL_CREATED,
2829            protocol::GpadlCreated {
2830                channel_id,
2831                gpadl_id: GpadlId(4),
2832                status: protocol::STATUS_UNSUCCESSFUL,
2833            },
2834        ));
2835
2836        server.send(in_msg(
2837            MessageType::GPADL_TORNDOWN,
2838            protocol::GpadlTorndown {
2839                gpadl_id: GpadlId(1),
2840            },
2841        ));
2842
2843        rpc.await.unwrap();
2844        recv.await.unwrap_err();
2845
2846        channel.revoke_recv.await.unwrap();
2847
2848        let rpc = channel
2849            .request_send
2850            .call(ChannelRequest::TeardownGpadl, GpadlId(2));
2851        drop(channel.request_send);
2852
2853        check_message(
2854            server.next().await.unwrap(),
2855            protocol::GpadlTeardown {
2856                channel_id,
2857                gpadl_id: GpadlId(2),
2858            },
2859        );
2860
2861        server.send(in_msg(
2862            MessageType::GPADL_TORNDOWN,
2863            protocol::GpadlTorndown {
2864                gpadl_id: GpadlId(2),
2865            },
2866        ));
2867
2868        rpc.await.unwrap();
2869
2870        check_message(
2871            server.next().await.unwrap(),
2872            protocol::RelIdReleased { channel_id },
2873        );
2874    }
2875
2876    #[async_test]
2877    async fn test_modify_connection(driver: DefaultDriver) {
2878        let (mut server, mut client) = test_init(&driver);
2879        server.connect(&mut client).await;
2880        let call = client.access.client_request_send.call(
2881            ClientRequest::Modify,
2882            ModifyConnectionRequest {
2883                monitor_page: Some(MonitorPageGpas {
2884                    child_to_parent: 5,
2885                    parent_to_child: 6,
2886                }),
2887            },
2888        );
2889
2890        check_message(
2891            server.next().await.unwrap(),
2892            protocol::ModifyConnection {
2893                child_to_parent_monitor_page_gpa: 5,
2894                parent_to_child_monitor_page_gpa: 6,
2895            },
2896        );
2897
2898        server.send(in_msg(
2899            MessageType::MODIFY_CONNECTION_RESPONSE,
2900            protocol::ModifyConnectionResponse {
2901                connection_state: ConnectionState::FAILED_LOW_RESOURCES,
2902            },
2903        ));
2904
2905        let result = call.await.unwrap();
2906        assert_eq!(ConnectionState::FAILED_LOW_RESOURCES, result);
2907    }
2908
2909    #[async_test]
2910    async fn test_hvsock(driver: DefaultDriver) {
2911        let (mut server, mut client) = test_init(&driver);
2912        server.connect(&mut client).await;
2913        let request = HvsockConnectRequest {
2914            service_id: Guid::new_random(),
2915            endpoint_id: Guid::new_random(),
2916            silo_id: Guid::new_random(),
2917            hosted_silo_unaware: false,
2918        };
2919
2920        let resp = client.access().connect_hvsock(request);
2921        check_message(
2922            server.next().await.unwrap(),
2923            protocol::TlConnectRequest2 {
2924                base: protocol::TlConnectRequest {
2925                    service_id: request.service_id,
2926                    endpoint_id: request.endpoint_id,
2927                },
2928                silo_id: request.silo_id,
2929            },
2930        );
2931
2932        // Now send a failure result.
2933        server.send(in_msg(
2934            MessageType::TL_CONNECT_REQUEST_RESULT,
2935            protocol::TlConnectResult {
2936                service_id: request.service_id,
2937                endpoint_id: request.endpoint_id,
2938                status: protocol::STATUS_CONNECTION_REFUSED,
2939            },
2940        ));
2941
2942        let result = resp.await;
2943        assert!(result.is_none());
2944    }
2945
2946    #[async_test]
2947    async fn test_synic_event_flags(driver: DefaultDriver) {
2948        let (mut server, mut client) = test_init(&driver);
2949        let connection = server.get_channels(&mut client, 5).await;
2950        let event = Event::new();
2951
2952        for _ in 0..5 {
2953            for (i, channel) in connection.offers.iter().enumerate() {
2954                let recv = channel.request_send.call(
2955                    ChannelRequest::Open,
2956                    OpenRequest {
2957                        open_data: OpenData {
2958                            target_vp: 0,
2959                            ring_offset: 0,
2960                            ring_gpadl_id: GpadlId(0),
2961                            event_flag: 0,
2962                            connection_id: 0,
2963                            user_data: UserDefinedData::new_zeroed(),
2964                        },
2965                        incoming_event: Some(event.clone()),
2966                        use_vtl2_connection_id: false,
2967                    },
2968                );
2969
2970                let expected_event_flag = i as u16 + 1;
2971
2972                check_message(
2973                    server.next().await.unwrap(),
2974                    protocol::OpenChannel2 {
2975                        open_channel: protocol::OpenChannel {
2976                            channel_id: channel.offer.channel_id,
2977                            open_id: 0,
2978                            ring_buffer_gpadl_id: GpadlId(0),
2979                            target_vp: 0,
2980                            downstream_ring_buffer_page_offset: 0,
2981                            user_data: UserDefinedData::new_zeroed(),
2982                        },
2983                        connection_id: 0,
2984                        event_flag: expected_event_flag,
2985                        flags: OpenChannelFlags::new().with_redirect_interrupt(true),
2986                    },
2987                );
2988
2989                server.send(in_msg(
2990                    MessageType::OPEN_CHANNEL_RESULT,
2991                    protocol::OpenResult {
2992                        channel_id: channel.offer.channel_id,
2993                        open_id: 0,
2994                        status: protocol::STATUS_SUCCESS as u32,
2995                    },
2996                ));
2997
2998                let output = recv.await.unwrap().unwrap();
2999                assert_eq!(output.redirected_event_flag, Some(expected_event_flag));
3000            }
3001
3002            for (i, channel) in connection.offers.iter().enumerate() {
3003                // Close the channel to prepare for the next iteration of the loop.
3004                // The event flag should be the same each time.
3005                channel
3006                    .request_send
3007                    .call(ChannelRequest::Close, ())
3008                    .await
3009                    .unwrap();
3010
3011                check_message(
3012                    server.next().await.unwrap(),
3013                    protocol::CloseChannel {
3014                        channel_id: ChannelId(i as u32),
3015                    },
3016                );
3017            }
3018        }
3019    }
3020
3021    #[async_test]
3022    async fn test_revoke(driver: DefaultDriver) {
3023        let (mut server, mut client) = test_init(&driver);
3024        let channel = server.get_channel(&mut client).await;
3025
3026        server.send(in_msg(
3027            MessageType::RESCIND_CHANNEL_OFFER,
3028            protocol::RescindChannelOffer {
3029                channel_id: ChannelId(0),
3030            },
3031        ));
3032
3033        channel.revoke_recv.await.unwrap();
3034
3035        channel
3036            .request_send
3037            .call_failable(
3038                ChannelRequest::Open,
3039                OpenRequest {
3040                    open_data: OpenData {
3041                        target_vp: 0,
3042                        ring_offset: 0,
3043                        ring_gpadl_id: GpadlId(0),
3044                        event_flag: 0,
3045                        connection_id: 0,
3046                        user_data: UserDefinedData::new_zeroed(),
3047                    },
3048                    incoming_event: None,
3049                    use_vtl2_connection_id: false,
3050                },
3051            )
3052            .await
3053            .unwrap_err();
3054    }
3055
3056    #[async_test]
3057    #[should_panic(expected = "channel should not exist")]
3058    async fn test_reoffer_in_use_rel_id(driver: DefaultDriver) {
3059        let (mut server, mut client) = test_init(&driver);
3060        let mut connection = server.get_channels(&mut client, 1).await;
3061        let [channel] = connection.offers.try_into().unwrap();
3062
3063        server.send(in_msg(
3064            MessageType::RESCIND_CHANNEL_OFFER,
3065            protocol::RescindChannelOffer {
3066                channel_id: ChannelId(0),
3067            },
3068        ));
3069
3070        channel.revoke_recv.await.unwrap();
3071
3072        // This offer will cause a panic since the rel id is still in use.
3073        let offer = protocol::OfferChannel {
3074            interface_id: Guid::new_random(),
3075            instance_id: Guid::new_random(),
3076            rsvd: [0; 4],
3077            flags: OfferFlags::new(),
3078            mmio_megabytes: 0,
3079            user_defined: UserDefinedData::new_zeroed(),
3080            subchannel_index: 0,
3081            mmio_megabytes_optional: 0,
3082            channel_id: ChannelId(0),
3083            monitor_id: 0,
3084            monitor_allocated: 0,
3085            is_dedicated: 0,
3086            connection_id: 0,
3087        };
3088
3089        server.send(in_msg(MessageType::OFFER_CHANNEL, offer));
3090
3091        connection.offer_recv.next().await;
3092    }
3093
3094    #[async_test]
3095    async fn test_revoke_release_and_reoffer(driver: DefaultDriver) {
3096        let (mut server, mut client) = test_init(&driver);
3097        let mut connection = server.get_channels(&mut client, 1).await;
3098        let [channel] = connection.offers.try_into().unwrap();
3099
3100        server.send(in_msg(
3101            MessageType::RESCIND_CHANNEL_OFFER,
3102            protocol::RescindChannelOffer {
3103                channel_id: ChannelId(0),
3104            },
3105        ));
3106
3107        channel.revoke_recv.await.unwrap();
3108        drop(channel.request_send);
3109
3110        check_message(
3111            server.next().await.unwrap(),
3112            protocol::RelIdReleased {
3113                channel_id: ChannelId(0),
3114            },
3115        );
3116
3117        let offer = protocol::OfferChannel {
3118            interface_id: Guid::new_random(),
3119            instance_id: Guid::new_random(),
3120            rsvd: [0; 4],
3121            flags: OfferFlags::new(),
3122            mmio_megabytes: 0,
3123            user_defined: UserDefinedData::new_zeroed(),
3124            subchannel_index: 0,
3125            mmio_megabytes_optional: 0,
3126            channel_id: ChannelId(0),
3127            monitor_id: 0,
3128            monitor_allocated: 0,
3129            is_dedicated: 0,
3130            connection_id: 0,
3131        };
3132
3133        server.send(in_msg(MessageType::OFFER_CHANNEL, offer));
3134
3135        connection.offer_recv.next().await.unwrap();
3136    }
3137
3138    #[async_test]
3139    async fn test_release_revoke_and_reoffer(driver: DefaultDriver) {
3140        let (mut server, mut client) = test_init(&driver);
3141        let mut connection = server.get_channels(&mut client, 1).await;
3142        let [channel] = connection.offers.try_into().unwrap();
3143
3144        let open = channel.request_send.call_failable(
3145            ChannelRequest::Open,
3146            OpenRequest {
3147                open_data: OpenData {
3148                    target_vp: 0,
3149                    ring_offset: 0,
3150                    ring_gpadl_id: GpadlId(0),
3151                    event_flag: 0,
3152                    connection_id: 0,
3153                    user_data: UserDefinedData::new_zeroed(),
3154                },
3155                incoming_event: None,
3156                use_vtl2_connection_id: false,
3157            },
3158        );
3159
3160        let server_open = async {
3161            check_message(
3162                server.next().await.unwrap(),
3163                protocol::OpenChannel2 {
3164                    open_channel: protocol::OpenChannel {
3165                        channel_id: ChannelId(0),
3166                        open_id: 0,
3167                        ring_buffer_gpadl_id: GpadlId(0),
3168                        target_vp: 0,
3169                        downstream_ring_buffer_page_offset: 0,
3170                        user_data: UserDefinedData::new_zeroed(),
3171                    },
3172                    connection_id: 0,
3173                    event_flag: 0,
3174                    flags: Default::default(),
3175                },
3176            );
3177            server.send(in_msg(
3178                MessageType::OPEN_CHANNEL_RESULT,
3179                protocol::OpenResult {
3180                    channel_id: ChannelId(0),
3181                    open_id: 0,
3182                    status: protocol::STATUS_SUCCESS as u32,
3183                },
3184            ));
3185        };
3186
3187        (open, server_open).join().await.0.unwrap();
3188
3189        // This will close the channel but won't release it yet.
3190        drop(channel);
3191
3192        check_message(
3193            server.next().await.unwrap(),
3194            protocol::CloseChannel {
3195                channel_id: ChannelId(0),
3196            },
3197        );
3198
3199        server.send(in_msg(
3200            MessageType::RESCIND_CHANNEL_OFFER,
3201            protocol::RescindChannelOffer {
3202                channel_id: ChannelId(0),
3203            },
3204        ));
3205
3206        // Should be released.
3207        check_message(
3208            server.next().await.unwrap(),
3209            protocol::RelIdReleased {
3210                channel_id: ChannelId(0),
3211            },
3212        );
3213
3214        let offer = protocol::OfferChannel {
3215            interface_id: Guid::new_random(),
3216            instance_id: Guid::new_random(),
3217            rsvd: [0; 4],
3218            flags: OfferFlags::new(),
3219            mmio_megabytes: 0,
3220            user_defined: UserDefinedData::new_zeroed(),
3221            subchannel_index: 0,
3222            mmio_megabytes_optional: 0,
3223            channel_id: ChannelId(0),
3224            monitor_id: 0,
3225            monitor_allocated: 0,
3226            is_dedicated: 0,
3227            connection_id: 0,
3228        };
3229
3230        server.send(in_msg(MessageType::OFFER_CHANNEL, offer));
3231
3232        // New offer should come through.
3233        connection.offer_recv.next().await.unwrap();
3234    }
3235}