vmbus_client/
lib.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! Client driver for the Hyper-V Virtual Machine Bus (VmBus).
5
6#![expect(missing_docs)]
7#![forbid(unsafe_code)]
8
9pub mod driver;
10pub mod filter;
11mod hvsock;
12pub mod saved_state;
13
14pub use self::saved_state::SavedState;
15use anyhow::Context as _;
16use anyhow::Result;
17use futures::FutureExt;
18use futures::StreamExt;
19use futures::future::OptionFuture;
20use futures::stream::SelectAll;
21use futures_concurrency::future::Race;
22use guid::Guid;
23use inspect::Inspect;
24use mesh::rpc::FailableRpc;
25use mesh::rpc::Rpc;
26use mesh::rpc::RpcSend;
27use pal_async::task::Spawn;
28use pal_async::task::Task;
29use pal_event::Event;
30use std::collections::HashMap;
31use std::collections::VecDeque;
32use std::collections::hash_map;
33use std::convert::TryInto;
34use std::future::Future;
35use std::future::poll_fn;
36use std::ops::Deref;
37use std::ops::DerefMut;
38use std::pin::pin;
39use std::sync::Arc;
40use std::sync::atomic::AtomicU32;
41use std::sync::atomic::Ordering;
42use std::task::Context;
43use std::task::Poll;
44use thiserror::Error;
45use vmbus_async::async_dgram::AsyncRecv;
46use vmbus_async::async_dgram::AsyncRecvExt;
47use vmbus_channel::bus::GpadlRequest;
48use vmbus_channel::bus::ModifyRequest;
49use vmbus_channel::bus::OfferKey;
50use vmbus_channel::bus::OpenData;
51use vmbus_channel::gpadl::GpadlId;
52use vmbus_core::HvsockConnectRequest;
53use vmbus_core::OutgoingMessage;
54use vmbus_core::TaggedStream;
55use vmbus_core::VersionInfo;
56use vmbus_core::protocol;
57use vmbus_core::protocol::ChannelId;
58use vmbus_core::protocol::ConnectionState;
59use vmbus_core::protocol::FeatureFlags;
60use vmbus_core::protocol::Message;
61use vmbus_core::protocol::OpenChannelFlags;
62use vmbus_core::protocol::Version;
63use vmcore::interrupt::Interrupt;
64use vmcore::synic::MonitorPageGpas;
65use zerocopy::Immutable;
66use zerocopy::IntoBytes;
67use zerocopy::KnownLayout;
68
69const SINT: u8 = 2;
70const VTL: u8 = 0;
71const SUPPORTED_VERSIONS: &[Version] = &[Version::Iron, Version::Copper];
72const SUPPORTED_FEATURE_FLAGS: FeatureFlags = FeatureFlags::new()
73    .with_guest_specified_signal_parameters(true)
74    .with_channel_interrupt_redirection(true)
75    .with_modify_connection(true)
76    .with_client_id(true)
77    .with_pause_resume(true);
78
79/// The client interface synic events.
80pub trait SynicEventClient: Send + Sync {
81    /// Maps an incoming event signal on SINT7 to `event`.
82    fn map_event(&self, event_flag: u16, event: &Event) -> std::io::Result<()>;
83
84    /// Unmaps an event previously mapped with `map_event`.
85    fn unmap_event(&self, event_flag: u16);
86
87    /// Signals an event on the synic.
88    fn signal_event(&self, connection_id: u32, event_flag: u16) -> std::io::Result<()>;
89}
90
91/// A stream of vmbus messages that can be paused and resumed.
92pub trait VmbusMessageSource: AsyncRecv + Send {
93    /// Stop accepting new messages from the synic. After this is called, the message source must
94    /// return any pending messages already in the queue, and then return EOF.
95    fn pause_message_stream(&mut self) {}
96
97    /// Resume accepting new messages from the synic.
98    fn resume_message_stream(&mut self) {}
99}
100
101pub trait PollPostMessage: Send {
102    fn poll_post_message(
103        &mut self,
104        cx: &mut Context<'_>,
105        connection_id: u32,
106        typ: u32,
107        msg: &[u8],
108    ) -> Poll<()>;
109}
110
111#[derive(Inspect)]
112pub struct VmbusClient {
113    #[inspect(flatten, send = "TaskRequest::Inspect")]
114    task_send: mesh::Sender<TaskRequest>,
115    #[inspect(skip)]
116    access: VmbusClientAccess,
117    #[inspect(skip)]
118    task: Task<ClientTask>,
119}
120
121#[derive(Debug, thiserror::Error)]
122pub enum ConnectError {
123    #[error("invalid state to connect to the server")]
124    InvalidState,
125    #[error("no supported protocol versions")]
126    NoSupportedVersions,
127    #[error("failed to connect to the server: {0:?}")]
128    FailedToConnect(ConnectionState),
129}
130
131#[derive(Clone)]
132pub struct VmbusClientAccess {
133    client_request_send: mesh::Sender<ClientRequest>,
134}
135
136/// A builder for creating a [`VmbusClient`].
137pub struct VmbusClientBuilder {
138    event_client: Arc<dyn SynicEventClient>,
139    msg_source: Box<dyn VmbusMessageSource>,
140    msg_client: Box<dyn PollPostMessage>,
141}
142
143impl VmbusClientBuilder {
144    /// Creates a new instance of the builder with the given synic input.
145    pub fn new(
146        event_client: impl SynicEventClient + 'static,
147        msg_source: impl VmbusMessageSource + 'static,
148        msg_client: impl PollPostMessage + 'static,
149    ) -> Self {
150        Self {
151            event_client: Arc::new(event_client),
152            msg_source: Box::new(msg_source),
153            msg_client: Box::new(msg_client),
154        }
155    }
156
157    /// Creates a new instance with a receiver for incoming synic messages.
158    pub fn build(self, spawner: &impl Spawn) -> VmbusClient {
159        let (task_send, task_recv) = mesh::channel();
160        let (client_request_send, client_request_recv) = mesh::channel();
161
162        let inner = ClientTaskInner {
163            messages: OutgoingMessages {
164                poster: self.msg_client,
165                queued: VecDeque::new(),
166                state: OutgoingMessageState::Paused,
167            },
168            teardown_gpadls: HashMap::new(),
169            channel_requests: SelectAll::new(),
170            synic: SynicState {
171                event_flag_state: Vec::new(),
172                event_client: self.event_client,
173            },
174        };
175
176        let mut task = ClientTask {
177            inner,
178            channels: ChannelList::default(),
179            task_recv,
180            running: false,
181            msg_source: self.msg_source,
182            client_request_recv,
183            state: ClientState::Disconnected,
184            modify_request: None,
185            hvsock_tracker: hvsock::HvsockRequestTracker::new(),
186        };
187
188        let task = spawner.spawn("vmbus client", async move {
189            task.run().await;
190            task
191        });
192
193        VmbusClient {
194            access: VmbusClientAccess {
195                client_request_send,
196            },
197            task_send,
198            task,
199        }
200    }
201}
202
203impl VmbusClient {
204    /// Connects to the server, negotiating the protocol version and retrieving
205    /// the initial list of channel offers.
206    pub async fn connect(
207        &mut self,
208        target_message_vp: u32,
209        monitor_page: Option<MonitorPageGpas>,
210        client_id: Guid,
211    ) -> Result<ConnectResult, ConnectError> {
212        let request = ConnectRequest {
213            target_message_vp,
214            monitor_page,
215            client_id,
216        };
217
218        self.access
219            .client_request_send
220            .call(ClientRequest::Connect, request)
221            .await
222            .unwrap()
223    }
224
225    pub async fn unload(self) {
226        self.access
227            .client_request_send
228            .call(ClientRequest::Unload, ())
229            .await
230            .unwrap();
231
232        self.sever().await;
233    }
234
235    pub fn access(&self) -> &VmbusClientAccess {
236        &self.access
237    }
238
239    pub fn start(&mut self) {
240        self.task_send.send(TaskRequest::Start);
241    }
242
243    pub async fn stop(&mut self) {
244        self.task_send
245            .call(TaskRequest::Stop, ())
246            .await
247            .expect("Failed to send stop request");
248    }
249
250    pub async fn save(&self) -> SavedState {
251        self.task_send
252            .call(TaskRequest::Save, ())
253            .await
254            .expect("Failed to send save request")
255    }
256
257    pub async fn restore(
258        &mut self,
259        state: SavedState,
260    ) -> Result<Option<ConnectResult>, RestoreError> {
261        self.task_send
262            .call(TaskRequest::Restore, state)
263            .await
264            .expect("Failed to send restore request")
265    }
266
267    pub async fn post_restore(&mut self) {
268        self.task_send
269            .call(TaskRequest::PostRestore, ())
270            .await
271            .expect("Failed to send post-restore request");
272    }
273
274    async fn sever(self) -> VmbusClientBuilder {
275        drop(self.task_send);
276        let task = self.task.await;
277        VmbusClientBuilder {
278            event_client: task.inner.synic.event_client,
279            msg_source: task.msg_source,
280            msg_client: task.inner.messages.poster,
281        }
282    }
283}
284
285#[derive(Debug)]
286pub struct ConnectResult {
287    pub version: VersionInfo,
288    pub offers: Vec<OfferInfo>,
289    pub offer_recv: mesh::Receiver<OfferInfo>,
290}
291
292impl VmbusClientAccess {
293    pub async fn modify(&self, request: ModifyConnectionRequest) -> ConnectionState {
294        self.client_request_send
295            .call(ClientRequest::Modify, request)
296            .await
297            .expect("Failed to send modify request")
298    }
299
300    pub fn connect_hvsock(
301        &self,
302        request: HvsockConnectRequest,
303    ) -> impl Future<Output = Option<OfferInfo>> + use<> {
304        self.client_request_send
305            .call(ClientRequest::HvsockConnect, request)
306            .map(|r| r.ok().flatten())
307    }
308}
309
310#[derive(Debug)]
311pub struct OpenRequest {
312    pub open_data: OpenData,
313    pub incoming_event: Option<Event>,
314    pub use_vtl2_connection_id: bool,
315}
316
317#[derive(Debug)]
318pub struct RestoreRequest {
319    pub incoming_event: Option<Event>,
320    // FUTURE: move to saved state, don't rely on the caller.
321    pub redirected_event_flag: Option<u16>,
322    // FUTURE: ditto
323    pub connection_id: u32,
324}
325
326/// Expresses an operation requested of the client.
327pub enum ChannelRequest {
328    Open(FailableRpc<OpenRequest, OpenOutput>),
329    Restore(FailableRpc<RestoreRequest, OpenOutput>),
330    Close(Rpc<(), ()>),
331    Gpadl(FailableRpc<GpadlRequest, ()>),
332    TeardownGpadl(Rpc<GpadlId, ()>),
333    Modify(Rpc<ModifyRequest, i32>),
334}
335
336#[derive(Debug)]
337pub struct OpenOutput {
338    // FUTURE: remove this once it's part of the saved state.
339    pub redirected_event_flag: Option<u16>,
340}
341
342impl std::fmt::Display for ChannelRequest {
343    fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
344        let s = match self {
345            ChannelRequest::Open(_) => "Open",
346            ChannelRequest::Close(_) => "Close",
347            ChannelRequest::Restore(_) => "Restore",
348            ChannelRequest::Gpadl(_) => "Gpadl",
349            ChannelRequest::TeardownGpadl(_) => "TeardownGpadl",
350            ChannelRequest::Modify(_) => "Modify",
351        };
352        fmt.pad(s)
353    }
354}
355
356#[derive(Debug, Error)]
357pub enum RestoreError {
358    #[error("unsupported protocol version {0:#x}")]
359    UnsupportedVersion(u32),
360
361    #[error("unsupported feature flags {0:#x}")]
362    UnsupportedFeatureFlags(u32),
363
364    #[error("duplicate channel id {0}")]
365    DuplicateChannelId(u32),
366
367    #[error("duplicate gpadl id {0}")]
368    DuplicateGpadlId(u32),
369
370    #[error("gpadl for unknown channel id {0}")]
371    GpadlForUnknownChannelId(u32),
372
373    #[error("invalid pending message")]
374    InvalidPendingMessage(#[source] vmbus_core::MessageTooLarge),
375
376    #[error("failed to offer restored channel")]
377    OfferFailed(#[source] anyhow::Error),
378}
379
380/// Provides the offer details from the server in addition to both a channel
381/// to request client actions and a channel to receive server responses.
382#[derive(Debug, Inspect)]
383pub struct OfferInfo {
384    pub offer: protocol::OfferChannel,
385    #[inspect(skip)]
386    pub guest_to_host_interrupt: Interrupt,
387    #[inspect(skip)]
388    pub request_send: mesh::Sender<ChannelRequest>,
389    #[inspect(skip)]
390    pub revoke_recv: mesh::OneshotReceiver<()>,
391}
392
393#[derive(Debug)]
394enum ClientRequest {
395    Connect(Rpc<ConnectRequest, Result<ConnectResult, ConnectError>>),
396    Unload(Rpc<(), ()>),
397    Modify(Rpc<ModifyConnectionRequest, ConnectionState>),
398    HvsockConnect(Rpc<HvsockConnectRequest, Option<OfferInfo>>),
399}
400
401impl std::fmt::Display for ClientRequest {
402    fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
403        let s = match self {
404            ClientRequest::Connect(..) => "Connect",
405            ClientRequest::Unload { .. } => "Unload",
406            ClientRequest::Modify(..) => "Modify",
407            ClientRequest::HvsockConnect(..) => "HvsockConnect",
408        };
409        fmt.pad(s)
410    }
411}
412
413enum TaskRequest {
414    Inspect(inspect::Deferred),
415    Save(Rpc<(), SavedState>),
416    Restore(Rpc<SavedState, Result<Option<ConnectResult>, RestoreError>>),
417    PostRestore(Rpc<(), ()>),
418    Start,
419    Stop(Rpc<(), ()>),
420}
421
422/// The overall state machine used to drive which actions the client can legally
423/// take. This primarily pertains to overall client activity but has a
424/// side-effect of limiting whether or not channels can perform actions.
425#[derive(Inspect)]
426#[inspect(external_tag)]
427enum ClientState {
428    /// The client has yet to connect to the server.
429    Disconnected,
430    /// The client has initiated contact with the server.
431    Connecting {
432        version: Version,
433        #[inspect(skip)]
434        rpc: Rpc<ConnectRequest, Result<ConnectResult, ConnectError>>,
435    },
436    /// The client has negotiated the protocol version with the server.
437    Connected {
438        version: VersionInfo,
439        #[inspect(skip)]
440        offer_send: mesh::Sender<OfferInfo>,
441    },
442    /// The client has requested offers from the server.
443    RequestingOffers {
444        version: VersionInfo,
445        #[inspect(skip)]
446        rpc: Rpc<(), Result<ConnectResult, ConnectError>>,
447        #[inspect(skip)]
448        offers: Vec<OfferInfo>,
449    },
450    /// The client has initiated an unload from the server.
451    Disconnecting {
452        version: VersionInfo,
453        #[inspect(skip)]
454        rpc: Rpc<(), ()>,
455    },
456}
457
458impl ClientState {
459    fn get_version(&self) -> Option<VersionInfo> {
460        match self {
461            ClientState::Connected { version, .. } => Some(*version),
462            ClientState::RequestingOffers { version, .. } => Some(*version),
463            ClientState::Disconnecting { version, .. } => Some(*version),
464            ClientState::Disconnected | ClientState::Connecting { .. } => None,
465        }
466    }
467}
468
469impl std::fmt::Display for ClientState {
470    fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
471        let s = match self {
472            ClientState::Disconnected => "Disconnected",
473            ClientState::Connecting { .. } => "Connecting",
474            ClientState::Connected { .. } => "Connected",
475            ClientState::RequestingOffers { .. } => "RequestingOffers",
476            ClientState::Disconnecting { .. } => "Disconnecting",
477        };
478        fmt.pad(s)
479    }
480}
481
482#[derive(Copy, Clone, Debug, Default)]
483struct ConnectRequest {
484    target_message_vp: u32,
485    monitor_page: Option<MonitorPageGpas>,
486    client_id: Guid,
487}
488
489#[derive(Copy, Clone, Debug, Default)]
490pub struct ModifyConnectionRequest {
491    pub monitor_page: Option<MonitorPageGpas>,
492}
493
494impl From<ModifyConnectionRequest> for protocol::ModifyConnection {
495    fn from(value: ModifyConnectionRequest) -> Self {
496        let monitor_page = value.monitor_page.unwrap_or_default();
497
498        Self {
499            parent_to_child_monitor_page_gpa: monitor_page.parent_to_child,
500            child_to_parent_monitor_page_gpa: monitor_page.child_to_parent,
501        }
502    }
503}
504
505/// The per-channel state which dictates which whether or not a channel can
506/// request an Open/Close. As GPADLs can happen outside this loop there is no
507/// state tied to GPADL actions.
508#[derive(Debug, Inspect)]
509#[inspect(external_tag)]
510enum ChannelState {
511    /// The channel has been offered to the client.
512    Offered,
513    /// The channel has requested the server to be opened.
514    Opening {
515        redirected_event_flag: Option<u16>,
516        #[inspect(skip)]
517        redirected_event: Option<Event>,
518        #[inspect(skip)]
519        rpc: FailableRpc<(), OpenOutput>,
520    },
521    /// The channel has been restored but not claimed.
522    Restored,
523    /// The channel has been successfully opened.
524    Opened {
525        redirected_event_flag: Option<u16>,
526        #[inspect(skip)]
527        redirected_event: Option<Event>,
528    },
529    /// The channel has been revoked by the server.
530    Revoked,
531}
532
533impl std::fmt::Display for ChannelState {
534    fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
535        let s = match self {
536            ChannelState::Opening { .. } => "Opening",
537            ChannelState::Offered => "Offered",
538            ChannelState::Opened { .. } => "Opened",
539            ChannelState::Restored => "Restored",
540            ChannelState::Revoked => "Revoked",
541        };
542        fmt.pad(s)
543    }
544}
545
546#[derive(Debug, Inspect)]
547struct Channel {
548    offer: protocol::OfferChannel,
549    // When dropped, notifies the caller the channel has been revoked.
550    #[inspect(skip)]
551    revoke_send: Option<mesh::OneshotSender<()>>,
552    state: ChannelState,
553    #[inspect(with = "|x| x.is_some()")]
554    modify_response_send: Option<Rpc<(), i32>>,
555    #[inspect(with = "|x| inspect::iter_by_key(x).map_key(|x| x.0)")]
556    gpadls: HashMap<GpadlId, GpadlState>,
557    is_client_released: bool,
558    connection_id: Arc<AtomicU32>,
559}
560
561impl Channel {
562    fn pending_request(&self) -> Option<&'static str> {
563        if self.modify_response_send.is_some() {
564            return Some("modify");
565        }
566        self.gpadls.iter().find_map(|(_, gpadl)| match gpadl {
567            GpadlState::Offered(_) => Some("creating gpadl"),
568            GpadlState::Created => None,
569            GpadlState::TearingDown { .. } => Some("tearing down gpadl"),
570        })
571    }
572}
573
574#[derive(Inspect)]
575struct ClientTask {
576    #[inspect(flatten)]
577    inner: ClientTaskInner,
578    channels: ChannelList,
579    state: ClientState,
580    hvsock_tracker: hvsock::HvsockRequestTracker,
581    running: bool,
582    #[inspect(with = "|x| x.is_some()")]
583    modify_request: Option<Rpc<ModifyConnectionRequest, ConnectionState>>,
584    #[inspect(skip)]
585    msg_source: Box<dyn VmbusMessageSource>,
586    #[inspect(skip)]
587    task_recv: mesh::Receiver<TaskRequest>,
588    #[inspect(skip)]
589    client_request_recv: mesh::Receiver<ClientRequest>,
590}
591
592impl ClientTask {
593    fn handle_initiate_contact(
594        &mut self,
595        rpc: Rpc<ConnectRequest, Result<ConnectResult, ConnectError>>,
596        version: Version,
597    ) {
598        let ClientState::Disconnected = self.state else {
599            tracing::warn!(client_state = %self.state, "invalid client state for InitiateContact");
600            rpc.complete(Err(ConnectError::InvalidState));
601            return;
602        };
603        let feature_flags = if version >= Version::Copper {
604            SUPPORTED_FEATURE_FLAGS
605        } else {
606            FeatureFlags::new()
607        };
608
609        let request = rpc.input();
610
611        tracing::debug!(version = ?version, ?feature_flags, "VmBus client connecting");
612        let target_info = protocol::TargetInfo::new()
613            .with_sint(SINT)
614            .with_vtl(VTL)
615            .with_feature_flags(feature_flags.into());
616        let monitor_page = request.monitor_page.unwrap_or_default();
617        let msg = protocol::InitiateContact2 {
618            initiate_contact: protocol::InitiateContact {
619                version_requested: version as u32,
620                target_message_vp: request.target_message_vp,
621                interrupt_page_or_target_info: target_info.into(),
622                parent_to_child_monitor_page_gpa: monitor_page.parent_to_child,
623                child_to_parent_monitor_page_gpa: monitor_page.child_to_parent,
624            },
625            client_id: request.client_id,
626        };
627
628        self.state = ClientState::Connecting { version, rpc };
629        if version < Version::Copper {
630            self.inner.messages.send(&msg.initiate_contact)
631        } else {
632            self.inner.messages.send(&msg);
633        }
634    }
635
636    fn handle_unload(&mut self, rpc: Rpc<(), ()>) {
637        tracing::debug!(%self.state, "VmBus client disconnecting");
638        self.state = ClientState::Disconnecting {
639            version: self.state.get_version().expect("invalid state for unload"),
640            rpc,
641        };
642
643        self.inner.messages.send(&protocol::Unload {});
644    }
645
646    fn handle_modify(&mut self, request: Rpc<ModifyConnectionRequest, ConnectionState>) {
647        if !matches!(self.state, ClientState::Connected { version, .. }
648            if version.feature_flags.modify_connection())
649        {
650            tracing::warn!("ModifyConnection not supported");
651            request.complete(ConnectionState::FAILED_UNKNOWN_FAILURE);
652            return;
653        }
654
655        if self.modify_request.is_some() {
656            tracing::warn!("Duplicate ModifyConnection request");
657            request.complete(ConnectionState::FAILED_UNKNOWN_FAILURE);
658            return;
659        }
660
661        let message = protocol::ModifyConnection::from(*request.input());
662        self.modify_request = Some(request);
663        self.inner.messages.send(&message);
664    }
665
666    fn handle_tl_connect(&mut self, rpc: Rpc<HvsockConnectRequest, Option<OfferInfo>>) {
667        // The client only supports protocol versions which use the newer message format.
668        // The host will not send a TlConnectRequestResult message on success, so a response to this
669        // message is not guaranteed.
670        let message = protocol::TlConnectRequest2::from(*rpc.input());
671        self.hvsock_tracker.add_request(rpc);
672        self.inner.messages.send(&message);
673    }
674
675    fn handle_client_request(&mut self, request: ClientRequest) {
676        match request {
677            ClientRequest::Connect(rpc) => {
678                self.handle_initiate_contact(rpc, *SUPPORTED_VERSIONS.last().unwrap());
679            }
680            ClientRequest::Unload(rpc) => {
681                self.handle_unload(rpc);
682            }
683            ClientRequest::Modify(request) => self.handle_modify(request),
684            ClientRequest::HvsockConnect(request) => self.handle_tl_connect(request),
685        }
686    }
687
688    fn handle_version_response(&mut self, msg: protocol::VersionResponse2) {
689        let old_state = std::mem::replace(&mut self.state, ClientState::Disconnected);
690        let ClientState::Connecting { version, rpc } = old_state else {
691            self.state = old_state;
692            tracing::warn!(
693                client_state = %self.state,
694                "invalid client state to handle VersionResponse"
695            );
696            return;
697        };
698        if msg.version_response.version_supported > 0 {
699            if msg.version_response.connection_state != ConnectionState::SUCCESSFUL {
700                rpc.complete(Err(ConnectError::FailedToConnect(
701                    msg.version_response.connection_state,
702                )));
703                return;
704            }
705
706            let feature_flags = if version >= Version::Copper {
707                FeatureFlags::from(msg.supported_features)
708            } else {
709                FeatureFlags::new()
710            };
711
712            let version = VersionInfo {
713                version,
714                feature_flags,
715            };
716
717            self.inner.messages.send(&protocol::RequestOffers {});
718            self.state = ClientState::RequestingOffers {
719                version,
720                rpc: rpc.split().1,
721                offers: Vec::new(),
722            };
723            tracing::info!(?version, "VmBus client connected, requesting offers");
724        } else {
725            let index = SUPPORTED_VERSIONS
726                .iter()
727                .position(|v| *v == version)
728                .unwrap();
729
730            if index == 0 {
731                rpc.complete(Err(ConnectError::NoSupportedVersions));
732                return;
733            }
734            let next_version = SUPPORTED_VERSIONS[index - 1];
735            tracing::debug!(
736                version = version as u32,
737                next_version = next_version as u32,
738                "Unsupported version, retrying"
739            );
740            self.handle_initiate_contact(rpc, next_version);
741        }
742    }
743
744    fn create_channel(&mut self, offer: protocol::OfferChannel) -> Result<OfferInfo> {
745        self.create_channel_core(offer, ChannelState::Offered)
746    }
747
748    fn create_channel_core(
749        &mut self,
750        offer: protocol::OfferChannel,
751        state: ChannelState,
752    ) -> Result<OfferInfo> {
753        if self.channels.0.contains_key(&offer.channel_id) {
754            anyhow::bail!("channel {:?} exists", offer.channel_id);
755        }
756        let (request_send, request_recv) = mesh::channel();
757        let (revoke_send, revoke_recv) = mesh::oneshot();
758
759        let connection_id = Arc::new(AtomicU32::new(0));
760        self.channels.0.insert(
761            offer.channel_id,
762            Channel {
763                revoke_send: Some(revoke_send),
764                offer,
765                state,
766                modify_response_send: None,
767                gpadls: HashMap::new(),
768                is_client_released: false,
769                connection_id: connection_id.clone(),
770            },
771        );
772
773        self.inner
774            .channel_requests
775            .push(TaggedStream::new(offer.channel_id, request_recv));
776
777        Ok(OfferInfo {
778            offer,
779            guest_to_host_interrupt: self.inner.synic.guest_to_host_interrupt(connection_id),
780            revoke_recv,
781            request_send,
782        })
783    }
784
785    fn handle_offer(&mut self, offer: protocol::OfferChannel) {
786        let offer_info = self
787            .create_channel(offer)
788            .expect("channel should not exist");
789
790        tracing::info!(
791                state = %self.state,
792                channel_id = offer.channel_id.0,
793                interface_id = %offer.interface_id,
794                instance_id = %offer.instance_id,
795                subchannel_index = offer.subchannel_index,
796                "received offer");
797
798        if let Some(offer) = self.hvsock_tracker.check_offer(&offer_info.offer) {
799            offer.complete(Some(offer_info));
800        } else {
801            match &mut self.state {
802                ClientState::Connected { offer_send, .. } => {
803                    offer_send.send(offer_info);
804                }
805                ClientState::RequestingOffers { offers, .. } => {
806                    offers.push(offer_info);
807                }
808                state => unreachable!("invalid client state for OfferChannel: {state}"),
809            }
810        }
811    }
812
813    fn handle_rescind(&mut self, rescind: protocol::RescindChannelOffer) -> TriedRelease {
814        let mut channel = self.channels.get_mut(rescind.channel_id);
815        tracing::info!(
816            state = %self.state,
817            channel_id = rescind.channel_id.0,
818            key = %OfferKey::from(&channel.offer),
819            "received rescind"
820        );
821        let event_flag = match std::mem::replace(&mut channel.state, ChannelState::Revoked) {
822            ChannelState::Offered => None,
823            ChannelState::Opening {
824                redirected_event_flag,
825                redirected_event: _,
826                rpc,
827            } => {
828                rpc.fail(anyhow::anyhow!("channel revoked"));
829                redirected_event_flag
830            }
831            ChannelState::Restored => None,
832            ChannelState::Opened {
833                redirected_event_flag,
834                redirected_event: _,
835            } => redirected_event_flag,
836            ChannelState::Revoked => {
837                panic!("channel id {:?} already revoked", rescind.channel_id);
838            }
839        };
840        if let Some(event_flag) = event_flag {
841            self.inner.synic.free_event_flag(event_flag);
842        }
843
844        // Drop the channel and send the revoked message to the client.
845        channel.revoke_send.take().unwrap().send(());
846
847        channel.try_release(&mut self.inner.messages)
848    }
849
850    fn handle_offers_delivered(&mut self) {
851        match std::mem::replace(&mut self.state, ClientState::Disconnected) {
852            ClientState::RequestingOffers {
853                version,
854                rpc,
855                offers,
856            } => {
857                tracing::info!(version = ?version, "VmBus client connected, offers delivered");
858                let (offer_send, offer_recv) = mesh::channel();
859                self.state = ClientState::Connected {
860                    version,
861                    offer_send,
862                };
863                rpc.complete(Ok(ConnectResult {
864                    version,
865                    offers,
866                    offer_recv,
867                }));
868            }
869            state => {
870                tracing::warn!(client_state = %state, "invalid client state for OffersDelivered");
871                self.state = state;
872            }
873        }
874    }
875
876    fn handle_gpadl_created(&mut self, request: protocol::GpadlCreated) -> TriedRelease {
877        let mut channel = self.channels.get_mut(request.channel_id);
878        let Some(gpadl_state) = channel.gpadls.get_mut(&request.gpadl_id) else {
879            panic!("GpadlCreated for unknown gpadl {:#x}", request.gpadl_id.0);
880        };
881
882        let rpc = match std::mem::replace(gpadl_state, GpadlState::Created) {
883            GpadlState::Offered(rpc) => rpc,
884            old_state => {
885                panic!(
886                    "invalid state {old_state:?} for gpadl {:#x}:{:#x}",
887                    request.channel_id.0, request.gpadl_id.0
888                );
889            }
890        };
891
892        let gpadl_created = request.status == protocol::STATUS_SUCCESS;
893        if gpadl_created {
894            rpc.complete(Ok(()));
895        } else {
896            channel.gpadls.remove(&request.gpadl_id).unwrap();
897            rpc.fail(anyhow::anyhow!(
898                "gpadl creation failed: {:#x}",
899                request.status
900            ));
901        };
902        channel.try_release(&mut self.inner.messages)
903    }
904
905    fn handle_open_result(&mut self, result: protocol::OpenResult) {
906        let mut channel = self.channels.get_mut(result.channel_id);
907        tracing::debug!(
908            channel_id = result.channel_id.0,
909            key = %OfferKey::from(&channel.offer),
910            result = result.status,
911            "received open result"
912        );
913
914        let channel_opened = result.status == protocol::STATUS_SUCCESS as u32;
915        let old_state = std::mem::replace(&mut channel.state, ChannelState::Offered);
916        let ChannelState::Opening {
917            redirected_event_flag,
918            redirected_event,
919            rpc,
920        } = old_state
921        else {
922            tracing::warn!(
923                key = %OfferKey::from(&channel.offer),
924                old_state = ?channel.state,
925                channel_opened,
926                "invalid state for open result"
927            );
928            channel.state = old_state;
929            return;
930        };
931
932        if !channel_opened {
933            if let Some(event_flag) = redirected_event_flag {
934                self.inner.synic.free_event_flag(event_flag);
935            }
936            rpc.fail(anyhow::anyhow!("open failed: {:#x}", result.status));
937            return;
938        }
939
940        channel.state = ChannelState::Opened {
941            redirected_event_flag,
942            redirected_event,
943        };
944
945        rpc.complete(Ok(OpenOutput {
946            redirected_event_flag,
947        }));
948    }
949
950    fn handle_gpadl_torndown(&mut self, request: protocol::GpadlTorndown) -> TriedRelease {
951        let Some(channel_id) = self.inner.teardown_gpadls.remove(&request.gpadl_id) else {
952            panic!("gpadl {:#x} not in teardown list", request.gpadl_id.0);
953        };
954
955        let mut channel = self.channels.get_mut(channel_id);
956        tracing::debug!(
957            gpadl_id = request.gpadl_id.0,
958            channel_id = channel_id.0,
959            key = %OfferKey::from(&channel.offer),
960            "Received GpadlTorndown"
961        );
962
963        let gpadl_state = channel
964            .gpadls
965            .remove(&request.gpadl_id)
966            .expect("gpadl validated above");
967
968        let GpadlState::TearingDown { rpcs } = gpadl_state else {
969            panic!("gpadl should be tearing down if in teardown list, state = {gpadl_state:?}");
970        };
971
972        for rpc in rpcs {
973            rpc.complete(());
974        }
975        channel.try_release(&mut self.inner.messages)
976    }
977
978    fn handle_unload_complete(&mut self) {
979        match std::mem::replace(&mut self.state, ClientState::Disconnected) {
980            ClientState::Disconnecting { version: _, rpc } => {
981                tracing::info!("VmBus client disconnected");
982                rpc.complete(());
983            }
984            state => {
985                tracing::warn!(client_state = %state, "invalid client state for UnloadComplete");
986            }
987        }
988    }
989
990    fn handle_modify_complete(&mut self, response: protocol::ModifyConnectionResponse) {
991        if let Some(request) = self.modify_request.take() {
992            request.complete(response.connection_state)
993        } else {
994            tracing::warn!("Unexpected modify complete request");
995        }
996    }
997
998    fn handle_modify_channel_response(
999        &mut self,
1000        response: protocol::ModifyChannelResponse,
1001    ) -> TriedRelease {
1002        let mut channel = self.channels.get_mut(response.channel_id);
1003        let Some(sender) = channel.modify_response_send.take() else {
1004            panic!(
1005                "unexpected modify channel response for channel {:#x}",
1006                response.channel_id.0
1007            );
1008        };
1009
1010        sender.complete(response.status);
1011        channel.try_release(&mut self.inner.messages)
1012    }
1013
1014    fn handle_tl_connect_result(&mut self, response: protocol::TlConnectResult) {
1015        if let Some(rpc) = self.hvsock_tracker.check_result(&response) {
1016            rpc.complete(None);
1017        }
1018    }
1019
1020    /// Returns false if the message was a pause complete message.
1021    fn handle_synic_message(&mut self, data: &[u8]) -> bool {
1022        let msg = Message::parse(data, self.state.get_version()).unwrap();
1023        tracing::trace!(?msg, "received client message from synic");
1024
1025        match msg {
1026            Message::VersionResponse3(version_response, ..) => {
1027                // The client never sends the server-specified monitor pages feature flag, but
1028                // since version response messages are distinguished only by size, the response can
1029                // still look like `VersionResponse3` if the size was not set exactly by the server.
1030                // Since the feature flag can't be set, the extra data can be ignored.
1031                self.handle_version_response(version_response.version_response2);
1032            }
1033            Message::VersionResponse2(version_response, ..) => {
1034                self.handle_version_response(version_response);
1035            }
1036            Message::VersionResponse(version_response, ..) => {
1037                self.handle_version_response(version_response.into());
1038            }
1039            Message::OfferChannel(offer, ..) => {
1040                self.handle_offer(offer);
1041            }
1042            Message::AllOffersDelivered(..) => {
1043                self.handle_offers_delivered();
1044            }
1045            Message::UnloadComplete(..) => {
1046                self.handle_unload_complete();
1047            }
1048            Message::ModifyConnectionResponse(response, ..) => {
1049                self.handle_modify_complete(response);
1050            }
1051            Message::GpadlCreated(gpadl, ..) => {
1052                self.handle_gpadl_created(gpadl);
1053            }
1054            Message::OpenResult(result, ..) => {
1055                self.handle_open_result(result);
1056            }
1057            Message::GpadlTorndown(gpadl, ..) => {
1058                self.handle_gpadl_torndown(gpadl);
1059            }
1060            Message::RescindChannelOffer(rescind, ..) => {
1061                self.handle_rescind(rescind);
1062            }
1063            Message::ModifyChannelResponse(response, ..) => {
1064                self.handle_modify_channel_response(response);
1065            }
1066            Message::TlConnectResult(response, ..) => self.handle_tl_connect_result(response),
1067            // Unsupported messages.
1068            Message::CloseReservedChannelResponse(..) => {
1069                todo!("Unsupported message {msg:?}")
1070            }
1071            Message::PauseResponse(..) => {
1072                return false;
1073            }
1074            // Messages that should only be received by a vmbus server.
1075            Message::RequestOffers(..)
1076            | Message::OpenChannel2(..)
1077            | Message::OpenChannel(..)
1078            | Message::CloseChannel(..)
1079            | Message::GpadlHeader(..)
1080            | Message::GpadlBody(..)
1081            | Message::GpadlTeardown(..)
1082            | Message::RelIdReleased(..)
1083            | Message::InitiateContact(..)
1084            | Message::InitiateContact2(..)
1085            | Message::Unload(..)
1086            | Message::OpenReservedChannel(..)
1087            | Message::CloseReservedChannel(..)
1088            | Message::TlConnectRequest2(..)
1089            | Message::TlConnectRequest(..)
1090            | Message::ModifyChannel(..)
1091            | Message::ModifyConnection(..)
1092            | Message::Pause(..)
1093            | Message::Resume(..) => {
1094                unreachable!("Client received server message {msg:?}");
1095            }
1096        }
1097        true
1098    }
1099
1100    fn handle_open_channel(
1101        &mut self,
1102        channel_id: ChannelId,
1103        rpc: FailableRpc<OpenRequest, OpenOutput>,
1104    ) {
1105        let mut channel = self.channels.get_mut(channel_id);
1106        match &channel.state {
1107            ChannelState::Offered => {}
1108            ChannelState::Revoked => {
1109                rpc.fail(anyhow::anyhow!("channel revoked"));
1110                return;
1111            }
1112            state => {
1113                rpc.fail(anyhow::anyhow!("invalid channel state: {}", state));
1114                return;
1115            }
1116        }
1117
1118        tracing::info!(
1119            channel_id = channel_id.0,
1120            key = %OfferKey::from(&channel.offer),
1121            "opening channel on host"
1122        );
1123
1124        let (request, rpc) = rpc.split();
1125        let open_data = &request.open_data;
1126
1127        let supports_interrupt_redirection =
1128            if let ClientState::Connected { version, .. } = self.state {
1129                version.feature_flags.guest_specified_signal_parameters()
1130                    || version.feature_flags.channel_interrupt_redirection()
1131            } else {
1132                false
1133            };
1134
1135        if !supports_interrupt_redirection && open_data.event_flag != channel_id.0 as u16 {
1136            rpc.fail(anyhow::anyhow!(
1137                "host does not support specifying the event flag"
1138            ));
1139            return;
1140        }
1141
1142        let open_channel = protocol::OpenChannel {
1143            channel_id,
1144            open_id: 0,
1145            ring_buffer_gpadl_id: open_data.ring_gpadl_id,
1146            target_vp: open_data
1147                .target_vp
1148                .unwrap_or(protocol::VP_INDEX_DISABLE_INTERRUPT),
1149            downstream_ring_buffer_page_offset: open_data.ring_offset,
1150            user_data: open_data.user_data,
1151        };
1152
1153        let connection_id = if request.use_vtl2_connection_id {
1154            if !supports_interrupt_redirection {
1155                rpc.fail(anyhow::anyhow!(
1156                    "host does not support specfiying the connection ID"
1157                ));
1158                return;
1159            }
1160            protocol::ConnectionId::new(channel_id.0, 2.try_into().unwrap(), 7).0
1161        } else {
1162            open_data.connection_id
1163        };
1164
1165        // No failure paths after the one for allocating the event flag, since
1166        // otherwise we would need to free the event flag.
1167        let mut flags = OpenChannelFlags::new();
1168        let event_flag = if let Some(event) = &request.incoming_event {
1169            if !supports_interrupt_redirection {
1170                rpc.fail(anyhow::anyhow!(
1171                    "host does not support redirecting interrupts"
1172                ));
1173                return;
1174            }
1175
1176            flags.set_redirect_interrupt(true);
1177            match self.inner.synic.allocate_event_flag(event) {
1178                Ok(flag) => flag,
1179                Err(err) => {
1180                    rpc.fail(err.context("failed to allocate event flag"));
1181                    return;
1182                }
1183            }
1184        } else {
1185            open_data.event_flag
1186        };
1187
1188        if supports_interrupt_redirection {
1189            self.inner.messages.send(&protocol::OpenChannel2 {
1190                open_channel,
1191                connection_id,
1192                event_flag,
1193                flags,
1194            });
1195        } else {
1196            self.inner.messages.send(&open_channel);
1197        }
1198
1199        channel
1200            .connection_id
1201            .store(connection_id, Ordering::Release);
1202        channel.state = ChannelState::Opening {
1203            redirected_event_flag: (request.incoming_event.is_some()).then_some(event_flag),
1204            redirected_event: request.incoming_event,
1205            rpc,
1206        }
1207    }
1208
1209    fn handle_restore_channel(
1210        &mut self,
1211        channel_id: ChannelId,
1212        request: RestoreRequest,
1213    ) -> Result<OpenOutput> {
1214        let mut channel = self.channels.get_mut(channel_id);
1215        if !matches!(channel.state, ChannelState::Restored) {
1216            anyhow::bail!("invalid channel state: {}", channel.state);
1217        }
1218
1219        if request.incoming_event.is_some() != request.redirected_event_flag.is_some() {
1220            anyhow::bail!("incoming event and redirected event flag must both be set or unset");
1221        }
1222
1223        if let Some((flag, event)) = request
1224            .redirected_event_flag
1225            .zip(request.incoming_event.as_ref())
1226        {
1227            self.inner.synic.restore_event_flag(flag, event)?;
1228        }
1229
1230        channel
1231            .connection_id
1232            .store(request.connection_id, Ordering::Release);
1233        channel.state = ChannelState::Opened {
1234            redirected_event_flag: request.redirected_event_flag,
1235            redirected_event: request.incoming_event,
1236        };
1237        Ok(OpenOutput {
1238            redirected_event_flag: request.redirected_event_flag,
1239        })
1240    }
1241
1242    fn handle_gpadl(&mut self, channel_id: ChannelId, rpc: FailableRpc<GpadlRequest, ()>) {
1243        let (request, rpc) = rpc.split();
1244        let mut channel = self.channels.get_mut(channel_id);
1245        if channel
1246            .gpadls
1247            .insert(request.id, GpadlState::Offered(rpc))
1248            .is_some()
1249        {
1250            panic!(
1251                "duplicate gpadl ID {:?} for channel {:?}.",
1252                request.id, channel_id
1253            );
1254        }
1255
1256        tracing::trace!(
1257            channel_id = channel_id.0,
1258            key = %OfferKey::from(&channel.offer),
1259            gpadl_id = request.id.0,
1260            count = request.count,
1261            len = request.buf.len(),
1262            "received gpadl request"
1263        );
1264
1265        // Split off the values that fit in the header.
1266        let (first, remaining) = if request.buf.len() > protocol::GpadlHeader::MAX_DATA_VALUES {
1267            request.buf.split_at(protocol::GpadlHeader::MAX_DATA_VALUES)
1268        } else {
1269            (request.buf.as_slice(), [].as_slice())
1270        };
1271
1272        let message = protocol::GpadlHeader {
1273            channel_id,
1274            gpadl_id: request.id,
1275            len: (request.buf.len() * size_of::<u64>())
1276                .try_into()
1277                .expect("Too many GPA values"),
1278            count: request.count,
1279        };
1280
1281        self.inner
1282            .messages
1283            .send_with_data(&message, first.as_bytes());
1284
1285        // Send GpadlBody messages for the remaining values.
1286        let message = protocol::GpadlBody {
1287            rsvd: 0,
1288            gpadl_id: request.id,
1289        };
1290        for chunk in remaining.chunks(protocol::GpadlBody::MAX_DATA_VALUES) {
1291            self.inner
1292                .messages
1293                .send_with_data(&message, chunk.as_bytes());
1294        }
1295    }
1296
1297    fn handle_gpadl_teardown(&mut self, channel_id: ChannelId, rpc: Rpc<GpadlId, ()>) {
1298        let (gpadl_id, rpc) = rpc.split();
1299        let mut channel = self.channels.get_mut(channel_id);
1300        let Some(gpadl_state) = channel.gpadls.get_mut(&gpadl_id) else {
1301            tracing::warn!(
1302                gpadl_id = gpadl_id.0,
1303                channel_id = channel_id.0,
1304                key = %OfferKey::from(&channel.offer),
1305                "Gpadl teardown for unknown gpadl or revoked channel"
1306            );
1307            return;
1308        };
1309
1310        match gpadl_state {
1311            GpadlState::Offered(_) => {
1312                tracing::warn!(
1313                    gpadl_id = gpadl_id.0,
1314                    channel_id = channel_id.0,
1315                    key = %OfferKey::from(&channel.offer),
1316                    "gpadl teardown for offered gpadl"
1317                );
1318            }
1319            GpadlState::Created => {
1320                *gpadl_state = GpadlState::TearingDown { rpcs: vec![rpc] };
1321                // The caller must guarantee that GPADL teardown requests are only made
1322                // for unique GPADL IDs. This is currently enforced in vmbus_server by
1323                // blocking GPADL teardown messages for reserved channels.
1324                assert!(
1325                    self.inner
1326                        .teardown_gpadls
1327                        .insert(gpadl_id, channel_id)
1328                        .is_none(),
1329                    "Gpadl state validated above"
1330                );
1331
1332                self.inner.messages.send(&protocol::GpadlTeardown {
1333                    channel_id,
1334                    gpadl_id,
1335                });
1336            }
1337            GpadlState::TearingDown { rpcs } => {
1338                rpcs.push(rpc);
1339            }
1340        }
1341    }
1342
1343    fn handle_close_channel(&mut self, channel_id: ChannelId) {
1344        let mut channel = self.channels.get_mut(channel_id);
1345        self.inner.close_channel(channel_id, &mut channel);
1346    }
1347
1348    fn handle_modify_channel(&mut self, channel_id: ChannelId, rpc: Rpc<ModifyRequest, i32>) {
1349        // The client doesn't support versions below Iron, so we always expect the host to send a
1350        // ModifyChannelResponse. This means we don't need to worry about sending a ChannelResponse
1351        // if that weren't supported.
1352        assert!(self.check_version(Version::Iron));
1353        let mut channel = self.channels.get_mut(channel_id);
1354        if channel.modify_response_send.is_some() {
1355            panic!("duplicate channel modify request {channel_id:?}");
1356        }
1357
1358        let (request, response) = rpc.split();
1359        channel.modify_response_send = Some(response);
1360        let payload = match request {
1361            ModifyRequest::TargetVp { target_vp } => protocol::ModifyChannel {
1362                channel_id,
1363                target_vp,
1364            },
1365        };
1366
1367        self.inner.messages.send(&payload);
1368    }
1369
1370    fn handle_channel_request(&mut self, channel_id: ChannelId, request: ChannelRequest) {
1371        match request {
1372            ChannelRequest::Open(rpc) => self.handle_open_channel(channel_id, rpc),
1373            ChannelRequest::Restore(rpc) => {
1374                rpc.handle_failable_sync(|request| self.handle_restore_channel(channel_id, request))
1375            }
1376            ChannelRequest::Gpadl(req) => self.handle_gpadl(channel_id, req),
1377            ChannelRequest::TeardownGpadl(req) => self.handle_gpadl_teardown(channel_id, req),
1378            ChannelRequest::Close(req) => {
1379                req.handle_sync(|()| self.handle_close_channel(channel_id))
1380            }
1381            ChannelRequest::Modify(req) => self.handle_modify_channel(channel_id, req),
1382        }
1383    }
1384
1385    async fn handle_task(&mut self, task: TaskRequest) {
1386        match task {
1387            TaskRequest::Inspect(deferred) => {
1388                deferred.inspect(&*self);
1389            }
1390            TaskRequest::Save(rpc) => rpc.handle_sync(|()| self.handle_save()),
1391            TaskRequest::Restore(rpc) => {
1392                rpc.handle_sync(|saved_state| self.handle_restore(saved_state))
1393            }
1394            TaskRequest::PostRestore(rpc) => rpc.handle_sync(|()| self.handle_post_restore()),
1395            TaskRequest::Start => self.handle_start(),
1396            TaskRequest::Stop(rpc) => rpc.handle(async |()| self.handle_stop().await).await,
1397        }
1398    }
1399
1400    /// Makes sure a channel is closed if the channel request stream was dropped.
1401    fn handle_device_removal(&mut self, channel_id: ChannelId) -> TriedRelease {
1402        let mut channel = self.channels.get_mut(channel_id);
1403        channel.is_client_released = true;
1404        // Close the channel if it is still open.
1405        if let ChannelState::Opened { .. } = channel.state {
1406            tracing::warn!(
1407                channel_id = channel_id.0,
1408                key = %OfferKey::from(&channel.offer),
1409                "Channel dropped without closing first"
1410            );
1411            self.inner.close_channel(channel_id, &mut channel);
1412        }
1413        channel.try_release(&mut self.inner.messages)
1414    }
1415
1416    /// Determines if the client is connected with at least the specified version.
1417    fn check_version(&self, version: Version) -> bool {
1418        matches!(self.state, ClientState::Connected { version: v, .. } if v.version >= version)
1419    }
1420
1421    fn handle_start(&mut self) {
1422        assert!(!self.running);
1423        self.msg_source.resume_message_stream();
1424        self.inner.messages.resume();
1425        self.running = true;
1426    }
1427
1428    async fn handle_stop(&mut self) {
1429        assert!(self.running);
1430
1431        loop {
1432            // Process messages until there are no more channels waiting for
1433            // responses. This is necessary to ensure that the saved state does
1434            // not have to support encoding revoked channels for which we are
1435            // waiting for GPADL or modify responses.
1436            while let Some((id, request)) = self.channels.revoked_channel_with_pending_request() {
1437                tracelimit::info_ratelimited!(
1438                    channel_id = id.0,
1439                    request,
1440                    "waiting for responses for channel"
1441                );
1442                assert!(self.process_next_message().await);
1443            }
1444
1445            if self.can_pause_resume() {
1446                self.inner.messages.pause();
1447            } else {
1448                // Mask the sint to pause the message stream. The host will
1449                // retry any queued messages after the sint is unmasked.
1450                self.msg_source.pause_message_stream();
1451                self.inner.messages.force_pause();
1452            }
1453
1454            // Continue processing messages until we hit EOF or get a pause
1455            // response.
1456            while self.process_next_message().await {}
1457
1458            // Ensure there are still no pending requests. If there are, resume
1459            // and go around again.
1460            if self
1461                .channels
1462                .revoked_channel_with_pending_request()
1463                .is_none()
1464            {
1465                break;
1466            }
1467            if !self.can_pause_resume() {
1468                self.msg_source.resume_message_stream();
1469            }
1470            self.inner.messages.resume();
1471        }
1472
1473        tracing::debug!("messages drained");
1474        // Because the run loop awaits all async operations, there is no need for rundown.
1475        self.running = false;
1476    }
1477
1478    async fn process_next_message(&mut self) -> bool {
1479        let mut buf = [0; protocol::MAX_MESSAGE_SIZE];
1480        let recv = self.msg_source.recv(&mut buf);
1481        // Concurrently flush until there is no more work to do, since pending
1482        // messages may be blocking responses from the host.
1483        let flush = async {
1484            self.inner.messages.flush_messages().await;
1485            std::future::pending().await
1486        };
1487        let size = (recv, flush)
1488            .race()
1489            .await
1490            .expect("Fatal error reading messages from synic");
1491        if size == 0 {
1492            return false;
1493        }
1494        self.handle_synic_message(&buf[..size])
1495    }
1496
1497    /// Returns whether the server supports in-band messages to pause/resume the
1498    /// message stream.
1499    ///
1500    /// For hosts where this is not supported, we mask the sint to pause new
1501    /// messages being queued to the sint, then drain the messages. This does
1502    /// not work with some host implementations, which cannot support draining
1503    /// the message queue while the sint is masked (due to the use of
1504    /// HvPostMessageDirect).
1505    fn can_pause_resume(&self) -> bool {
1506        if let ClientState::Connected { version, .. } = self.state {
1507            version.feature_flags.pause_resume()
1508        } else {
1509            false
1510        }
1511    }
1512
1513    async fn run(&mut self) {
1514        let mut buf = [0; protocol::MAX_MESSAGE_SIZE];
1515        loop {
1516            let mut message_recv =
1517                OptionFuture::from(self.running.then(|| self.msg_source.recv(&mut buf).fuse()));
1518
1519            // If there are pending outgoing messages, the host is backed up.
1520            // Try to flush the queue, and in the meantime, stop generating new
1521            // messages by stopping processing client requests, so as to avoid
1522            // the outgoing message queue growing without bound.
1523            //
1524            // We still need to process incoming messages when in this state,
1525            // even though they may generate additional outgoing messages, to
1526            // avoid a deadlock with the host. The host can always DoS the
1527            // guest, so this is not an attack vector.
1528            let host_backed_up = !self.inner.messages.is_empty();
1529            let flush_messages = OptionFuture::from(
1530                (self.running && host_backed_up)
1531                    .then(|| self.inner.messages.flush_messages().fuse()),
1532            );
1533
1534            let mut client_request_recv = OptionFuture::from(
1535                (self.running && !host_backed_up).then(|| self.client_request_recv.next()),
1536            );
1537
1538            let mut channel_requests = OptionFuture::from(
1539                (self.running && !host_backed_up)
1540                    .then(|| self.inner.channel_requests.select_next_some()),
1541            );
1542
1543            futures::select! { // merge semantics
1544                _r = pin!(flush_messages) => {}
1545                r = self.task_recv.next() => {
1546                    if let Some(task) = r {
1547                        self.handle_task(task).await;
1548                    } else {
1549                        break;
1550                    }
1551                }
1552                r = client_request_recv => {
1553                    if let Some(Some(request)) = r {
1554                        self.handle_client_request(request);
1555                    } else {
1556                        break;
1557                    }
1558                }
1559                r = channel_requests => {
1560                    match r.unwrap() {
1561                        (id, Some(request)) => self.handle_channel_request(id, request),
1562                        (id, _) => {
1563                            self.handle_device_removal(id);
1564                        }
1565                    }
1566                }
1567                r = message_recv => {
1568                    match r.unwrap() {
1569                        Ok(size) => {
1570                            if size == 0 {
1571                                panic!("Unexpected end of file reading messages from synic.");
1572                            }
1573
1574                            self.handle_synic_message(&buf[..size]);
1575                        }
1576                        Err(err) => {
1577                            panic!("Error reading messages from synic: {err:?}");
1578                        }
1579                    }
1580                }
1581                complete => break,
1582            }
1583        }
1584    }
1585}
1586
1587impl ClientTaskInner {
1588    fn close_channel(&mut self, channel_id: ChannelId, channel: &mut Channel) {
1589        if let ChannelState::Opened {
1590            redirected_event_flag,
1591            ..
1592        } = channel.state
1593        {
1594            if let Some(flag) = redirected_event_flag {
1595                self.synic.free_event_flag(flag);
1596            }
1597            tracing::info!(
1598                channel_id = channel_id.0,
1599                key = %OfferKey::from(&channel.offer),
1600                "closing channel on host"
1601            );
1602
1603            self.messages.send(&protocol::CloseChannel { channel_id });
1604            channel.state = ChannelState::Offered;
1605            channel.connection_id.store(0, Ordering::Release);
1606        } else {
1607            tracing::warn!(
1608                channel_id = channel_id.0,
1609                key = %OfferKey::from(&channel.offer),
1610                channel_state = %channel.state,
1611                "invalid channel state for close channel"
1612            );
1613        }
1614    }
1615}
1616
1617#[derive(Debug, Inspect)]
1618#[inspect(external_tag)]
1619enum GpadlState {
1620    /// GpadlHeader has been sent to the host.
1621    Offered(#[inspect(skip)] FailableRpc<(), ()>),
1622    /// Host has responded with GpadlCreated.
1623    Created,
1624    /// GpadlTeardown message has been sent to the host.
1625    TearingDown {
1626        #[inspect(skip)]
1627        rpcs: Vec<Rpc<(), ()>>,
1628    },
1629}
1630
1631#[derive(Inspect)]
1632struct OutgoingMessages {
1633    #[inspect(skip)]
1634    poster: Box<dyn PollPostMessage>,
1635    #[inspect(with = "|x| x.len()")]
1636    queued: VecDeque<OutgoingMessage>,
1637    state: OutgoingMessageState,
1638}
1639
1640#[derive(Inspect, PartialEq, Eq, Debug)]
1641enum OutgoingMessageState {
1642    Running,
1643    SendingPauseMessage,
1644    Paused,
1645}
1646
1647impl OutgoingMessages {
1648    fn send<T: IntoBytes + protocol::VmbusMessage + std::fmt::Debug + Immutable + KnownLayout>(
1649        &mut self,
1650        msg: &T,
1651    ) {
1652        self.send_with_data(msg, &[])
1653    }
1654
1655    fn send_with_data<
1656        T: IntoBytes + protocol::VmbusMessage + std::fmt::Debug + Immutable + KnownLayout,
1657    >(
1658        &mut self,
1659        msg: &T,
1660        data: &[u8],
1661    ) {
1662        tracing::trace!(typ = ?T::MESSAGE_TYPE, "Sending message to host");
1663        let msg = OutgoingMessage::with_data(msg, data);
1664        if self.queued.is_empty() && self.state == OutgoingMessageState::Running {
1665            let r = self.poster.poll_post_message(
1666                &mut Context::from_waker(std::task::Waker::noop()),
1667                protocol::VMBUS_MESSAGE_REDIRECT_CONNECTION_ID,
1668                1,
1669                msg.data(),
1670            );
1671            if let Poll::Ready(()) = r {
1672                return;
1673            }
1674        }
1675        tracing::trace!("queueing message");
1676        self.queued.push_back(msg);
1677    }
1678
1679    async fn flush_messages(&mut self) {
1680        let mut send = async |msg: &OutgoingMessage| {
1681            poll_fn(|cx| {
1682                self.poster.poll_post_message(
1683                    cx,
1684                    protocol::VMBUS_MESSAGE_REDIRECT_CONNECTION_ID,
1685                    1,
1686                    msg.data(),
1687                )
1688            })
1689            .await
1690        };
1691        match self.state {
1692            OutgoingMessageState::Running => {
1693                while let Some(msg) = self.queued.front() {
1694                    send(msg).await;
1695                    tracing::trace!("sent queued message");
1696                    self.queued.pop_front();
1697                }
1698            }
1699            OutgoingMessageState::SendingPauseMessage => {
1700                send(&OutgoingMessage::new(&protocol::Pause)).await;
1701                tracing::trace!("sent pause message");
1702                self.state = OutgoingMessageState::Paused;
1703            }
1704            OutgoingMessageState::Paused => {}
1705        }
1706    }
1707
1708    /// Pause by sending a pause message to the host. This will cause the host
1709    /// to stop sending messages after sending a pause response.
1710    fn pause(&mut self) {
1711        assert_eq!(self.state, OutgoingMessageState::Running);
1712        self.state = OutgoingMessageState::SendingPauseMessage;
1713        // Queue a resume message to be sent later.
1714        self.queued
1715            .push_front(OutgoingMessage::new(&protocol::Resume));
1716    }
1717
1718    /// Force a pause by setting the state to Paused. This is used when the
1719    /// host does not support in-band pause/resume messages, in which case
1720    /// the SINT is masked to force the host to stop sending messages.
1721    fn force_pause(&mut self) {
1722        assert_eq!(self.state, OutgoingMessageState::Running);
1723        self.state = OutgoingMessageState::Paused;
1724    }
1725
1726    fn resume(&mut self) {
1727        assert_eq!(self.state, OutgoingMessageState::Paused);
1728        self.state = OutgoingMessageState::Running;
1729    }
1730
1731    fn is_empty(&self) -> bool {
1732        self.queued.is_empty()
1733    }
1734}
1735
1736#[derive(Inspect)]
1737struct ClientTaskInner {
1738    messages: OutgoingMessages,
1739    #[inspect(with = "|x| inspect::iter_by_key(x).map_key(|id| id.0)")]
1740    teardown_gpadls: HashMap<GpadlId, ChannelId>,
1741    #[inspect(skip)]
1742    channel_requests: SelectAll<TaggedStream<ChannelId, mesh::Receiver<ChannelRequest>>>,
1743    synic: SynicState,
1744}
1745
1746#[derive(Inspect)]
1747struct SynicState {
1748    #[inspect(skip)]
1749    event_client: Arc<dyn SynicEventClient>,
1750    #[inspect(iter_by_index)]
1751    event_flag_state: Vec<bool>,
1752}
1753
1754#[derive(Inspect, Default)]
1755#[inspect(transparent)]
1756struct ChannelList(
1757    #[inspect(with = "|x| inspect::iter_by_key(x).map_key(|id| id.0)")] HashMap<ChannelId, Channel>,
1758);
1759
1760/// A reference to a channel that can be used to remove the channel from the map
1761/// as well.
1762struct ChannelRef<'a>(hash_map::OccupiedEntry<'a, ChannelId, Channel>);
1763
1764/// A tag value used to indicate that [`ChannelRef::try_release`] has been called.
1765/// This is useful as a return value for methods that might transition a channel
1766/// into a fully released state.
1767struct TriedRelease(());
1768
1769impl ChannelRef<'_> {
1770    /// If the channel has been fully released (revoked, released by the client,
1771    /// no pending requests), notifies the server and removes this channel from
1772    /// the map.
1773    fn try_release(self, messages: &mut OutgoingMessages) -> TriedRelease {
1774        if self.is_client_released
1775            && matches!(self.state, ChannelState::Revoked)
1776            && self.pending_request().is_none()
1777        {
1778            let channel_id = *self.0.key();
1779            tracelimit::info_ratelimited!(
1780                channel_id = channel_id.0,
1781                key = %OfferKey::from(&self.offer),
1782                "releasing channel"
1783            );
1784
1785            messages.send(&protocol::RelIdReleased { channel_id });
1786            self.0.remove();
1787        }
1788        TriedRelease(())
1789    }
1790}
1791
1792impl Deref for ChannelRef<'_> {
1793    type Target = Channel;
1794
1795    fn deref(&self) -> &Self::Target {
1796        self.0.get()
1797    }
1798}
1799
1800impl DerefMut for ChannelRef<'_> {
1801    fn deref_mut(&mut self) -> &mut Self::Target {
1802        self.0.get_mut()
1803    }
1804}
1805
1806impl ChannelList {
1807    fn revoked_channel_with_pending_request(&self) -> Option<(ChannelId, &'static str)> {
1808        self.0.iter().find_map(|(&id, channel)| {
1809            if !matches!(channel.state, ChannelState::Revoked) {
1810                return None;
1811            }
1812            Some((id, channel.pending_request()?))
1813        })
1814    }
1815
1816    #[track_caller]
1817    fn get_mut(&mut self, channel_id: ChannelId) -> ChannelRef<'_> {
1818        match self.0.entry(channel_id) {
1819            hash_map::Entry::Occupied(entry) => ChannelRef(entry),
1820            hash_map::Entry::Vacant(_) => {
1821                panic!("channel {:?} not found", channel_id);
1822            }
1823        }
1824    }
1825}
1826
1827impl SynicState {
1828    fn guest_to_host_interrupt(&self, connection_id: Arc<AtomicU32>) -> Interrupt {
1829        Interrupt::from_fn({
1830            let event_client = self.event_client.clone();
1831            move || {
1832                let connection_id = connection_id.load(Ordering::Acquire);
1833                if connection_id == 0 {
1834                    tracing::debug!("interrupt signal after close");
1835                    return;
1836                }
1837
1838                if let Err(err) = event_client.signal_event(connection_id, 0) {
1839                    tracelimit::warn_ratelimited!(
1840                        error = &err as &dyn std::error::Error,
1841                        "failed to signal event"
1842                    );
1843                }
1844            }
1845        })
1846    }
1847
1848    const MAX_EVENT_FLAGS: u16 = 2047;
1849
1850    fn allocate_event_flag(&mut self, event: &Event) -> Result<u16> {
1851        let i = self
1852            .event_flag_state
1853            .iter()
1854            .position(|&used| !used)
1855            .ok_or(())
1856            .or_else(|()| {
1857                if self.event_flag_state.len() >= Self::MAX_EVENT_FLAGS as usize {
1858                    anyhow::bail!("out of event flags");
1859                }
1860                self.event_flag_state.push(false);
1861                Ok(self.event_flag_state.len() - 1)
1862            })?;
1863
1864        let event_flag = (i + 1) as u16;
1865        self.event_client
1866            .map_event(event_flag, event)
1867            .context("failed to map event")?;
1868        self.event_flag_state[i] = true;
1869        Ok(event_flag)
1870    }
1871
1872    fn restore_event_flag(&mut self, flag: u16, event: &Event) -> Result<()> {
1873        let i = (flag as usize)
1874            .checked_sub(1)
1875            .context("invalid event flag")?;
1876        if i >= Self::MAX_EVENT_FLAGS as usize {
1877            anyhow::bail!("invalid event flag");
1878        }
1879        if self.event_flag_state.len() <= i {
1880            self.event_flag_state.resize(i + 1, false);
1881        }
1882        if self.event_flag_state[i] {
1883            anyhow::bail!("event flag already in use");
1884        }
1885        self.event_client
1886            .map_event(flag, event)
1887            .context("failed to map event")?;
1888        self.event_flag_state[i] = true;
1889        Ok(())
1890    }
1891
1892    fn free_event_flag(&mut self, flag: u16) {
1893        let i = flag as usize - 1;
1894        assert!(i < self.event_flag_state.len());
1895        self.event_flag_state[i] = false;
1896    }
1897}
1898
1899#[cfg(test)]
1900mod tests {
1901    use super::*;
1902    use futures_concurrency::future::Join;
1903    use guid::Guid;
1904    use pal_async::DefaultDriver;
1905    use pal_async::async_test;
1906    use pal_async::timer::PolledTimer;
1907    use protocol::TargetInfo;
1908    use std::fmt::Debug;
1909    use std::task::ready;
1910    use std::time::Duration;
1911    use test_with_tracing::test;
1912    use vmbus_core::protocol::MessageHeader;
1913    use vmbus_core::protocol::MessageType;
1914    use vmbus_core::protocol::OfferFlags;
1915    use vmbus_core::protocol::UserDefinedData;
1916    use vmbus_core::protocol::VmbusMessage;
1917    use zerocopy::FromBytes;
1918    use zerocopy::FromZeros;
1919    use zerocopy::Immutable;
1920    use zerocopy::IntoBytes;
1921    use zerocopy::KnownLayout;
1922
1923    const VMBUS_TEST_CLIENT_ID: Guid = guid::guid!("e6e6e6e6-e6e6-e6e6-e6e6-e6e6e6e6e6e6");
1924
1925    fn in_msg<T: IntoBytes + Immutable + KnownLayout>(message_type: MessageType, t: T) -> Vec<u8> {
1926        let mut data = Vec::new();
1927        data.extend_from_slice(&message_type.0.to_ne_bytes());
1928        data.extend_from_slice(&0u32.to_ne_bytes());
1929        data.extend_from_slice(t.as_bytes());
1930        data
1931    }
1932
1933    #[track_caller]
1934    fn check_message<T>(msg: OutgoingMessage, chk: T)
1935    where
1936        T: IntoBytes + FromBytes + Immutable + KnownLayout + Debug + VmbusMessage,
1937    {
1938        check_message_with_data(msg, chk, &[]);
1939    }
1940
1941    #[track_caller]
1942    fn check_message_with_data<T>(msg: OutgoingMessage, chk: T, data: &[u8])
1943    where
1944        T: IntoBytes + FromBytes + Immutable + KnownLayout + Debug + VmbusMessage,
1945    {
1946        let chk_data = OutgoingMessage::with_data(&chk, data);
1947        if msg.data() != chk_data.data() {
1948            let (header, rest) = MessageHeader::read_from_prefix(msg.data()).unwrap();
1949            assert_eq!(header.message_type(), <T as VmbusMessage>::MESSAGE_TYPE);
1950            let (msg, rest) = T::read_from_prefix(rest).expect("incorrect message size");
1951            if msg.as_bytes() != chk.as_bytes() {
1952                panic!("mismatched messages, expected {:#?}, got {:#?}", chk, msg);
1953            }
1954            if rest != data {
1955                panic!("mismatched data, expected {:#?}, got {:#?}", data, rest);
1956            }
1957        }
1958    }
1959
1960    struct TestServer {
1961        messages: mesh::Receiver<OutgoingMessage>,
1962        send: mesh::Sender<Vec<u8>>,
1963    }
1964
1965    impl TestServer {
1966        async fn next(&mut self) -> Option<OutgoingMessage> {
1967            self.messages.next().await
1968        }
1969
1970        fn send(&self, msg: Vec<u8>) {
1971            self.send.send(msg);
1972        }
1973
1974        async fn connect(&mut self, client: &mut VmbusClient) -> ConnectResult {
1975            self.connect_with_channels(client, |_| {}).await
1976        }
1977
1978        async fn connect_with_channels(
1979            &mut self,
1980            client: &mut VmbusClient,
1981            send_offers: impl FnOnce(&mut Self),
1982        ) -> ConnectResult {
1983            let client_connect = client.connect(0, None, Guid::ZERO);
1984
1985            let server_connect = async {
1986                let _ = self.next().await.unwrap();
1987
1988                self.send(in_msg(
1989                    MessageType::VERSION_RESPONSE,
1990                    protocol::VersionResponse2 {
1991                        version_response: protocol::VersionResponse {
1992                            version_supported: 1,
1993                            connection_state: ConnectionState::SUCCESSFUL,
1994                            padding: 0,
1995                            selected_version_or_connection_id: 0,
1996                        },
1997                        supported_features: SUPPORTED_FEATURE_FLAGS.into(),
1998                    },
1999                ));
2000
2001                check_message(self.next().await.unwrap(), protocol::RequestOffers {});
2002
2003                send_offers(self);
2004                self.send(in_msg(MessageType::ALL_OFFERS_DELIVERED, [0x00]));
2005            };
2006
2007            let (connection, ()) = (client_connect, server_connect).join().await;
2008
2009            let connection = connection.unwrap();
2010            assert_eq!(connection.version.version, Version::Copper);
2011            assert_eq!(connection.version.feature_flags, SUPPORTED_FEATURE_FLAGS);
2012            connection
2013        }
2014
2015        async fn get_channel(&mut self, client: &mut VmbusClient) -> OfferInfo {
2016            let [channel] = self
2017                .get_channels(client, 1)
2018                .await
2019                .offers
2020                .try_into()
2021                .unwrap();
2022            channel
2023        }
2024
2025        async fn get_channels(&mut self, client: &mut VmbusClient, count: usize) -> ConnectResult {
2026            self.connect_with_channels(client, |this| {
2027                for i in 0..count {
2028                    let offer = protocol::OfferChannel {
2029                        interface_id: Guid::new_random(),
2030                        instance_id: Guid::new_random(),
2031                        rsvd: [0; 4],
2032                        flags: OfferFlags::new(),
2033                        mmio_megabytes: 0,
2034                        user_defined: UserDefinedData::new_zeroed(),
2035                        subchannel_index: 0,
2036                        mmio_megabytes_optional: 0,
2037                        channel_id: ChannelId(i as u32),
2038                        monitor_id: 0,
2039                        monitor_allocated: 0,
2040                        is_dedicated: 0,
2041                        connection_id: 0,
2042                    };
2043
2044                    this.send(in_msg(MessageType::OFFER_CHANNEL, offer));
2045                }
2046            })
2047            .await
2048        }
2049
2050        async fn stop_client(&mut self, client: &mut VmbusClient) {
2051            let client_stop = client.stop();
2052            let server_stop = async {
2053                check_message(self.next().await.unwrap(), protocol::Pause);
2054                self.send(in_msg(MessageType::PAUSE_RESPONSE, protocol::PauseResponse));
2055            };
2056            (client_stop, server_stop).join().await;
2057        }
2058
2059        async fn start_client(&mut self, client: &mut VmbusClient) {
2060            client.start();
2061            check_message(self.next().await.unwrap(), protocol::Resume);
2062        }
2063    }
2064
2065    struct TestServerClient {
2066        sender: mesh::Sender<OutgoingMessage>,
2067        timer: PolledTimer,
2068        deadline: Option<pal_async::timer::Instant>,
2069    }
2070
2071    impl PollPostMessage for TestServerClient {
2072        fn poll_post_message(
2073            &mut self,
2074            cx: &mut Context<'_>,
2075            _connection_id: u32,
2076            _typ: u32,
2077            msg: &[u8],
2078        ) -> Poll<()> {
2079            loop {
2080                if let Some(deadline) = self.deadline {
2081                    ready!(self.timer.poll_until(cx, deadline));
2082                    self.deadline = None;
2083                }
2084                // Randomly choose whether to delay the message.
2085                //
2086                // FUTURE: use some kind of deterministic test framework for this to
2087                // allow for reproducible tests.
2088                let mut b = [0];
2089                getrandom::fill(&mut b).unwrap();
2090                if b[0] % 4 == 0 {
2091                    self.deadline =
2092                        Some(pal_async::timer::Instant::now() + Duration::from_millis(10));
2093                } else {
2094                    let msg = OutgoingMessage::from_message(msg).unwrap();
2095                    tracing::info!(
2096                        msg = ?MessageHeader::read_from_prefix(msg.data()),
2097                        "sending message"
2098                    );
2099                    self.sender.send(msg);
2100                    break Poll::Ready(());
2101                }
2102            }
2103        }
2104    }
2105
2106    struct NoopSynicEvents;
2107
2108    impl SynicEventClient for NoopSynicEvents {
2109        fn map_event(&self, _event_flag: u16, _event: &Event) -> std::io::Result<()> {
2110            Ok(())
2111        }
2112
2113        fn unmap_event(&self, _event_flag: u16) {}
2114
2115        fn signal_event(&self, _connection_id: u32, _event_flag: u16) -> std::io::Result<()> {
2116            Err(std::io::ErrorKind::Unsupported.into())
2117        }
2118    }
2119
2120    struct TestMessageSource {
2121        msg_recv: mesh::Receiver<Vec<u8>>,
2122        paused: bool,
2123    }
2124
2125    impl AsyncRecv for TestMessageSource {
2126        fn poll_recv(
2127            &mut self,
2128            cx: &mut Context<'_>,
2129            mut bufs: &mut [std::io::IoSliceMut<'_>],
2130        ) -> Poll<std::io::Result<usize>> {
2131            let value = match self.msg_recv.poll_recv(cx) {
2132                Poll::Ready(v) => v.unwrap(),
2133                Poll::Pending => {
2134                    if self.paused {
2135                        return Poll::Ready(Ok(0));
2136                    } else {
2137                        return Poll::Pending;
2138                    }
2139                }
2140            };
2141            let mut remaining = value.as_slice();
2142            let mut total_size = 0;
2143            while !remaining.is_empty() && !bufs.is_empty() {
2144                let size = bufs[0].len().min(remaining.len());
2145                bufs[0][..size].copy_from_slice(&remaining[..size]);
2146                remaining = &remaining[size..];
2147                bufs = &mut bufs[1..];
2148                total_size += size;
2149            }
2150
2151            Ok(total_size).into()
2152        }
2153    }
2154
2155    impl VmbusMessageSource for TestMessageSource {
2156        fn pause_message_stream(&mut self) {
2157            self.paused = true;
2158        }
2159
2160        fn resume_message_stream(&mut self) {
2161            self.paused = false;
2162        }
2163    }
2164
2165    fn test_init(driver: &DefaultDriver) -> (TestServer, VmbusClient) {
2166        let (msg_send, msg_recv) = mesh::channel();
2167        let (synic_send, synic_recv) = mesh::channel();
2168        let server = TestServer {
2169            messages: synic_recv,
2170            send: msg_send,
2171        };
2172        let mut client = VmbusClientBuilder::new(
2173            NoopSynicEvents,
2174            TestMessageSource {
2175                msg_recv,
2176                paused: false,
2177            },
2178            TestServerClient {
2179                sender: synic_send,
2180                deadline: None,
2181                timer: PolledTimer::new(driver),
2182            },
2183        )
2184        .build(driver);
2185        client.start();
2186        (server, client)
2187    }
2188
2189    #[async_test]
2190    async fn test_initiate_contact_success(driver: DefaultDriver) {
2191        let (mut server, client) = test_init(&driver);
2192        let _recv = client
2193            .access
2194            .client_request_send
2195            .call(ClientRequest::Connect, ConnectRequest::default());
2196        check_message(
2197            server.next().await.unwrap(),
2198            protocol::InitiateContact2 {
2199                initiate_contact: protocol::InitiateContact {
2200                    version_requested: Version::Copper as u32,
2201                    target_message_vp: 0,
2202                    interrupt_page_or_target_info: TargetInfo::new()
2203                        .with_sint(2)
2204                        .with_vtl(0)
2205                        .with_feature_flags(SUPPORTED_FEATURE_FLAGS.into())
2206                        .into(),
2207                    parent_to_child_monitor_page_gpa: 0,
2208                    child_to_parent_monitor_page_gpa: 0,
2209                },
2210                ..FromZeros::new_zeroed()
2211            },
2212        );
2213    }
2214
2215    #[async_test]
2216    async fn test_connect_success(driver: DefaultDriver) {
2217        let (mut server, mut client) = test_init(&driver);
2218        let client_connect = client.connect(0, None, Guid::ZERO);
2219
2220        let server_connect = async {
2221            check_message(
2222                server.next().await.unwrap(),
2223                protocol::InitiateContact2 {
2224                    initiate_contact: protocol::InitiateContact {
2225                        version_requested: Version::Copper as u32,
2226                        target_message_vp: 0,
2227                        interrupt_page_or_target_info: TargetInfo::new()
2228                            .with_sint(2)
2229                            .with_vtl(0)
2230                            .with_feature_flags(SUPPORTED_FEATURE_FLAGS.into())
2231                            .into(),
2232                        parent_to_child_monitor_page_gpa: 0,
2233                        child_to_parent_monitor_page_gpa: 0,
2234                    },
2235                    ..FromZeros::new_zeroed()
2236                },
2237            );
2238
2239            server.send(in_msg(
2240                MessageType::VERSION_RESPONSE,
2241                protocol::VersionResponse2 {
2242                    version_response: protocol::VersionResponse {
2243                        version_supported: 1,
2244                        connection_state: ConnectionState::SUCCESSFUL,
2245                        padding: 0,
2246                        selected_version_or_connection_id: 0,
2247                    },
2248                    supported_features: SUPPORTED_FEATURE_FLAGS.into_bits(),
2249                },
2250            ));
2251
2252            check_message(server.next().await.unwrap(), protocol::RequestOffers {});
2253            server.send(in_msg(MessageType::ALL_OFFERS_DELIVERED, [0x00]));
2254        };
2255
2256        let (connection, ()) = (client_connect, server_connect).join().await;
2257        let connection = connection.unwrap();
2258
2259        assert_eq!(connection.version.version, Version::Copper);
2260        assert_eq!(connection.version.feature_flags, SUPPORTED_FEATURE_FLAGS);
2261    }
2262
2263    #[async_test]
2264    async fn test_feature_flags(driver: DefaultDriver) {
2265        let (mut server, mut client) = test_init(&driver);
2266        let client_connect = client.connect(0, None, Guid::ZERO);
2267
2268        let server_connect = async {
2269            check_message(
2270                server.next().await.unwrap(),
2271                protocol::InitiateContact2 {
2272                    initiate_contact: protocol::InitiateContact {
2273                        version_requested: Version::Copper as u32,
2274                        target_message_vp: 0,
2275                        interrupt_page_or_target_info: TargetInfo::new()
2276                            .with_sint(2)
2277                            .with_vtl(0)
2278                            .with_feature_flags(SUPPORTED_FEATURE_FLAGS.into())
2279                            .into(),
2280                        parent_to_child_monitor_page_gpa: 0,
2281                        child_to_parent_monitor_page_gpa: 0,
2282                    },
2283                    ..FromZeros::new_zeroed()
2284                },
2285            );
2286
2287            // Report the server doesn't support some of the feature flags, and make
2288            // sure this is reflected in the returned version.
2289            server.send(in_msg(
2290                MessageType::VERSION_RESPONSE,
2291                protocol::VersionResponse2 {
2292                    version_response: protocol::VersionResponse {
2293                        version_supported: 1,
2294                        connection_state: ConnectionState::SUCCESSFUL,
2295                        padding: 0,
2296                        selected_version_or_connection_id: 0,
2297                    },
2298                    supported_features: 2,
2299                },
2300            ));
2301
2302            check_message(server.next().await.unwrap(), protocol::RequestOffers {});
2303            server.send(in_msg(MessageType::ALL_OFFERS_DELIVERED, [0x00]));
2304        };
2305
2306        let (connection, ()) = (client_connect, server_connect).join().await;
2307        let connection = connection.unwrap();
2308
2309        assert_eq!(connection.version.version, Version::Copper);
2310        assert_eq!(
2311            connection.version.feature_flags,
2312            FeatureFlags::new().with_channel_interrupt_redirection(true)
2313        );
2314    }
2315
2316    #[async_test]
2317    async fn test_client_id(driver: DefaultDriver) {
2318        let (mut server, client) = test_init(&driver);
2319        let initiate_contact = ConnectRequest {
2320            client_id: VMBUS_TEST_CLIENT_ID,
2321            ..Default::default()
2322        };
2323        let _recv = client
2324            .access
2325            .client_request_send
2326            .call(ClientRequest::Connect, initiate_contact);
2327
2328        check_message(
2329            server.next().await.unwrap(),
2330            protocol::InitiateContact2 {
2331                initiate_contact: protocol::InitiateContact {
2332                    version_requested: Version::Copper as u32,
2333                    target_message_vp: 0,
2334                    interrupt_page_or_target_info: TargetInfo::new()
2335                        .with_sint(2)
2336                        .with_vtl(0)
2337                        .with_feature_flags(SUPPORTED_FEATURE_FLAGS.into())
2338                        .into(),
2339                    parent_to_child_monitor_page_gpa: 0,
2340                    child_to_parent_monitor_page_gpa: 0,
2341                },
2342                client_id: VMBUS_TEST_CLIENT_ID,
2343            },
2344        );
2345    }
2346
2347    #[async_test]
2348    async fn test_version_negotiation(driver: DefaultDriver) {
2349        let (mut server, mut client) = test_init(&driver);
2350        let client_connect = client.connect(0, None, Guid::ZERO);
2351
2352        let server_connect = async {
2353            check_message(
2354                server.next().await.unwrap(),
2355                protocol::InitiateContact2 {
2356                    initiate_contact: protocol::InitiateContact {
2357                        version_requested: Version::Copper as u32,
2358                        target_message_vp: 0,
2359                        interrupt_page_or_target_info: TargetInfo::new()
2360                            .with_sint(2)
2361                            .with_vtl(0)
2362                            .with_feature_flags(SUPPORTED_FEATURE_FLAGS.into())
2363                            .into(),
2364                        parent_to_child_monitor_page_gpa: 0,
2365                        child_to_parent_monitor_page_gpa: 0,
2366                    },
2367                    ..FromZeros::new_zeroed()
2368                },
2369            );
2370
2371            server.send(in_msg(
2372                MessageType::VERSION_RESPONSE,
2373                protocol::VersionResponse {
2374                    version_supported: 0,
2375                    connection_state: ConnectionState::SUCCESSFUL,
2376                    padding: 0,
2377                    selected_version_or_connection_id: 0,
2378                },
2379            ));
2380
2381            check_message(
2382                server.next().await.unwrap(),
2383                protocol::InitiateContact {
2384                    version_requested: Version::Iron as u32,
2385                    target_message_vp: 0,
2386                    interrupt_page_or_target_info: TargetInfo::new()
2387                        .with_sint(2)
2388                        .with_vtl(0)
2389                        .with_feature_flags(FeatureFlags::new().into())
2390                        .into(),
2391                    parent_to_child_monitor_page_gpa: 0,
2392                    child_to_parent_monitor_page_gpa: 0,
2393                },
2394            );
2395
2396            server.send(in_msg(
2397                MessageType::VERSION_RESPONSE,
2398                protocol::VersionResponse {
2399                    version_supported: 1,
2400                    connection_state: ConnectionState::SUCCESSFUL,
2401                    padding: 0,
2402                    selected_version_or_connection_id: 0,
2403                },
2404            ));
2405
2406            check_message(server.next().await.unwrap(), protocol::RequestOffers {});
2407            server.send(in_msg(MessageType::ALL_OFFERS_DELIVERED, [0x00]));
2408        };
2409
2410        let (connection, ()) = (client_connect, server_connect).join().await;
2411        let connection = connection.unwrap();
2412
2413        assert_eq!(connection.version.version, Version::Iron);
2414        assert_eq!(connection.version.feature_flags, FeatureFlags::new());
2415    }
2416
2417    #[async_test]
2418    async fn test_open_channel_success(driver: DefaultDriver) {
2419        let (mut server, mut client) = test_init(&driver);
2420        let channel = server.get_channel(&mut client).await;
2421
2422        let recv = channel.request_send.call(
2423            ChannelRequest::Open,
2424            OpenRequest {
2425                open_data: OpenData {
2426                    target_vp: Some(0),
2427                    ring_offset: 0,
2428                    ring_gpadl_id: GpadlId(0),
2429                    event_flag: 0,
2430                    connection_id: 0,
2431                    user_data: UserDefinedData::new_zeroed(),
2432                },
2433                incoming_event: None,
2434                use_vtl2_connection_id: false,
2435            },
2436        );
2437
2438        check_message(
2439            server.next().await.unwrap(),
2440            protocol::OpenChannel2 {
2441                open_channel: protocol::OpenChannel {
2442                    channel_id: ChannelId(0),
2443                    open_id: 0,
2444                    ring_buffer_gpadl_id: GpadlId(0),
2445                    target_vp: 0,
2446                    downstream_ring_buffer_page_offset: 0,
2447                    user_data: UserDefinedData::new_zeroed(),
2448                },
2449                connection_id: 0,
2450                event_flag: 0,
2451                flags: Default::default(),
2452            },
2453        );
2454
2455        server.send(in_msg(
2456            MessageType::OPEN_CHANNEL_RESULT,
2457            protocol::OpenResult {
2458                channel_id: ChannelId(0),
2459                open_id: 0,
2460                status: protocol::STATUS_SUCCESS as u32,
2461            },
2462        ));
2463
2464        recv.await.unwrap().unwrap();
2465    }
2466
2467    #[async_test]
2468    async fn test_open_channel_fail(driver: DefaultDriver) {
2469        let (mut server, mut client) = test_init(&driver);
2470        let channel = server.get_channel(&mut client).await;
2471
2472        let recv = channel.request_send.call(
2473            ChannelRequest::Open,
2474            OpenRequest {
2475                open_data: OpenData {
2476                    target_vp: Some(0),
2477                    ring_offset: 0,
2478                    ring_gpadl_id: GpadlId(0),
2479                    event_flag: 0,
2480                    connection_id: 0,
2481                    user_data: UserDefinedData::new_zeroed(),
2482                },
2483                incoming_event: None,
2484                use_vtl2_connection_id: false,
2485            },
2486        );
2487
2488        check_message(
2489            server.next().await.unwrap(),
2490            protocol::OpenChannel2 {
2491                open_channel: protocol::OpenChannel {
2492                    channel_id: ChannelId(0),
2493                    open_id: 0,
2494                    ring_buffer_gpadl_id: GpadlId(0),
2495                    target_vp: 0,
2496                    downstream_ring_buffer_page_offset: 0,
2497                    user_data: UserDefinedData::new_zeroed(),
2498                },
2499                connection_id: 0,
2500                event_flag: 0,
2501                flags: Default::default(),
2502            },
2503        );
2504
2505        server.send(in_msg(
2506            MessageType::OPEN_CHANNEL_RESULT,
2507            protocol::OpenResult {
2508                channel_id: ChannelId(0),
2509                open_id: 0,
2510                status: protocol::STATUS_UNSUCCESSFUL as u32,
2511            },
2512        ));
2513
2514        recv.await.unwrap().unwrap_err();
2515    }
2516
2517    #[async_test]
2518    async fn test_modify_channel(driver: DefaultDriver) {
2519        let (mut server, mut client) = test_init(&driver);
2520        let channel = server.get_channel(&mut client).await;
2521
2522        // N.B. A real server requires the channel to be open before sending this, but the test
2523        //      server doesn't care.
2524        let recv = channel.request_send.call(
2525            ChannelRequest::Modify,
2526            ModifyRequest::TargetVp { target_vp: 1 },
2527        );
2528
2529        check_message(
2530            server.next().await.unwrap(),
2531            protocol::ModifyChannel {
2532                channel_id: ChannelId(0),
2533                target_vp: 1,
2534            },
2535        );
2536
2537        server.send(in_msg(
2538            MessageType::MODIFY_CHANNEL_RESPONSE,
2539            protocol::ModifyChannelResponse {
2540                channel_id: ChannelId(0),
2541                status: protocol::STATUS_SUCCESS,
2542            },
2543        ));
2544
2545        let status = recv.await.unwrap();
2546        assert_eq!(status, protocol::STATUS_SUCCESS);
2547    }
2548
2549    #[async_test]
2550    async fn test_save_restore_connected(driver: DefaultDriver) {
2551        let (mut server, mut client) = test_init(&driver);
2552        server.connect(&mut client).await;
2553        server.stop_client(&mut client).await;
2554        let s0 = client.save().await;
2555        let builder = client.sever().await;
2556        let mut client = builder.build(&driver);
2557        client.restore(s0.clone()).await.unwrap();
2558
2559        let s1 = client.save().await;
2560
2561        assert_eq!(s0, s1);
2562    }
2563
2564    #[async_test]
2565    async fn test_save_restore_connected_with_channel(driver: DefaultDriver) {
2566        let (mut server, mut client) = test_init(&driver);
2567        let c0 = server.get_channel(&mut client).await;
2568        server.stop_client(&mut client).await;
2569        let s0 = client.save().await;
2570        let builder = client.sever().await;
2571        let mut client = builder.build(&driver);
2572        let connection = client.restore(s0.clone()).await.unwrap().unwrap();
2573        let s1 = client.save().await;
2574        assert_eq!(s0, s1);
2575        assert_eq!(connection.offers[0].offer, c0.offer);
2576    }
2577
2578    #[async_test]
2579    async fn test_save_restore_connected_with_revoked_channel(driver: DefaultDriver) {
2580        let (mut server, mut client) = test_init(&driver);
2581        let c0 = server.get_channel(&mut client).await;
2582        server.send(in_msg(
2583            MessageType::RESCIND_CHANNEL_OFFER,
2584            protocol::RescindChannelOffer {
2585                channel_id: ChannelId(0),
2586            },
2587        ));
2588        c0.revoke_recv.await.unwrap();
2589        let rpc = c0.request_send.call(
2590            ChannelRequest::Modify,
2591            ModifyRequest::TargetVp { target_vp: 1 },
2592        );
2593
2594        check_message(
2595            server.next().await.unwrap(),
2596            protocol::ModifyChannel {
2597                channel_id: ChannelId(0),
2598                target_vp: 1,
2599            },
2600        );
2601
2602        let client_stop = client.stop();
2603        let server_stop = async {
2604            server.send(in_msg(
2605                MessageType::MODIFY_CHANNEL_RESPONSE,
2606                protocol::ModifyChannelResponse {
2607                    channel_id: ChannelId(0),
2608                    status: protocol::STATUS_SUCCESS,
2609                },
2610            ));
2611            check_message(server.next().await.unwrap(), protocol::Pause);
2612            server.send(in_msg(MessageType::PAUSE_RESPONSE, protocol::PauseResponse));
2613        };
2614        (client_stop, server_stop).join().await;
2615
2616        rpc.await.unwrap();
2617
2618        let s0 = client.save().await;
2619        let builder = client.sever().await;
2620        let mut client = builder.build(&driver);
2621        let connection = client.restore(s0.clone()).await.unwrap().unwrap();
2622        let s1 = client.save().await;
2623        assert_eq!(s0, s1);
2624        assert!(connection.offers.is_empty());
2625        server.start_client(&mut client).await;
2626        check_message(
2627            server.next().await.unwrap(),
2628            protocol::RelIdReleased {
2629                channel_id: ChannelId(0),
2630            },
2631        );
2632    }
2633
2634    #[async_test]
2635    async fn test_connect_fails_on_incorrect_state(driver: DefaultDriver) {
2636        let (mut server, mut client) = test_init(&driver);
2637        server.connect(&mut client).await;
2638        let err = client.connect(0, None, Guid::ZERO).await.unwrap_err();
2639        assert!(matches!(err, ConnectError::InvalidState), "{:?}", err);
2640    }
2641
2642    #[async_test]
2643    async fn test_hot_add_remove(driver: DefaultDriver) {
2644        let (mut server, mut client) = test_init(&driver);
2645
2646        let mut connection = server.connect(&mut client).await;
2647        let offer = protocol::OfferChannel {
2648            interface_id: Guid::new_random(),
2649            instance_id: Guid::new_random(),
2650            rsvd: [0; 4],
2651            flags: OfferFlags::new(),
2652            mmio_megabytes: 0,
2653            user_defined: UserDefinedData::new_zeroed(),
2654            subchannel_index: 0,
2655            mmio_megabytes_optional: 0,
2656            channel_id: ChannelId(5),
2657            monitor_id: 0,
2658            monitor_allocated: 0,
2659            is_dedicated: 0,
2660            connection_id: 0,
2661        };
2662
2663        server.send(in_msg(MessageType::OFFER_CHANNEL, offer));
2664        let info = connection.offer_recv.next().await.unwrap();
2665
2666        assert_eq!(offer, info.offer);
2667
2668        server.send(in_msg(
2669            MessageType::RESCIND_CHANNEL_OFFER,
2670            protocol::RescindChannelOffer {
2671                channel_id: ChannelId(5),
2672            },
2673        ));
2674
2675        info.revoke_recv.await.unwrap();
2676        drop(info.request_send);
2677
2678        check_message(
2679            server.next().await.unwrap(),
2680            protocol::RelIdReleased {
2681                channel_id: ChannelId(5),
2682            },
2683        );
2684    }
2685
2686    #[async_test]
2687    async fn test_gpadl_success(driver: DefaultDriver) {
2688        let (mut server, mut client) = test_init(&driver);
2689        let channel = server.get_channel(&mut client).await;
2690        let recv = channel.request_send.call(
2691            ChannelRequest::Gpadl,
2692            GpadlRequest {
2693                id: GpadlId(1),
2694                count: 1,
2695                buf: vec![5],
2696            },
2697        );
2698
2699        check_message_with_data(
2700            server.next().await.unwrap(),
2701            protocol::GpadlHeader {
2702                channel_id: ChannelId(0),
2703                gpadl_id: GpadlId(1),
2704                len: 8,
2705                count: 1,
2706            },
2707            0x5u64.as_bytes(),
2708        );
2709
2710        server.send(in_msg(
2711            MessageType::GPADL_CREATED,
2712            protocol::GpadlCreated {
2713                channel_id: ChannelId(0),
2714                gpadl_id: GpadlId(1),
2715                status: protocol::STATUS_SUCCESS,
2716            },
2717        ));
2718
2719        recv.await.unwrap().unwrap();
2720
2721        let rpc = channel
2722            .request_send
2723            .call(ChannelRequest::TeardownGpadl, GpadlId(1));
2724
2725        check_message(
2726            server.next().await.unwrap(),
2727            protocol::GpadlTeardown {
2728                channel_id: ChannelId(0),
2729                gpadl_id: GpadlId(1),
2730            },
2731        );
2732
2733        server.send(in_msg(
2734            MessageType::GPADL_TORNDOWN,
2735            protocol::GpadlTorndown {
2736                gpadl_id: GpadlId(1),
2737            },
2738        ));
2739
2740        rpc.await.unwrap();
2741    }
2742
2743    #[async_test]
2744    async fn test_gpadl_fail(driver: DefaultDriver) {
2745        let (mut server, mut client) = test_init(&driver);
2746        let channel = server.get_channel(&mut client).await;
2747        let recv = channel.request_send.call(
2748            ChannelRequest::Gpadl,
2749            GpadlRequest {
2750                id: GpadlId(1),
2751                count: 1,
2752                buf: vec![7],
2753            },
2754        );
2755
2756        check_message_with_data(
2757            server.next().await.unwrap(),
2758            protocol::GpadlHeader {
2759                channel_id: ChannelId(0),
2760                gpadl_id: GpadlId(1),
2761                len: 8,
2762                count: 1,
2763            },
2764            0x7u64.as_bytes(),
2765        );
2766
2767        server.send(in_msg(
2768            MessageType::GPADL_CREATED,
2769            protocol::GpadlCreated {
2770                channel_id: ChannelId(0),
2771                gpadl_id: GpadlId(1),
2772                status: protocol::STATUS_UNSUCCESSFUL,
2773            },
2774        ));
2775
2776        recv.await.unwrap().unwrap_err();
2777    }
2778
2779    #[async_test]
2780    async fn test_gpadl_with_revoke(driver: DefaultDriver) {
2781        let (mut server, mut client) = test_init(&driver);
2782        let channel = server.get_channel(&mut client).await;
2783        let channel_id = ChannelId(0);
2784        for gpadl_id in [1, 2, 3].map(GpadlId) {
2785            let recv = channel.request_send.call(
2786                ChannelRequest::Gpadl,
2787                GpadlRequest {
2788                    id: gpadl_id,
2789                    count: 1,
2790                    buf: vec![3],
2791                },
2792            );
2793
2794            check_message_with_data(
2795                server.next().await.unwrap(),
2796                protocol::GpadlHeader {
2797                    channel_id,
2798                    gpadl_id,
2799                    len: 8,
2800                    count: 1,
2801                },
2802                0x3u64.as_bytes(),
2803            );
2804
2805            server.send(in_msg(
2806                MessageType::GPADL_CREATED,
2807                protocol::GpadlCreated {
2808                    channel_id,
2809                    gpadl_id,
2810                    status: protocol::STATUS_SUCCESS,
2811                },
2812            ));
2813
2814            recv.await.unwrap().unwrap();
2815        }
2816
2817        let rpc = channel
2818            .request_send
2819            .call(ChannelRequest::TeardownGpadl, GpadlId(1));
2820
2821        check_message(
2822            server.next().await.unwrap(),
2823            protocol::GpadlTeardown {
2824                channel_id,
2825                gpadl_id: GpadlId(1),
2826            },
2827        );
2828
2829        server.send(in_msg(
2830            MessageType::RESCIND_CHANNEL_OFFER,
2831            protocol::RescindChannelOffer { channel_id },
2832        ));
2833
2834        let recv = channel.request_send.call_failable(
2835            ChannelRequest::Gpadl,
2836            GpadlRequest {
2837                id: GpadlId(4),
2838                count: 1,
2839                buf: vec![3],
2840            },
2841        );
2842
2843        check_message_with_data(
2844            server.next().await.unwrap(),
2845            protocol::GpadlHeader {
2846                channel_id,
2847                gpadl_id: GpadlId(4),
2848                len: 8,
2849                count: 1,
2850            },
2851            0x3u64.as_bytes(),
2852        );
2853
2854        server.send(in_msg(
2855            MessageType::GPADL_CREATED,
2856            protocol::GpadlCreated {
2857                channel_id,
2858                gpadl_id: GpadlId(4),
2859                status: protocol::STATUS_UNSUCCESSFUL,
2860            },
2861        ));
2862
2863        server.send(in_msg(
2864            MessageType::GPADL_TORNDOWN,
2865            protocol::GpadlTorndown {
2866                gpadl_id: GpadlId(1),
2867            },
2868        ));
2869
2870        rpc.await.unwrap();
2871        recv.await.unwrap_err();
2872
2873        channel.revoke_recv.await.unwrap();
2874
2875        let rpc = channel
2876            .request_send
2877            .call(ChannelRequest::TeardownGpadl, GpadlId(2));
2878        drop(channel.request_send);
2879
2880        check_message(
2881            server.next().await.unwrap(),
2882            protocol::GpadlTeardown {
2883                channel_id,
2884                gpadl_id: GpadlId(2),
2885            },
2886        );
2887
2888        server.send(in_msg(
2889            MessageType::GPADL_TORNDOWN,
2890            protocol::GpadlTorndown {
2891                gpadl_id: GpadlId(2),
2892            },
2893        ));
2894
2895        rpc.await.unwrap();
2896
2897        check_message(
2898            server.next().await.unwrap(),
2899            protocol::RelIdReleased { channel_id },
2900        );
2901    }
2902
2903    #[async_test]
2904    async fn test_modify_connection(driver: DefaultDriver) {
2905        let (mut server, mut client) = test_init(&driver);
2906        server.connect(&mut client).await;
2907        let call = client.access.client_request_send.call(
2908            ClientRequest::Modify,
2909            ModifyConnectionRequest {
2910                monitor_page: Some(MonitorPageGpas {
2911                    child_to_parent: 5,
2912                    parent_to_child: 6,
2913                }),
2914            },
2915        );
2916
2917        check_message(
2918            server.next().await.unwrap(),
2919            protocol::ModifyConnection {
2920                child_to_parent_monitor_page_gpa: 5,
2921                parent_to_child_monitor_page_gpa: 6,
2922            },
2923        );
2924
2925        server.send(in_msg(
2926            MessageType::MODIFY_CONNECTION_RESPONSE,
2927            protocol::ModifyConnectionResponse {
2928                connection_state: ConnectionState::FAILED_LOW_RESOURCES,
2929            },
2930        ));
2931
2932        let result = call.await.unwrap();
2933        assert_eq!(ConnectionState::FAILED_LOW_RESOURCES, result);
2934    }
2935
2936    #[async_test]
2937    async fn test_hvsock(driver: DefaultDriver) {
2938        let (mut server, mut client) = test_init(&driver);
2939        server.connect(&mut client).await;
2940        let request = HvsockConnectRequest {
2941            service_id: Guid::new_random(),
2942            endpoint_id: Guid::new_random(),
2943            silo_id: Guid::new_random(),
2944            hosted_silo_unaware: false,
2945        };
2946
2947        let resp = client.access().connect_hvsock(request);
2948        check_message(
2949            server.next().await.unwrap(),
2950            protocol::TlConnectRequest2 {
2951                base: protocol::TlConnectRequest {
2952                    service_id: request.service_id,
2953                    endpoint_id: request.endpoint_id,
2954                },
2955                silo_id: request.silo_id,
2956            },
2957        );
2958
2959        // Now send a failure result.
2960        server.send(in_msg(
2961            MessageType::TL_CONNECT_REQUEST_RESULT,
2962            protocol::TlConnectResult {
2963                service_id: request.service_id,
2964                endpoint_id: request.endpoint_id,
2965                status: protocol::STATUS_CONNECTION_REFUSED,
2966            },
2967        ));
2968
2969        let result = resp.await;
2970        assert!(result.is_none());
2971    }
2972
2973    #[async_test]
2974    async fn test_synic_event_flags(driver: DefaultDriver) {
2975        let (mut server, mut client) = test_init(&driver);
2976        let connection = server.get_channels(&mut client, 5).await;
2977        let event = Event::new();
2978
2979        for _ in 0..5 {
2980            for (i, channel) in connection.offers.iter().enumerate() {
2981                let recv = channel.request_send.call(
2982                    ChannelRequest::Open,
2983                    OpenRequest {
2984                        open_data: OpenData {
2985                            target_vp: Some(0),
2986                            ring_offset: 0,
2987                            ring_gpadl_id: GpadlId(0),
2988                            event_flag: 0,
2989                            connection_id: 0,
2990                            user_data: UserDefinedData::new_zeroed(),
2991                        },
2992                        incoming_event: Some(event.clone()),
2993                        use_vtl2_connection_id: false,
2994                    },
2995                );
2996
2997                let expected_event_flag = i as u16 + 1;
2998
2999                check_message(
3000                    server.next().await.unwrap(),
3001                    protocol::OpenChannel2 {
3002                        open_channel: protocol::OpenChannel {
3003                            channel_id: channel.offer.channel_id,
3004                            open_id: 0,
3005                            ring_buffer_gpadl_id: GpadlId(0),
3006                            target_vp: 0,
3007                            downstream_ring_buffer_page_offset: 0,
3008                            user_data: UserDefinedData::new_zeroed(),
3009                        },
3010                        connection_id: 0,
3011                        event_flag: expected_event_flag,
3012                        flags: OpenChannelFlags::new().with_redirect_interrupt(true),
3013                    },
3014                );
3015
3016                server.send(in_msg(
3017                    MessageType::OPEN_CHANNEL_RESULT,
3018                    protocol::OpenResult {
3019                        channel_id: channel.offer.channel_id,
3020                        open_id: 0,
3021                        status: protocol::STATUS_SUCCESS as u32,
3022                    },
3023                ));
3024
3025                let output = recv.await.unwrap().unwrap();
3026                assert_eq!(output.redirected_event_flag, Some(expected_event_flag));
3027            }
3028
3029            for (i, channel) in connection.offers.iter().enumerate() {
3030                // Close the channel to prepare for the next iteration of the loop.
3031                // The event flag should be the same each time.
3032                channel
3033                    .request_send
3034                    .call(ChannelRequest::Close, ())
3035                    .await
3036                    .unwrap();
3037
3038                check_message(
3039                    server.next().await.unwrap(),
3040                    protocol::CloseChannel {
3041                        channel_id: ChannelId(i as u32),
3042                    },
3043                );
3044            }
3045        }
3046    }
3047
3048    #[async_test]
3049    async fn test_revoke(driver: DefaultDriver) {
3050        let (mut server, mut client) = test_init(&driver);
3051        let channel = server.get_channel(&mut client).await;
3052
3053        server.send(in_msg(
3054            MessageType::RESCIND_CHANNEL_OFFER,
3055            protocol::RescindChannelOffer {
3056                channel_id: ChannelId(0),
3057            },
3058        ));
3059
3060        channel.revoke_recv.await.unwrap();
3061
3062        channel
3063            .request_send
3064            .call_failable(
3065                ChannelRequest::Open,
3066                OpenRequest {
3067                    open_data: OpenData {
3068                        target_vp: Some(0),
3069                        ring_offset: 0,
3070                        ring_gpadl_id: GpadlId(0),
3071                        event_flag: 0,
3072                        connection_id: 0,
3073                        user_data: UserDefinedData::new_zeroed(),
3074                    },
3075                    incoming_event: None,
3076                    use_vtl2_connection_id: false,
3077                },
3078            )
3079            .await
3080            .unwrap_err();
3081    }
3082
3083    #[async_test]
3084    #[should_panic(expected = "channel should not exist")]
3085    async fn test_reoffer_in_use_rel_id(driver: DefaultDriver) {
3086        let (mut server, mut client) = test_init(&driver);
3087        let mut connection = server.get_channels(&mut client, 1).await;
3088        let [channel] = connection.offers.try_into().unwrap();
3089
3090        server.send(in_msg(
3091            MessageType::RESCIND_CHANNEL_OFFER,
3092            protocol::RescindChannelOffer {
3093                channel_id: ChannelId(0),
3094            },
3095        ));
3096
3097        channel.revoke_recv.await.unwrap();
3098
3099        // This offer will cause a panic since the rel id is still in use.
3100        let offer = protocol::OfferChannel {
3101            interface_id: Guid::new_random(),
3102            instance_id: Guid::new_random(),
3103            rsvd: [0; 4],
3104            flags: OfferFlags::new(),
3105            mmio_megabytes: 0,
3106            user_defined: UserDefinedData::new_zeroed(),
3107            subchannel_index: 0,
3108            mmio_megabytes_optional: 0,
3109            channel_id: ChannelId(0),
3110            monitor_id: 0,
3111            monitor_allocated: 0,
3112            is_dedicated: 0,
3113            connection_id: 0,
3114        };
3115
3116        server.send(in_msg(MessageType::OFFER_CHANNEL, offer));
3117
3118        connection.offer_recv.next().await;
3119    }
3120
3121    #[async_test]
3122    async fn test_revoke_release_and_reoffer(driver: DefaultDriver) {
3123        let (mut server, mut client) = test_init(&driver);
3124        let mut connection = server.get_channels(&mut client, 1).await;
3125        let [channel] = connection.offers.try_into().unwrap();
3126
3127        server.send(in_msg(
3128            MessageType::RESCIND_CHANNEL_OFFER,
3129            protocol::RescindChannelOffer {
3130                channel_id: ChannelId(0),
3131            },
3132        ));
3133
3134        channel.revoke_recv.await.unwrap();
3135        drop(channel.request_send);
3136
3137        check_message(
3138            server.next().await.unwrap(),
3139            protocol::RelIdReleased {
3140                channel_id: ChannelId(0),
3141            },
3142        );
3143
3144        let offer = protocol::OfferChannel {
3145            interface_id: Guid::new_random(),
3146            instance_id: Guid::new_random(),
3147            rsvd: [0; 4],
3148            flags: OfferFlags::new(),
3149            mmio_megabytes: 0,
3150            user_defined: UserDefinedData::new_zeroed(),
3151            subchannel_index: 0,
3152            mmio_megabytes_optional: 0,
3153            channel_id: ChannelId(0),
3154            monitor_id: 0,
3155            monitor_allocated: 0,
3156            is_dedicated: 0,
3157            connection_id: 0,
3158        };
3159
3160        server.send(in_msg(MessageType::OFFER_CHANNEL, offer));
3161
3162        connection.offer_recv.next().await.unwrap();
3163    }
3164
3165    #[async_test]
3166    async fn test_release_revoke_and_reoffer(driver: DefaultDriver) {
3167        let (mut server, mut client) = test_init(&driver);
3168        let mut connection = server.get_channels(&mut client, 1).await;
3169        let [channel] = connection.offers.try_into().unwrap();
3170
3171        let open = channel.request_send.call_failable(
3172            ChannelRequest::Open,
3173            OpenRequest {
3174                open_data: OpenData {
3175                    target_vp: Some(0),
3176                    ring_offset: 0,
3177                    ring_gpadl_id: GpadlId(0),
3178                    event_flag: 0,
3179                    connection_id: 0,
3180                    user_data: UserDefinedData::new_zeroed(),
3181                },
3182                incoming_event: None,
3183                use_vtl2_connection_id: false,
3184            },
3185        );
3186
3187        let server_open = async {
3188            check_message(
3189                server.next().await.unwrap(),
3190                protocol::OpenChannel2 {
3191                    open_channel: protocol::OpenChannel {
3192                        channel_id: ChannelId(0),
3193                        open_id: 0,
3194                        ring_buffer_gpadl_id: GpadlId(0),
3195                        target_vp: 0,
3196                        downstream_ring_buffer_page_offset: 0,
3197                        user_data: UserDefinedData::new_zeroed(),
3198                    },
3199                    connection_id: 0,
3200                    event_flag: 0,
3201                    flags: Default::default(),
3202                },
3203            );
3204            server.send(in_msg(
3205                MessageType::OPEN_CHANNEL_RESULT,
3206                protocol::OpenResult {
3207                    channel_id: ChannelId(0),
3208                    open_id: 0,
3209                    status: protocol::STATUS_SUCCESS as u32,
3210                },
3211            ));
3212        };
3213
3214        (open, server_open).join().await.0.unwrap();
3215
3216        // This will close the channel but won't release it yet.
3217        drop(channel);
3218
3219        check_message(
3220            server.next().await.unwrap(),
3221            protocol::CloseChannel {
3222                channel_id: ChannelId(0),
3223            },
3224        );
3225
3226        server.send(in_msg(
3227            MessageType::RESCIND_CHANNEL_OFFER,
3228            protocol::RescindChannelOffer {
3229                channel_id: ChannelId(0),
3230            },
3231        ));
3232
3233        // Should be released.
3234        check_message(
3235            server.next().await.unwrap(),
3236            protocol::RelIdReleased {
3237                channel_id: ChannelId(0),
3238            },
3239        );
3240
3241        let offer = protocol::OfferChannel {
3242            interface_id: Guid::new_random(),
3243            instance_id: Guid::new_random(),
3244            rsvd: [0; 4],
3245            flags: OfferFlags::new(),
3246            mmio_megabytes: 0,
3247            user_defined: UserDefinedData::new_zeroed(),
3248            subchannel_index: 0,
3249            mmio_megabytes_optional: 0,
3250            channel_id: ChannelId(0),
3251            monitor_id: 0,
3252            monitor_allocated: 0,
3253            is_dedicated: 0,
3254            connection_id: 0,
3255        };
3256
3257        server.send(in_msg(MessageType::OFFER_CHANNEL, offer));
3258
3259        // New offer should come through.
3260        connection.offer_recv.next().await.unwrap();
3261    }
3262}