vmbus_client/
lib.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! Client driver for the Hyper-V Virtual Machine Bus (VmBus).
5
6#![expect(missing_docs)]
7#![forbid(unsafe_code)]
8
9pub mod driver;
10pub mod filter;
11mod hvsock;
12pub mod saved_state;
13
14pub use self::saved_state::SavedState;
15use anyhow::Context as _;
16use anyhow::Result;
17use futures::FutureExt;
18use futures::StreamExt;
19use futures::future::OptionFuture;
20use futures::stream::SelectAll;
21use futures_concurrency::future::Race;
22use guid::Guid;
23use inspect::Inspect;
24use mesh::rpc::FailableRpc;
25use mesh::rpc::Rpc;
26use mesh::rpc::RpcSend;
27use pal_async::task::Spawn;
28use pal_async::task::Task;
29use pal_event::Event;
30use std::collections::HashMap;
31use std::collections::VecDeque;
32use std::collections::hash_map;
33use std::convert::TryInto;
34use std::future::Future;
35use std::future::poll_fn;
36use std::ops::Deref;
37use std::ops::DerefMut;
38use std::pin::pin;
39use std::sync::Arc;
40use std::task::Context;
41use std::task::Poll;
42use thiserror::Error;
43use vmbus_async::async_dgram::AsyncRecv;
44use vmbus_async::async_dgram::AsyncRecvExt;
45use vmbus_channel::bus::GpadlRequest;
46use vmbus_channel::bus::ModifyRequest;
47use vmbus_channel::bus::OpenData;
48use vmbus_channel::gpadl::GpadlId;
49use vmbus_core::HvsockConnectRequest;
50use vmbus_core::OutgoingMessage;
51use vmbus_core::TaggedStream;
52use vmbus_core::VersionInfo;
53use vmbus_core::protocol;
54use vmbus_core::protocol::ChannelId;
55use vmbus_core::protocol::ConnectionState;
56use vmbus_core::protocol::FeatureFlags;
57use vmbus_core::protocol::Message;
58use vmbus_core::protocol::OpenChannelFlags;
59use vmbus_core::protocol::Version;
60use vmcore::interrupt::Interrupt;
61use vmcore::synic::MonitorPageGpas;
62use zerocopy::Immutable;
63use zerocopy::IntoBytes;
64use zerocopy::KnownLayout;
65
66const SINT: u8 = 2;
67const VTL: u8 = 0;
68const SUPPORTED_VERSIONS: &[Version] = &[Version::Iron, Version::Copper];
69const SUPPORTED_FEATURE_FLAGS: FeatureFlags = FeatureFlags::new()
70    .with_guest_specified_signal_parameters(true)
71    .with_channel_interrupt_redirection(true)
72    .with_modify_connection(true)
73    .with_client_id(true)
74    .with_pause_resume(true);
75
76/// The client interface synic events.
77pub trait SynicEventClient: Send + Sync {
78    /// Maps an incoming event signal on SINT7 to `event`.
79    fn map_event(&self, event_flag: u16, event: &Event) -> std::io::Result<()>;
80
81    /// Unmaps an event previously mapped with `map_event`.
82    fn unmap_event(&self, event_flag: u16);
83
84    /// Signals an event on the synic.
85    fn signal_event(&self, connection_id: u32, event_flag: u16) -> std::io::Result<()>;
86}
87
88/// A stream of vmbus messages that can be paused and resumed.
89pub trait VmbusMessageSource: AsyncRecv + Send {
90    /// Stop accepting new messages from the synic. After this is called, the message source must
91    /// return any pending messages already in the queue, and then return EOF.
92    fn pause_message_stream(&mut self) {}
93
94    /// Resume accepting new messages from the synic.
95    fn resume_message_stream(&mut self) {}
96}
97
98pub trait PollPostMessage: Send {
99    fn poll_post_message(
100        &mut self,
101        cx: &mut Context<'_>,
102        connection_id: u32,
103        typ: u32,
104        msg: &[u8],
105    ) -> Poll<()>;
106}
107
108pub struct VmbusClient {
109    task_send: mesh::Sender<TaskRequest>,
110    access: VmbusClientAccess,
111    task: Task<ClientTask>,
112}
113
114#[derive(Debug, thiserror::Error)]
115pub enum ConnectError {
116    #[error("invalid state to connect to the server")]
117    InvalidState,
118    #[error("no supported protocol versions")]
119    NoSupportedVersions,
120    #[error("failed to connect to the server: {0:?}")]
121    FailedToConnect(ConnectionState),
122}
123
124#[derive(Clone)]
125pub struct VmbusClientAccess {
126    client_request_send: mesh::Sender<ClientRequest>,
127}
128
129/// A builder for creating a [`VmbusClient`].
130pub struct VmbusClientBuilder {
131    event_client: Arc<dyn SynicEventClient>,
132    msg_source: Box<dyn VmbusMessageSource>,
133    msg_client: Box<dyn PollPostMessage>,
134}
135
136impl VmbusClientBuilder {
137    /// Creates a new instance of the builder with the given synic input.
138    pub fn new(
139        event_client: impl SynicEventClient + 'static,
140        msg_source: impl VmbusMessageSource + 'static,
141        msg_client: impl PollPostMessage + 'static,
142    ) -> Self {
143        Self {
144            event_client: Arc::new(event_client),
145            msg_source: Box::new(msg_source),
146            msg_client: Box::new(msg_client),
147        }
148    }
149
150    /// Creates a new instance with a receiver for incoming synic messages.
151    pub fn build(self, spawner: &impl Spawn) -> VmbusClient {
152        let (task_send, task_recv) = mesh::channel();
153        let (client_request_send, client_request_recv) = mesh::channel();
154
155        let inner = ClientTaskInner {
156            messages: OutgoingMessages {
157                poster: self.msg_client,
158                queued: VecDeque::new(),
159                state: OutgoingMessageState::Paused,
160            },
161            teardown_gpadls: HashMap::new(),
162            channel_requests: SelectAll::new(),
163            synic: SynicState {
164                event_flag_state: Vec::new(),
165                event_client: self.event_client,
166            },
167        };
168
169        let mut task = ClientTask {
170            inner,
171            channels: ChannelList::default(),
172            task_recv,
173            running: false,
174            msg_source: self.msg_source,
175            client_request_recv,
176            state: ClientState::Disconnected,
177            modify_request: None,
178            hvsock_tracker: hvsock::HvsockRequestTracker::new(),
179        };
180
181        let task = spawner.spawn("vmbus client", async move {
182            task.run().await;
183            task
184        });
185
186        VmbusClient {
187            access: VmbusClientAccess {
188                client_request_send,
189            },
190            task_send,
191            task,
192        }
193    }
194}
195
196impl VmbusClient {
197    /// Connects to the server, negotiating the protocol version and retrieving
198    /// the initial list of channel offers.
199    pub async fn connect(
200        &mut self,
201        target_message_vp: u32,
202        monitor_page: Option<MonitorPageGpas>,
203        client_id: Guid,
204    ) -> Result<ConnectResult, ConnectError> {
205        let request = ConnectRequest {
206            target_message_vp,
207            monitor_page,
208            client_id,
209        };
210
211        self.access
212            .client_request_send
213            .call(ClientRequest::Connect, request)
214            .await
215            .unwrap()
216    }
217
218    pub async fn unload(self) {
219        self.access
220            .client_request_send
221            .call(ClientRequest::Unload, ())
222            .await
223            .unwrap();
224
225        self.sever().await;
226    }
227
228    pub fn access(&self) -> &VmbusClientAccess {
229        &self.access
230    }
231
232    pub fn start(&mut self) {
233        self.task_send.send(TaskRequest::Start);
234    }
235
236    pub async fn stop(&mut self) {
237        self.task_send
238            .call(TaskRequest::Stop, ())
239            .await
240            .expect("Failed to send stop request");
241    }
242
243    pub async fn save(&self) -> SavedState {
244        self.task_send
245            .call(TaskRequest::Save, ())
246            .await
247            .expect("Failed to send save request")
248    }
249
250    pub async fn restore(
251        &mut self,
252        state: SavedState,
253    ) -> Result<Option<ConnectResult>, RestoreError> {
254        self.task_send
255            .call(TaskRequest::Restore, state)
256            .await
257            .expect("Failed to send restore request")
258    }
259
260    pub async fn post_restore(&mut self) {
261        self.task_send
262            .call(TaskRequest::PostRestore, ())
263            .await
264            .expect("Failed to send post-restore request");
265    }
266
267    async fn sever(self) -> VmbusClientBuilder {
268        drop(self.task_send);
269        let task = self.task.await;
270        VmbusClientBuilder {
271            event_client: task.inner.synic.event_client,
272            msg_source: task.msg_source,
273            msg_client: task.inner.messages.poster,
274        }
275    }
276}
277
278impl Inspect for VmbusClient {
279    fn inspect(&self, req: inspect::Request<'_>) {
280        self.task_send.send(TaskRequest::Inspect(req.defer()));
281    }
282}
283
284#[derive(Debug)]
285pub struct ConnectResult {
286    pub version: VersionInfo,
287    pub offers: Vec<OfferInfo>,
288    pub offer_recv: mesh::Receiver<OfferInfo>,
289}
290
291impl VmbusClientAccess {
292    pub async fn modify(&self, request: ModifyConnectionRequest) -> ConnectionState {
293        self.client_request_send
294            .call(ClientRequest::Modify, request)
295            .await
296            .expect("Failed to send modify request")
297    }
298
299    pub fn connect_hvsock(
300        &self,
301        request: HvsockConnectRequest,
302    ) -> impl Future<Output = Option<OfferInfo>> + use<> {
303        self.client_request_send
304            .call(ClientRequest::HvsockConnect, request)
305            .map(|r| r.ok().flatten())
306    }
307}
308
309#[derive(Debug)]
310pub struct OpenRequest {
311    pub open_data: OpenData,
312    pub incoming_event: Option<Event>,
313    pub use_vtl2_connection_id: bool,
314}
315
316#[derive(Debug)]
317pub struct RestoreRequest {
318    pub incoming_event: Option<Event>,
319    // FUTURE: move to saved state, don't rely on the caller.
320    pub redirected_event_flag: Option<u16>,
321    // FUTURE: ditto
322    pub connection_id: u32,
323}
324
325/// Expresses an operation requested of the client.
326pub enum ChannelRequest {
327    Open(FailableRpc<OpenRequest, OpenOutput>),
328    Restore(FailableRpc<RestoreRequest, OpenOutput>),
329    Close(Rpc<(), ()>),
330    Gpadl(FailableRpc<GpadlRequest, ()>),
331    TeardownGpadl(Rpc<GpadlId, ()>),
332    Modify(Rpc<ModifyRequest, i32>),
333}
334
335#[derive(Debug)]
336pub struct OpenOutput {
337    // FUTURE: remove this once it's part of the saved state.
338    pub redirected_event_flag: Option<u16>,
339    pub guest_to_host_signal: Interrupt,
340}
341
342impl std::fmt::Display for ChannelRequest {
343    fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
344        let s = match self {
345            ChannelRequest::Open(_) => "Open",
346            ChannelRequest::Close(_) => "Close",
347            ChannelRequest::Restore(_) => "Restore",
348            ChannelRequest::Gpadl(_) => "Gpadl",
349            ChannelRequest::TeardownGpadl(_) => "TeardownGpadl",
350            ChannelRequest::Modify(_) => "Modify",
351        };
352        fmt.pad(s)
353    }
354}
355
356#[derive(Debug, Error)]
357pub enum RestoreError {
358    #[error("unsupported protocol version {0:#x}")]
359    UnsupportedVersion(u32),
360
361    #[error("unsupported feature flags {0:#x}")]
362    UnsupportedFeatureFlags(u32),
363
364    #[error("duplicate channel id {0}")]
365    DuplicateChannelId(u32),
366
367    #[error("duplicate gpadl id {0}")]
368    DuplicateGpadlId(u32),
369
370    #[error("gpadl for unknown channel id {0}")]
371    GpadlForUnknownChannelId(u32),
372
373    #[error("invalid pending message")]
374    InvalidPendingMessage(#[source] vmbus_core::MessageTooLarge),
375
376    #[error("failed to offer restored channel")]
377    OfferFailed(#[source] anyhow::Error),
378}
379
380/// Provides the offer details from the server in addition to both a channel
381/// to request client actions and a channel to receive server responses.
382#[derive(Debug, Inspect)]
383pub struct OfferInfo {
384    pub offer: protocol::OfferChannel,
385    #[inspect(skip)]
386    pub request_send: mesh::Sender<ChannelRequest>,
387    #[inspect(skip)]
388    pub revoke_recv: mesh::OneshotReceiver<()>,
389}
390
391#[derive(Debug)]
392enum ClientRequest {
393    Connect(Rpc<ConnectRequest, Result<ConnectResult, ConnectError>>),
394    Unload(Rpc<(), ()>),
395    Modify(Rpc<ModifyConnectionRequest, ConnectionState>),
396    HvsockConnect(Rpc<HvsockConnectRequest, Option<OfferInfo>>),
397}
398
399impl std::fmt::Display for ClientRequest {
400    fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
401        let s = match self {
402            ClientRequest::Connect(..) => "Connect",
403            ClientRequest::Unload { .. } => "Unload",
404            ClientRequest::Modify(..) => "Modify",
405            ClientRequest::HvsockConnect(..) => "HvsockConnect",
406        };
407        fmt.pad(s)
408    }
409}
410
411enum TaskRequest {
412    Inspect(inspect::Deferred),
413    Save(Rpc<(), SavedState>),
414    Restore(Rpc<SavedState, Result<Option<ConnectResult>, RestoreError>>),
415    PostRestore(Rpc<(), ()>),
416    Start,
417    Stop(Rpc<(), ()>),
418}
419
420/// The overall state machine used to drive which actions the client can legally
421/// take. This primarily pertains to overall client activity but has a
422/// side-effect of limiting whether or not channels can perform actions.
423#[derive(Inspect)]
424#[inspect(external_tag)]
425enum ClientState {
426    /// The client has yet to connect to the server.
427    Disconnected,
428    /// The client has initiated contact with the server.
429    Connecting {
430        version: Version,
431        #[inspect(skip)]
432        rpc: Rpc<ConnectRequest, Result<ConnectResult, ConnectError>>,
433    },
434    /// The client has negotiated the protocol version with the server.
435    Connected {
436        version: VersionInfo,
437        #[inspect(skip)]
438        offer_send: mesh::Sender<OfferInfo>,
439    },
440    /// The client has requested offers from the server.
441    RequestingOffers {
442        version: VersionInfo,
443        #[inspect(skip)]
444        rpc: Rpc<(), Result<ConnectResult, ConnectError>>,
445        #[inspect(skip)]
446        offers: Vec<OfferInfo>,
447    },
448    /// The client has initiated an unload from the server.
449    Disconnecting {
450        version: VersionInfo,
451        #[inspect(skip)]
452        rpc: Rpc<(), ()>,
453    },
454}
455
456impl ClientState {
457    fn get_version(&self) -> Option<VersionInfo> {
458        match self {
459            ClientState::Connected { version, .. } => Some(*version),
460            ClientState::RequestingOffers { version, .. } => Some(*version),
461            ClientState::Disconnecting { version, .. } => Some(*version),
462            ClientState::Disconnected | ClientState::Connecting { .. } => None,
463        }
464    }
465}
466
467impl std::fmt::Display for ClientState {
468    fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
469        let s = match self {
470            ClientState::Disconnected => "Disconnected",
471            ClientState::Connecting { .. } => "Connecting",
472            ClientState::Connected { .. } => "Connected",
473            ClientState::RequestingOffers { .. } => "RequestingOffers",
474            ClientState::Disconnecting { .. } => "Disconnecting",
475        };
476        fmt.pad(s)
477    }
478}
479
480#[derive(Copy, Clone, Debug, Default)]
481struct ConnectRequest {
482    target_message_vp: u32,
483    monitor_page: Option<MonitorPageGpas>,
484    client_id: Guid,
485}
486
487#[derive(Copy, Clone, Debug, Default)]
488pub struct ModifyConnectionRequest {
489    pub monitor_page: Option<MonitorPageGpas>,
490}
491
492impl From<ModifyConnectionRequest> for protocol::ModifyConnection {
493    fn from(value: ModifyConnectionRequest) -> Self {
494        let monitor_page = value.monitor_page.unwrap_or_default();
495
496        Self {
497            parent_to_child_monitor_page_gpa: monitor_page.parent_to_child,
498            child_to_parent_monitor_page_gpa: monitor_page.child_to_parent,
499        }
500    }
501}
502
503/// The per-channel state which dictates which whether or not a channel can
504/// request an Open/Close. As GPADLs can happen outside this loop there is no
505/// state tied to GPADL actions.
506#[derive(Debug, Inspect)]
507#[inspect(external_tag)]
508enum ChannelState {
509    /// The channel has been offered to the client.
510    Offered,
511    /// The channel has requested the server to be opened.
512    Opening {
513        connection_id: u32,
514        redirected_event_flag: Option<u16>,
515        #[inspect(skip)]
516        redirected_event: Option<Event>,
517        #[inspect(skip)]
518        rpc: FailableRpc<(), OpenOutput>,
519    },
520    /// The channel has been restored but not claimed.
521    Restored,
522    /// The channel has been successfully opened.
523    Opened {
524        connection_id: u32,
525        redirected_event_flag: Option<u16>,
526        #[inspect(skip)]
527        redirected_event: Option<Event>,
528    },
529    /// The channel has been revoked by the server.
530    Revoked,
531}
532
533impl std::fmt::Display for ChannelState {
534    fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
535        let s = match self {
536            ChannelState::Opening { .. } => "Opening",
537            ChannelState::Offered => "Offered",
538            ChannelState::Opened { .. } => "Opened",
539            ChannelState::Restored => "Restored",
540            ChannelState::Revoked => "Revoked",
541        };
542        fmt.pad(s)
543    }
544}
545
546#[derive(Debug, Inspect)]
547struct Channel {
548    offer: protocol::OfferChannel,
549    // When dropped, notifies the caller the channel has been revoked.
550    #[inspect(skip)]
551    revoke_send: Option<mesh::OneshotSender<()>>,
552    state: ChannelState,
553    #[inspect(with = "|x| x.is_some()")]
554    modify_response_send: Option<Rpc<(), i32>>,
555    #[inspect(with = "|x| inspect::iter_by_key(x).map_key(|x| x.0)")]
556    gpadls: HashMap<GpadlId, GpadlState>,
557    is_client_released: bool,
558}
559
560impl Channel {
561    fn pending_request(&self) -> Option<&'static str> {
562        if self.modify_response_send.is_some() {
563            return Some("modify");
564        }
565        self.gpadls.iter().find_map(|(_, gpadl)| match gpadl {
566            GpadlState::Offered(_) => Some("creating gpadl"),
567            GpadlState::Created => None,
568            GpadlState::TearingDown { .. } => Some("tearing down gpadl"),
569        })
570    }
571}
572
573#[derive(Inspect)]
574struct ClientTask {
575    #[inspect(flatten)]
576    inner: ClientTaskInner,
577    channels: ChannelList,
578    state: ClientState,
579    hvsock_tracker: hvsock::HvsockRequestTracker,
580    running: bool,
581    #[inspect(with = "|x| x.is_some()")]
582    modify_request: Option<Rpc<ModifyConnectionRequest, ConnectionState>>,
583    #[inspect(skip)]
584    msg_source: Box<dyn VmbusMessageSource>,
585    #[inspect(skip)]
586    task_recv: mesh::Receiver<TaskRequest>,
587    #[inspect(skip)]
588    client_request_recv: mesh::Receiver<ClientRequest>,
589}
590
591impl ClientTask {
592    fn handle_initiate_contact(
593        &mut self,
594        rpc: Rpc<ConnectRequest, Result<ConnectResult, ConnectError>>,
595        version: Version,
596    ) {
597        let ClientState::Disconnected = self.state else {
598            tracing::warn!(client_state = %self.state, "invalid client state for InitiateContact");
599            rpc.complete(Err(ConnectError::InvalidState));
600            return;
601        };
602        let feature_flags = if version >= Version::Copper {
603            SUPPORTED_FEATURE_FLAGS
604        } else {
605            FeatureFlags::new()
606        };
607
608        let request = rpc.input();
609
610        tracing::debug!(version = ?version, ?feature_flags, "VmBus client connecting");
611        let target_info = protocol::TargetInfo::new()
612            .with_sint(SINT)
613            .with_vtl(VTL)
614            .with_feature_flags(feature_flags.into());
615        let monitor_page = request.monitor_page.unwrap_or_default();
616        let msg = protocol::InitiateContact2 {
617            initiate_contact: protocol::InitiateContact {
618                version_requested: version as u32,
619                target_message_vp: request.target_message_vp,
620                interrupt_page_or_target_info: target_info.into(),
621                parent_to_child_monitor_page_gpa: monitor_page.parent_to_child,
622                child_to_parent_monitor_page_gpa: monitor_page.child_to_parent,
623            },
624            client_id: request.client_id,
625        };
626
627        self.state = ClientState::Connecting { version, rpc };
628        if version < Version::Copper {
629            self.inner.messages.send(&msg.initiate_contact)
630        } else {
631            self.inner.messages.send(&msg);
632        }
633    }
634
635    fn handle_unload(&mut self, rpc: Rpc<(), ()>) {
636        tracing::debug!(%self.state, "VmBus client disconnecting");
637        self.state = ClientState::Disconnecting {
638            version: self.state.get_version().expect("invalid state for unload"),
639            rpc,
640        };
641
642        self.inner.messages.send(&protocol::Unload {});
643    }
644
645    fn handle_modify(&mut self, request: Rpc<ModifyConnectionRequest, ConnectionState>) {
646        if !matches!(self.state, ClientState::Connected { version, .. }
647            if version.feature_flags.modify_connection())
648        {
649            tracing::warn!("ModifyConnection not supported");
650            request.complete(ConnectionState::FAILED_UNKNOWN_FAILURE);
651            return;
652        }
653
654        if self.modify_request.is_some() {
655            tracing::warn!("Duplicate ModifyConnection request");
656            request.complete(ConnectionState::FAILED_UNKNOWN_FAILURE);
657            return;
658        }
659
660        let message = protocol::ModifyConnection::from(*request.input());
661        self.modify_request = Some(request);
662        self.inner.messages.send(&message);
663    }
664
665    fn handle_tl_connect(&mut self, rpc: Rpc<HvsockConnectRequest, Option<OfferInfo>>) {
666        // The client only supports protocol versions which use the newer message format.
667        // The host will not send a TlConnectRequestResult message on success, so a response to this
668        // message is not guaranteed.
669        let message = protocol::TlConnectRequest2::from(*rpc.input());
670        self.hvsock_tracker.add_request(rpc);
671        self.inner.messages.send(&message);
672    }
673
674    fn handle_client_request(&mut self, request: ClientRequest) {
675        match request {
676            ClientRequest::Connect(rpc) => {
677                self.handle_initiate_contact(rpc, *SUPPORTED_VERSIONS.last().unwrap());
678            }
679            ClientRequest::Unload(rpc) => {
680                self.handle_unload(rpc);
681            }
682            ClientRequest::Modify(request) => self.handle_modify(request),
683            ClientRequest::HvsockConnect(request) => self.handle_tl_connect(request),
684        }
685    }
686
687    fn handle_version_response(&mut self, msg: protocol::VersionResponse2) {
688        let old_state = std::mem::replace(&mut self.state, ClientState::Disconnected);
689        let ClientState::Connecting { version, rpc } = old_state else {
690            self.state = old_state;
691            tracing::warn!(
692                client_state = %self.state,
693                "invalid client state to handle VersionResponse"
694            );
695            return;
696        };
697        if msg.version_response.version_supported > 0 {
698            if msg.version_response.connection_state != ConnectionState::SUCCESSFUL {
699                rpc.complete(Err(ConnectError::FailedToConnect(
700                    msg.version_response.connection_state,
701                )));
702                return;
703            }
704
705            let feature_flags = if version >= Version::Copper {
706                FeatureFlags::from(msg.supported_features)
707            } else {
708                FeatureFlags::new()
709            };
710
711            let version = VersionInfo {
712                version,
713                feature_flags,
714            };
715
716            self.inner.messages.send(&protocol::RequestOffers {});
717            self.state = ClientState::RequestingOffers {
718                version,
719                rpc: rpc.split().1,
720                offers: Vec::new(),
721            };
722            tracing::info!(?version, "VmBus client connected, requesting offers");
723        } else {
724            let index = SUPPORTED_VERSIONS
725                .iter()
726                .position(|v| *v == version)
727                .unwrap();
728
729            if index == 0 {
730                rpc.complete(Err(ConnectError::NoSupportedVersions));
731                return;
732            }
733            let next_version = SUPPORTED_VERSIONS[index - 1];
734            tracing::debug!(
735                version = version as u32,
736                next_version = next_version as u32,
737                "Unsupported version, retrying"
738            );
739            self.handle_initiate_contact(rpc, next_version);
740        }
741    }
742
743    fn create_channel(&mut self, offer: protocol::OfferChannel) -> Result<OfferInfo> {
744        self.create_channel_core(offer, ChannelState::Offered)
745    }
746
747    fn create_channel_core(
748        &mut self,
749        offer: protocol::OfferChannel,
750        state: ChannelState,
751    ) -> Result<OfferInfo> {
752        if self.channels.0.contains_key(&offer.channel_id) {
753            anyhow::bail!("channel {:?} exists", offer.channel_id);
754        }
755        let (request_send, request_recv) = mesh::channel();
756        let (revoke_send, revoke_recv) = mesh::oneshot();
757
758        self.channels.0.insert(
759            offer.channel_id,
760            Channel {
761                revoke_send: Some(revoke_send),
762                offer,
763                state,
764                modify_response_send: None,
765                gpadls: HashMap::new(),
766                is_client_released: false,
767            },
768        );
769
770        self.inner
771            .channel_requests
772            .push(TaggedStream::new(offer.channel_id, request_recv));
773
774        Ok(OfferInfo {
775            offer,
776            revoke_recv,
777            request_send,
778        })
779    }
780
781    fn handle_offer(&mut self, offer: protocol::OfferChannel) {
782        let offer_info = self
783            .create_channel(offer)
784            .expect("channel should not exist");
785
786        tracing::info!(
787                state = %self.state,
788                channel_id = offer.channel_id.0,
789                interface_id = %offer.interface_id,
790                instance_id = %offer.instance_id,
791                subchannel_index = offer.subchannel_index,
792                "received offer");
793
794        if let Some(offer) = self.hvsock_tracker.check_offer(&offer_info.offer) {
795            offer.complete(Some(offer_info));
796        } else {
797            match &mut self.state {
798                ClientState::Connected { offer_send, .. } => {
799                    offer_send.send(offer_info);
800                }
801                ClientState::RequestingOffers { offers, .. } => {
802                    offers.push(offer_info);
803                }
804                state => unreachable!("invalid client state for OfferChannel: {state}"),
805            }
806        }
807    }
808
809    fn handle_rescind(&mut self, rescind: protocol::RescindChannelOffer) -> TriedRelease {
810        tracing::info!(state = %self.state, channel_id = rescind.channel_id.0, "received rescind");
811
812        let mut channel = self.channels.get_mut(rescind.channel_id);
813        let event_flag = match std::mem::replace(&mut channel.state, ChannelState::Revoked) {
814            ChannelState::Offered => None,
815            ChannelState::Opening {
816                connection_id: _,
817                redirected_event_flag,
818                redirected_event: _,
819                rpc,
820            } => {
821                rpc.fail(anyhow::anyhow!("channel revoked"));
822                redirected_event_flag
823            }
824            ChannelState::Restored => None,
825            ChannelState::Opened {
826                connection_id: _,
827                redirected_event_flag,
828                redirected_event: _,
829            } => redirected_event_flag,
830            ChannelState::Revoked => {
831                panic!("channel id {:?} already revoked", rescind.channel_id);
832            }
833        };
834        if let Some(event_flag) = event_flag {
835            self.inner.synic.free_event_flag(event_flag);
836        }
837
838        // Drop the channel and send the revoked message to the client.
839        channel.revoke_send.take().unwrap().send(());
840
841        channel.try_release(&mut self.inner.messages)
842    }
843
844    fn handle_offers_delivered(&mut self) {
845        match std::mem::replace(&mut self.state, ClientState::Disconnected) {
846            ClientState::RequestingOffers {
847                version,
848                rpc,
849                offers,
850            } => {
851                tracing::info!(version = ?version, "VmBus client connected, offers delivered");
852                let (offer_send, offer_recv) = mesh::channel();
853                self.state = ClientState::Connected {
854                    version,
855                    offer_send,
856                };
857                rpc.complete(Ok(ConnectResult {
858                    version,
859                    offers,
860                    offer_recv,
861                }));
862            }
863            state => {
864                tracing::warn!(client_state = %state, "invalid client state for OffersDelivered");
865                self.state = state;
866            }
867        }
868    }
869
870    fn handle_gpadl_created(&mut self, request: protocol::GpadlCreated) -> TriedRelease {
871        let mut channel = self.channels.get_mut(request.channel_id);
872        let Some(gpadl_state) = channel.gpadls.get_mut(&request.gpadl_id) else {
873            panic!("GpadlCreated for unknown gpadl {:#x}", request.gpadl_id.0);
874        };
875
876        let rpc = match std::mem::replace(gpadl_state, GpadlState::Created) {
877            GpadlState::Offered(rpc) => rpc,
878            old_state => {
879                panic!(
880                    "invalid state {old_state:?} for gpadl {:#x}:{:#x}",
881                    request.channel_id.0, request.gpadl_id.0
882                );
883            }
884        };
885
886        let gpadl_created = request.status == protocol::STATUS_SUCCESS;
887        if gpadl_created {
888            rpc.complete(Ok(()));
889        } else {
890            channel.gpadls.remove(&request.gpadl_id).unwrap();
891            rpc.fail(anyhow::anyhow!(
892                "gpadl creation failed: {:#x}",
893                request.status
894            ));
895        };
896        channel.try_release(&mut self.inner.messages)
897    }
898
899    fn handle_open_result(&mut self, result: protocol::OpenResult) {
900        tracing::debug!(
901            channel_id = result.channel_id.0,
902            result = result.status,
903            "received open result"
904        );
905
906        let mut channel = self.channels.get_mut(result.channel_id);
907
908        let channel_opened = result.status == protocol::STATUS_SUCCESS as u32;
909        let old_state = std::mem::replace(&mut channel.state, ChannelState::Offered);
910        let ChannelState::Opening {
911            connection_id,
912            redirected_event_flag,
913            redirected_event,
914            rpc,
915        } = old_state
916        else {
917            tracing::warn!(
918                old_state = ?channel.state,
919                channel_opened,
920                "invalid state for open result"
921            );
922            channel.state = old_state;
923            return;
924        };
925
926        if !channel_opened {
927            if let Some(event_flag) = redirected_event_flag {
928                self.inner.synic.free_event_flag(event_flag);
929            }
930            rpc.fail(anyhow::anyhow!("open failed: {:#x}", result.status));
931            return;
932        }
933
934        channel.state = ChannelState::Opened {
935            connection_id,
936            redirected_event_flag,
937            redirected_event,
938        };
939
940        rpc.complete(Ok(OpenOutput {
941            redirected_event_flag,
942            guest_to_host_signal: self.inner.synic.guest_to_host_interrupt(connection_id),
943        }));
944    }
945
946    fn handle_gpadl_torndown(&mut self, request: protocol::GpadlTorndown) -> TriedRelease {
947        let Some(channel_id) = self.inner.teardown_gpadls.remove(&request.gpadl_id) else {
948            panic!("gpadl {:#x} not in teardown list", request.gpadl_id.0);
949        };
950
951        tracing::debug!(
952            gpadl_id = request.gpadl_id.0,
953            channel_id = channel_id.0,
954            "Received GpadlTorndown"
955        );
956
957        let mut channel = self.channels.get_mut(channel_id);
958        let gpadl_state = channel
959            .gpadls
960            .remove(&request.gpadl_id)
961            .expect("gpadl validated above");
962
963        let GpadlState::TearingDown { rpcs } = gpadl_state else {
964            panic!("gpadl should be tearing down if in teardown list, state = {gpadl_state:?}");
965        };
966
967        for rpc in rpcs {
968            rpc.complete(());
969        }
970        channel.try_release(&mut self.inner.messages)
971    }
972
973    fn handle_unload_complete(&mut self) {
974        match std::mem::replace(&mut self.state, ClientState::Disconnected) {
975            ClientState::Disconnecting { version: _, rpc } => {
976                tracing::info!("VmBus client disconnected");
977                rpc.complete(());
978            }
979            state => {
980                tracing::warn!(client_state = %state, "invalid client state for UnloadComplete");
981            }
982        }
983    }
984
985    fn handle_modify_complete(&mut self, response: protocol::ModifyConnectionResponse) {
986        if let Some(request) = self.modify_request.take() {
987            request.complete(response.connection_state)
988        } else {
989            tracing::warn!("Unexpected modify complete request");
990        }
991    }
992
993    fn handle_modify_channel_response(
994        &mut self,
995        response: protocol::ModifyChannelResponse,
996    ) -> TriedRelease {
997        let mut channel = self.channels.get_mut(response.channel_id);
998        let Some(sender) = channel.modify_response_send.take() else {
999            panic!(
1000                "unexpected modify channel response for channel {:#x}",
1001                response.channel_id.0
1002            );
1003        };
1004
1005        sender.complete(response.status);
1006        channel.try_release(&mut self.inner.messages)
1007    }
1008
1009    fn handle_tl_connect_result(&mut self, response: protocol::TlConnectResult) {
1010        if let Some(rpc) = self.hvsock_tracker.check_result(&response) {
1011            rpc.complete(None);
1012        }
1013    }
1014
1015    /// Returns false if the message was a pause complete message.
1016    fn handle_synic_message(&mut self, data: &[u8]) -> bool {
1017        let msg = Message::parse(data, self.state.get_version()).unwrap();
1018        tracing::trace!(?msg, "received client message from synic");
1019
1020        match msg {
1021            Message::VersionResponse2(version_response, ..) => {
1022                self.handle_version_response(version_response);
1023            }
1024            Message::VersionResponse(version_response, ..) => {
1025                self.handle_version_response(version_response.into());
1026            }
1027            Message::OfferChannel(offer, ..) => {
1028                self.handle_offer(offer);
1029            }
1030            Message::AllOffersDelivered(..) => {
1031                self.handle_offers_delivered();
1032            }
1033            Message::UnloadComplete(..) => {
1034                self.handle_unload_complete();
1035            }
1036            Message::ModifyConnectionResponse(response, ..) => {
1037                self.handle_modify_complete(response);
1038            }
1039            Message::GpadlCreated(gpadl, ..) => {
1040                self.handle_gpadl_created(gpadl);
1041            }
1042            Message::OpenResult(result, ..) => {
1043                self.handle_open_result(result);
1044            }
1045            Message::GpadlTorndown(gpadl, ..) => {
1046                self.handle_gpadl_torndown(gpadl);
1047            }
1048            Message::RescindChannelOffer(rescind, ..) => {
1049                self.handle_rescind(rescind);
1050            }
1051            Message::ModifyChannelResponse(response, ..) => {
1052                self.handle_modify_channel_response(response);
1053            }
1054            Message::TlConnectResult(response, ..) => self.handle_tl_connect_result(response),
1055            // Unsupported messages.
1056            Message::CloseReservedChannelResponse(..) => {
1057                todo!("Unsupported message {msg:?}")
1058            }
1059            Message::PauseResponse(..) => {
1060                return false;
1061            }
1062            // Messages that should only be received by a vmbus server.
1063            Message::RequestOffers(..)
1064            | Message::OpenChannel2(..)
1065            | Message::OpenChannel(..)
1066            | Message::CloseChannel(..)
1067            | Message::GpadlHeader(..)
1068            | Message::GpadlBody(..)
1069            | Message::GpadlTeardown(..)
1070            | Message::RelIdReleased(..)
1071            | Message::InitiateContact(..)
1072            | Message::InitiateContact2(..)
1073            | Message::Unload(..)
1074            | Message::OpenReservedChannel(..)
1075            | Message::CloseReservedChannel(..)
1076            | Message::TlConnectRequest2(..)
1077            | Message::TlConnectRequest(..)
1078            | Message::ModifyChannel(..)
1079            | Message::ModifyConnection(..)
1080            | Message::Pause(..)
1081            | Message::Resume(..) => {
1082                unreachable!("Client received server message {msg:?}");
1083            }
1084        }
1085        true
1086    }
1087
1088    fn handle_open_channel(
1089        &mut self,
1090        channel_id: ChannelId,
1091        rpc: FailableRpc<OpenRequest, OpenOutput>,
1092    ) {
1093        let mut channel = self.channels.get_mut(channel_id);
1094        match &channel.state {
1095            ChannelState::Offered => {}
1096            ChannelState::Revoked => {
1097                rpc.fail(anyhow::anyhow!("channel revoked"));
1098                return;
1099            }
1100            state => {
1101                rpc.fail(anyhow::anyhow!("invalid channel state: {}", state));
1102                return;
1103            }
1104        }
1105
1106        tracing::info!(channel_id = channel_id.0, "opening channel on host");
1107        let (request, rpc) = rpc.split();
1108        let open_data = &request.open_data;
1109
1110        let supports_interrupt_redirection =
1111            if let ClientState::Connected { version, .. } = self.state {
1112                version.feature_flags.guest_specified_signal_parameters()
1113                    || version.feature_flags.channel_interrupt_redirection()
1114            } else {
1115                false
1116            };
1117
1118        if !supports_interrupt_redirection && open_data.event_flag != channel_id.0 as u16 {
1119            rpc.fail(anyhow::anyhow!(
1120                "host does not support specifying the event flag"
1121            ));
1122            return;
1123        }
1124
1125        let open_channel = protocol::OpenChannel {
1126            channel_id,
1127            open_id: 0,
1128            ring_buffer_gpadl_id: open_data.ring_gpadl_id,
1129            target_vp: open_data.target_vp,
1130            downstream_ring_buffer_page_offset: open_data.ring_offset,
1131            user_data: open_data.user_data,
1132        };
1133
1134        let connection_id = if request.use_vtl2_connection_id {
1135            if !supports_interrupt_redirection {
1136                rpc.fail(anyhow::anyhow!(
1137                    "host does not support specfiying the connection ID"
1138                ));
1139                return;
1140            }
1141            protocol::ConnectionId::new(channel_id.0, 2.try_into().unwrap(), 7).0
1142        } else {
1143            open_data.connection_id
1144        };
1145
1146        // No failure paths after the one for allocating the event flag, since
1147        // otherwise we would need to free the event flag.
1148        let mut flags = OpenChannelFlags::new();
1149        let event_flag = if let Some(event) = &request.incoming_event {
1150            if !supports_interrupt_redirection {
1151                rpc.fail(anyhow::anyhow!(
1152                    "host does not support redirecting interrupts"
1153                ));
1154                return;
1155            }
1156
1157            flags.set_redirect_interrupt(true);
1158            match self.inner.synic.allocate_event_flag(event) {
1159                Ok(flag) => flag,
1160                Err(err) => {
1161                    rpc.fail(err.context("failed to allocate event flag"));
1162                    return;
1163                }
1164            }
1165        } else {
1166            open_data.event_flag
1167        };
1168
1169        if supports_interrupt_redirection {
1170            self.inner.messages.send(&protocol::OpenChannel2 {
1171                open_channel,
1172                connection_id,
1173                event_flag,
1174                flags,
1175            });
1176        } else {
1177            self.inner.messages.send(&open_channel);
1178        }
1179
1180        channel.state = ChannelState::Opening {
1181            connection_id,
1182            redirected_event_flag: (request.incoming_event.is_some()).then_some(event_flag),
1183            redirected_event: request.incoming_event,
1184            rpc,
1185        }
1186    }
1187
1188    fn handle_restore_channel(
1189        &mut self,
1190        channel_id: ChannelId,
1191        request: RestoreRequest,
1192    ) -> Result<OpenOutput> {
1193        let mut channel = self.channels.get_mut(channel_id);
1194        if !matches!(channel.state, ChannelState::Restored) {
1195            anyhow::bail!("invalid channel state: {}", channel.state);
1196        }
1197
1198        if request.incoming_event.is_some() != request.redirected_event_flag.is_some() {
1199            anyhow::bail!("incoming event and redirected event flag must both be set or unset");
1200        }
1201
1202        if let Some((flag, event)) = request
1203            .redirected_event_flag
1204            .zip(request.incoming_event.as_ref())
1205        {
1206            self.inner.synic.restore_event_flag(flag, event)?;
1207        }
1208
1209        channel.state = ChannelState::Opened {
1210            connection_id: request.connection_id,
1211            redirected_event_flag: request.redirected_event_flag,
1212            redirected_event: request.incoming_event,
1213        };
1214        Ok(OpenOutput {
1215            redirected_event_flag: request.redirected_event_flag,
1216            guest_to_host_signal: self
1217                .inner
1218                .synic
1219                .guest_to_host_interrupt(request.connection_id),
1220        })
1221    }
1222
1223    fn handle_gpadl(&mut self, channel_id: ChannelId, rpc: FailableRpc<GpadlRequest, ()>) {
1224        let (request, rpc) = rpc.split();
1225        let mut channel = self.channels.get_mut(channel_id);
1226        if channel
1227            .gpadls
1228            .insert(request.id, GpadlState::Offered(rpc))
1229            .is_some()
1230        {
1231            panic!(
1232                "duplicate gpadl ID {:?} for channel {:?}.",
1233                request.id, channel_id
1234            );
1235        }
1236
1237        tracing::trace!(
1238            channel_id = channel_id.0,
1239            gpadl_id = request.id.0,
1240            count = request.count,
1241            len = request.buf.len(),
1242            "received gpadl request"
1243        );
1244
1245        // Split off the values that fit in the header.
1246        let (first, remaining) = if request.buf.len() > protocol::GpadlHeader::MAX_DATA_VALUES {
1247            request.buf.split_at(protocol::GpadlHeader::MAX_DATA_VALUES)
1248        } else {
1249            (request.buf.as_slice(), [].as_slice())
1250        };
1251
1252        let message = protocol::GpadlHeader {
1253            channel_id,
1254            gpadl_id: request.id,
1255            len: (request.buf.len() * size_of::<u64>())
1256                .try_into()
1257                .expect("Too many GPA values"),
1258            count: request.count,
1259        };
1260
1261        self.inner
1262            .messages
1263            .send_with_data(&message, first.as_bytes());
1264
1265        // Send GpadlBody messages for the remaining values.
1266        let message = protocol::GpadlBody {
1267            rsvd: 0,
1268            gpadl_id: request.id,
1269        };
1270        for chunk in remaining.chunks(protocol::GpadlBody::MAX_DATA_VALUES) {
1271            self.inner
1272                .messages
1273                .send_with_data(&message, chunk.as_bytes());
1274        }
1275    }
1276
1277    fn handle_gpadl_teardown(&mut self, channel_id: ChannelId, rpc: Rpc<GpadlId, ()>) {
1278        let (gpadl_id, rpc) = rpc.split();
1279        let mut channel = self.channels.get_mut(channel_id);
1280        let Some(gpadl_state) = channel.gpadls.get_mut(&gpadl_id) else {
1281            tracing::warn!(
1282                gpadl_id = gpadl_id.0,
1283                channel_id = channel_id.0,
1284                "Gpadl teardown for unknown gpadl or revoked channel"
1285            );
1286            return;
1287        };
1288
1289        match gpadl_state {
1290            GpadlState::Offered(_) => {
1291                tracing::warn!(
1292                    gpadl_id = gpadl_id.0,
1293                    channel_id = channel_id.0,
1294                    "gpadl teardown for offered gpadl"
1295                );
1296            }
1297            GpadlState::Created => {
1298                *gpadl_state = GpadlState::TearingDown { rpcs: vec![rpc] };
1299                // The caller must guarantee that GPADL teardown requests are only made
1300                // for unique GPADL IDs. This is currently enforced in vmbus_server by
1301                // blocking GPADL teardown messages for reserved channels.
1302                assert!(
1303                    self.inner
1304                        .teardown_gpadls
1305                        .insert(gpadl_id, channel_id)
1306                        .is_none(),
1307                    "Gpadl state validated above"
1308                );
1309
1310                self.inner.messages.send(&protocol::GpadlTeardown {
1311                    channel_id,
1312                    gpadl_id,
1313                });
1314            }
1315            GpadlState::TearingDown { rpcs } => {
1316                rpcs.push(rpc);
1317            }
1318        }
1319    }
1320
1321    fn handle_close_channel(&mut self, channel_id: ChannelId) {
1322        let mut channel = self.channels.get_mut(channel_id);
1323        self.inner.close_channel(channel_id, &mut channel);
1324    }
1325
1326    fn handle_modify_channel(&mut self, channel_id: ChannelId, rpc: Rpc<ModifyRequest, i32>) {
1327        // The client doesn't support versions below Iron, so we always expect the host to send a
1328        // ModifyChannelResponse. This means we don't need to worry about sending a ChannelResponse
1329        // if that weren't supported.
1330        assert!(self.check_version(Version::Iron));
1331        let mut channel = self.channels.get_mut(channel_id);
1332        if channel.modify_response_send.is_some() {
1333            panic!("duplicate channel modify request {channel_id:?}");
1334        }
1335
1336        let (request, response) = rpc.split();
1337        channel.modify_response_send = Some(response);
1338        let payload = match request {
1339            ModifyRequest::TargetVp { target_vp } => protocol::ModifyChannel {
1340                channel_id,
1341                target_vp,
1342            },
1343        };
1344
1345        self.inner.messages.send(&payload);
1346    }
1347
1348    fn handle_channel_request(&mut self, channel_id: ChannelId, request: ChannelRequest) {
1349        match request {
1350            ChannelRequest::Open(rpc) => self.handle_open_channel(channel_id, rpc),
1351            ChannelRequest::Restore(rpc) => {
1352                rpc.handle_failable_sync(|request| self.handle_restore_channel(channel_id, request))
1353            }
1354            ChannelRequest::Gpadl(req) => self.handle_gpadl(channel_id, req),
1355            ChannelRequest::TeardownGpadl(req) => self.handle_gpadl_teardown(channel_id, req),
1356            ChannelRequest::Close(req) => {
1357                req.handle_sync(|()| self.handle_close_channel(channel_id))
1358            }
1359            ChannelRequest::Modify(req) => self.handle_modify_channel(channel_id, req),
1360        }
1361    }
1362
1363    async fn handle_task(&mut self, task: TaskRequest) {
1364        match task {
1365            TaskRequest::Inspect(deferred) => {
1366                deferred.inspect(&*self);
1367            }
1368            TaskRequest::Save(rpc) => rpc.handle_sync(|()| self.handle_save()),
1369            TaskRequest::Restore(rpc) => {
1370                rpc.handle_sync(|saved_state| self.handle_restore(saved_state))
1371            }
1372            TaskRequest::PostRestore(rpc) => rpc.handle_sync(|()| self.handle_post_restore()),
1373            TaskRequest::Start => self.handle_start(),
1374            TaskRequest::Stop(rpc) => rpc.handle(async |()| self.handle_stop().await).await,
1375        }
1376    }
1377
1378    /// Makes sure a channel is closed if the channel request stream was dropped.
1379    fn handle_device_removal(&mut self, channel_id: ChannelId) -> TriedRelease {
1380        let mut channel = self.channels.get_mut(channel_id);
1381        channel.is_client_released = true;
1382        // Close the channel if it is still open.
1383        if let ChannelState::Opened { .. } = channel.state {
1384            tracing::warn!(
1385                channel_id = channel_id.0,
1386                "Channel dropped without closing first"
1387            );
1388            self.inner.close_channel(channel_id, &mut channel);
1389        }
1390        channel.try_release(&mut self.inner.messages)
1391    }
1392
1393    /// Determines if the client is connected with at least the specified version.
1394    fn check_version(&self, version: Version) -> bool {
1395        matches!(self.state, ClientState::Connected { version: v, .. } if v.version >= version)
1396    }
1397
1398    fn handle_start(&mut self) {
1399        assert!(!self.running);
1400        self.msg_source.resume_message_stream();
1401        self.inner.messages.resume();
1402        self.running = true;
1403    }
1404
1405    async fn handle_stop(&mut self) {
1406        assert!(self.running);
1407
1408        loop {
1409            // Process messages until there are no more channels waiting for
1410            // responses. This is necessary to ensure that the saved state does
1411            // not have to support encoding revoked channels for which we are
1412            // waiting for GPADL or modify responses.
1413            while let Some((id, request)) = self.channels.revoked_channel_with_pending_request() {
1414                tracelimit::info_ratelimited!(
1415                    channel_id = id.0,
1416                    request,
1417                    "waiting for responses for channel"
1418                );
1419                assert!(self.process_next_message().await);
1420            }
1421
1422            if self.can_pause_resume() {
1423                self.inner.messages.pause();
1424            } else {
1425                // Mask the sint to pause the message stream. The host will
1426                // retry any queued messages after the sint is unmasked.
1427                self.msg_source.pause_message_stream();
1428                self.inner.messages.force_pause();
1429            }
1430
1431            // Continue processing messages until we hit EOF or get a pause
1432            // response.
1433            while self.process_next_message().await {}
1434
1435            // Ensure there are still no pending requests. If there are, resume
1436            // and go around again.
1437            if self
1438                .channels
1439                .revoked_channel_with_pending_request()
1440                .is_none()
1441            {
1442                break;
1443            }
1444            if !self.can_pause_resume() {
1445                self.msg_source.resume_message_stream();
1446            }
1447            self.inner.messages.resume();
1448        }
1449
1450        tracing::debug!("messages drained");
1451        // Because the run loop awaits all async operations, there is no need for rundown.
1452        self.running = false;
1453    }
1454
1455    async fn process_next_message(&mut self) -> bool {
1456        let mut buf = [0; protocol::MAX_MESSAGE_SIZE];
1457        let recv = self.msg_source.recv(&mut buf);
1458        // Concurrently flush until there is no more work to do, since pending
1459        // messages may be blocking responses from the host.
1460        let flush = async {
1461            self.inner.messages.flush_messages().await;
1462            std::future::pending().await
1463        };
1464        let size = (recv, flush)
1465            .race()
1466            .await
1467            .expect("Fatal error reading messages from synic");
1468        if size == 0 {
1469            return false;
1470        }
1471        self.handle_synic_message(&buf[..size])
1472    }
1473
1474    /// Returns whether the server supports in-band messages to pause/resume the
1475    /// message stream.
1476    ///
1477    /// For hosts where this is not supported, we mask the sint to pause new
1478    /// messages being queued to the sint, then drain the messages. This does
1479    /// not work with some host implementations, which cannot support draining
1480    /// the message queue while the sint is masked (due to the use of
1481    /// HvPostMessageDirect).
1482    fn can_pause_resume(&self) -> bool {
1483        if let ClientState::Connected { version, .. } = self.state {
1484            version.feature_flags.pause_resume()
1485        } else {
1486            false
1487        }
1488    }
1489
1490    async fn run(&mut self) {
1491        let mut buf = [0; protocol::MAX_MESSAGE_SIZE];
1492        loop {
1493            let mut message_recv =
1494                OptionFuture::from(self.running.then(|| self.msg_source.recv(&mut buf).fuse()));
1495
1496            // If there are pending outgoing messages, the host is backed up.
1497            // Try to flush the queue, and in the meantime, stop generating new
1498            // messages by stopping processing client requests, so as to avoid
1499            // the outgoing message queue growing without bound.
1500            //
1501            // We still need to process incoming messages when in this state,
1502            // even though they may generate additional outgoing messages, to
1503            // avoid a deadlock with the host. The host can always DoS the
1504            // guest, so this is not an attack vector.
1505            let host_backed_up = !self.inner.messages.is_empty();
1506            let flush_messages = OptionFuture::from(
1507                (self.running && host_backed_up)
1508                    .then(|| self.inner.messages.flush_messages().fuse()),
1509            );
1510
1511            let mut client_request_recv = OptionFuture::from(
1512                (self.running && !host_backed_up).then(|| self.client_request_recv.next()),
1513            );
1514
1515            let mut channel_requests = OptionFuture::from(
1516                (self.running && !host_backed_up)
1517                    .then(|| self.inner.channel_requests.select_next_some()),
1518            );
1519
1520            futures::select! { // merge semantics
1521                _r = pin!(flush_messages) => {}
1522                r = self.task_recv.next() => {
1523                    if let Some(task) = r {
1524                        self.handle_task(task).await;
1525                    } else {
1526                        break;
1527                    }
1528                }
1529                r = client_request_recv => {
1530                    if let Some(Some(request)) = r {
1531                        self.handle_client_request(request);
1532                    } else {
1533                        break;
1534                    }
1535                }
1536                r = channel_requests => {
1537                    match r.unwrap() {
1538                        (id, Some(request)) => self.handle_channel_request(id, request),
1539                        (id, _) => {
1540                            self.handle_device_removal(id);
1541                        }
1542                    }
1543                }
1544                r = message_recv => {
1545                    match r.unwrap() {
1546                        Ok(size) => {
1547                            if size == 0 {
1548                                panic!("Unexpected end of file reading messages from synic.");
1549                            }
1550
1551                            self.handle_synic_message(&buf[..size]);
1552                        }
1553                        Err(err) => {
1554                            panic!("Error reading messages from synic: {err:?}");
1555                        }
1556                    }
1557                }
1558                complete => break,
1559            }
1560        }
1561    }
1562}
1563
1564impl ClientTaskInner {
1565    fn close_channel(&mut self, channel_id: ChannelId, channel: &mut Channel) {
1566        if let ChannelState::Opened {
1567            redirected_event_flag,
1568            ..
1569        } = channel.state
1570        {
1571            if let Some(flag) = redirected_event_flag {
1572                self.synic.free_event_flag(flag);
1573            }
1574            tracing::info!(channel_id = channel_id.0, "closing channel on host");
1575            self.messages.send(&protocol::CloseChannel { channel_id });
1576            channel.state = ChannelState::Offered;
1577        } else {
1578            tracing::warn!(
1579                id = %channel_id.0,
1580                channel_state = %channel.state,
1581                "invalid channel state for close channel"
1582            );
1583        }
1584    }
1585}
1586
1587#[derive(Debug, Inspect)]
1588#[inspect(external_tag)]
1589enum GpadlState {
1590    /// GpadlHeader has been sent to the host.
1591    Offered(#[inspect(skip)] FailableRpc<(), ()>),
1592    /// Host has responded with GpadlCreated.
1593    Created,
1594    /// GpadlTeardown message has been sent to the host.
1595    TearingDown {
1596        #[inspect(skip)]
1597        rpcs: Vec<Rpc<(), ()>>,
1598    },
1599}
1600
1601#[derive(Inspect)]
1602struct OutgoingMessages {
1603    #[inspect(skip)]
1604    poster: Box<dyn PollPostMessage>,
1605    #[inspect(with = "|x| x.len()")]
1606    queued: VecDeque<OutgoingMessage>,
1607    state: OutgoingMessageState,
1608}
1609
1610#[derive(Inspect, PartialEq, Eq, Debug)]
1611enum OutgoingMessageState {
1612    Running,
1613    SendingPauseMessage,
1614    Paused,
1615}
1616
1617impl OutgoingMessages {
1618    fn send<T: IntoBytes + protocol::VmbusMessage + std::fmt::Debug + Immutable + KnownLayout>(
1619        &mut self,
1620        msg: &T,
1621    ) {
1622        self.send_with_data(msg, &[])
1623    }
1624
1625    fn send_with_data<
1626        T: IntoBytes + protocol::VmbusMessage + std::fmt::Debug + Immutable + KnownLayout,
1627    >(
1628        &mut self,
1629        msg: &T,
1630        data: &[u8],
1631    ) {
1632        tracing::trace!(typ = ?T::MESSAGE_TYPE, "Sending message to host");
1633        let msg = OutgoingMessage::with_data(msg, data);
1634        if self.queued.is_empty() && self.state == OutgoingMessageState::Running {
1635            let r = self.poster.poll_post_message(
1636                &mut Context::from_waker(std::task::Waker::noop()),
1637                protocol::VMBUS_MESSAGE_REDIRECT_CONNECTION_ID,
1638                1,
1639                msg.data(),
1640            );
1641            if let Poll::Ready(()) = r {
1642                return;
1643            }
1644        }
1645        tracing::trace!("queueing message");
1646        self.queued.push_back(msg);
1647    }
1648
1649    async fn flush_messages(&mut self) {
1650        let mut send = async |msg: &OutgoingMessage| {
1651            poll_fn(|cx| {
1652                self.poster.poll_post_message(
1653                    cx,
1654                    protocol::VMBUS_MESSAGE_REDIRECT_CONNECTION_ID,
1655                    1,
1656                    msg.data(),
1657                )
1658            })
1659            .await
1660        };
1661        match self.state {
1662            OutgoingMessageState::Running => {
1663                while let Some(msg) = self.queued.front() {
1664                    send(msg).await;
1665                    tracing::trace!("sent queued message");
1666                    self.queued.pop_front();
1667                }
1668            }
1669            OutgoingMessageState::SendingPauseMessage => {
1670                send(&OutgoingMessage::new(&protocol::Pause)).await;
1671                tracing::trace!("sent pause message");
1672                self.state = OutgoingMessageState::Paused;
1673            }
1674            OutgoingMessageState::Paused => {}
1675        }
1676    }
1677
1678    /// Pause by sending a pause message to the host. This will cause the host
1679    /// to stop sending messages after sending a pause response.
1680    fn pause(&mut self) {
1681        assert_eq!(self.state, OutgoingMessageState::Running);
1682        self.state = OutgoingMessageState::SendingPauseMessage;
1683        // Queue a resume message to be sent later.
1684        self.queued
1685            .push_front(OutgoingMessage::new(&protocol::Resume));
1686    }
1687
1688    /// Force a pause by setting the state to Paused. This is used when the
1689    /// host does not support in-band pause/resume messages, in which case
1690    /// the SINT is masked to force the host to stop sending messages.
1691    fn force_pause(&mut self) {
1692        assert_eq!(self.state, OutgoingMessageState::Running);
1693        self.state = OutgoingMessageState::Paused;
1694    }
1695
1696    fn resume(&mut self) {
1697        assert_eq!(self.state, OutgoingMessageState::Paused);
1698        self.state = OutgoingMessageState::Running;
1699    }
1700
1701    fn is_empty(&self) -> bool {
1702        self.queued.is_empty()
1703    }
1704}
1705
1706#[derive(Inspect)]
1707struct ClientTaskInner {
1708    messages: OutgoingMessages,
1709    #[inspect(with = "|x| inspect::iter_by_key(x).map_key(|id| id.0)")]
1710    teardown_gpadls: HashMap<GpadlId, ChannelId>,
1711    #[inspect(skip)]
1712    channel_requests: SelectAll<TaggedStream<ChannelId, mesh::Receiver<ChannelRequest>>>,
1713    synic: SynicState,
1714}
1715
1716#[derive(Inspect)]
1717struct SynicState {
1718    #[inspect(skip)]
1719    event_client: Arc<dyn SynicEventClient>,
1720    #[inspect(iter_by_index)]
1721    event_flag_state: Vec<bool>,
1722}
1723
1724#[derive(Inspect, Default)]
1725#[inspect(transparent)]
1726struct ChannelList(
1727    #[inspect(with = "|x| inspect::iter_by_key(x).map_key(|id| id.0)")] HashMap<ChannelId, Channel>,
1728);
1729
1730/// A reference to a channel that can be used to remove the channel from the map
1731/// as well.
1732struct ChannelRef<'a>(hash_map::OccupiedEntry<'a, ChannelId, Channel>);
1733
1734/// A tag value used to indicate that [`ChannelRef::try_release`] has been called.
1735/// This is useful as a return value for methods that might transition a channel
1736/// into a fully released state.
1737struct TriedRelease(());
1738
1739impl ChannelRef<'_> {
1740    /// If the channel has been fully released (revoked, released by the client,
1741    /// no pending requests), notifes the server and removes this channel from
1742    /// the map.
1743    fn try_release(self, messages: &mut OutgoingMessages) -> TriedRelease {
1744        if self.is_client_released
1745            && matches!(self.state, ChannelState::Revoked)
1746            && self.pending_request().is_none()
1747        {
1748            let channel_id = *self.0.key();
1749            tracelimit::info_ratelimited!(channel_id = channel_id.0, "releasing channel");
1750            messages.send(&protocol::RelIdReleased { channel_id });
1751            self.0.remove();
1752        }
1753        TriedRelease(())
1754    }
1755}
1756
1757impl Deref for ChannelRef<'_> {
1758    type Target = Channel;
1759
1760    fn deref(&self) -> &Self::Target {
1761        self.0.get()
1762    }
1763}
1764
1765impl DerefMut for ChannelRef<'_> {
1766    fn deref_mut(&mut self) -> &mut Self::Target {
1767        self.0.get_mut()
1768    }
1769}
1770
1771impl ChannelList {
1772    fn revoked_channel_with_pending_request(&self) -> Option<(ChannelId, &'static str)> {
1773        self.0.iter().find_map(|(&id, channel)| {
1774            if !matches!(channel.state, ChannelState::Revoked) {
1775                return None;
1776            }
1777            Some((id, channel.pending_request()?))
1778        })
1779    }
1780
1781    #[track_caller]
1782    fn get_mut(&mut self, channel_id: ChannelId) -> ChannelRef<'_> {
1783        match self.0.entry(channel_id) {
1784            hash_map::Entry::Occupied(entry) => ChannelRef(entry),
1785            hash_map::Entry::Vacant(_) => {
1786                panic!("channel {:?} not found", channel_id);
1787            }
1788        }
1789    }
1790}
1791
1792impl SynicState {
1793    fn guest_to_host_interrupt(&self, connection_id: u32) -> Interrupt {
1794        Interrupt::from_fn({
1795            let event_client = self.event_client.clone();
1796            move || {
1797                if let Err(err) = event_client.signal_event(connection_id, 0) {
1798                    tracelimit::warn_ratelimited!(
1799                        error = &err as &dyn std::error::Error,
1800                        "failed to signal event"
1801                    );
1802                }
1803            }
1804        })
1805    }
1806
1807    const MAX_EVENT_FLAGS: u16 = 2047;
1808
1809    fn allocate_event_flag(&mut self, event: &Event) -> Result<u16> {
1810        let i = self
1811            .event_flag_state
1812            .iter()
1813            .position(|&used| !used)
1814            .ok_or(())
1815            .or_else(|()| {
1816                if self.event_flag_state.len() >= Self::MAX_EVENT_FLAGS as usize {
1817                    anyhow::bail!("out of event flags");
1818                }
1819                self.event_flag_state.push(false);
1820                Ok(self.event_flag_state.len() - 1)
1821            })?;
1822
1823        let event_flag = (i + 1) as u16;
1824        self.event_client
1825            .map_event(event_flag, event)
1826            .context("failed to map event")?;
1827        self.event_flag_state[i] = true;
1828        Ok(event_flag)
1829    }
1830
1831    fn restore_event_flag(&mut self, flag: u16, event: &Event) -> Result<()> {
1832        let i = (flag as usize)
1833            .checked_sub(1)
1834            .context("invalid event flag")?;
1835        if i >= Self::MAX_EVENT_FLAGS as usize {
1836            anyhow::bail!("invalid event flag");
1837        }
1838        if self.event_flag_state.len() <= i {
1839            self.event_flag_state.resize(i + 1, false);
1840        }
1841        if self.event_flag_state[i] {
1842            anyhow::bail!("event flag already in use");
1843        }
1844        self.event_client
1845            .map_event(flag, event)
1846            .context("failed to map event")?;
1847        self.event_flag_state[i] = true;
1848        Ok(())
1849    }
1850
1851    fn free_event_flag(&mut self, flag: u16) {
1852        let i = flag as usize - 1;
1853        assert!(i < self.event_flag_state.len());
1854        self.event_flag_state[i] = false;
1855    }
1856}
1857
1858#[cfg(test)]
1859mod tests {
1860    use super::*;
1861    use futures_concurrency::future::Join;
1862    use guid::Guid;
1863    use pal_async::DefaultDriver;
1864    use pal_async::async_test;
1865    use pal_async::timer::PolledTimer;
1866    use protocol::TargetInfo;
1867    use std::fmt::Debug;
1868    use std::task::ready;
1869    use std::time::Duration;
1870    use test_with_tracing::test;
1871    use vmbus_core::protocol::MessageHeader;
1872    use vmbus_core::protocol::MessageType;
1873    use vmbus_core::protocol::OfferFlags;
1874    use vmbus_core::protocol::UserDefinedData;
1875    use vmbus_core::protocol::VmbusMessage;
1876    use zerocopy::FromBytes;
1877    use zerocopy::FromZeros;
1878    use zerocopy::Immutable;
1879    use zerocopy::IntoBytes;
1880    use zerocopy::KnownLayout;
1881
1882    const VMBUS_TEST_CLIENT_ID: Guid = guid::guid!("e6e6e6e6-e6e6-e6e6-e6e6-e6e6e6e6e6e6");
1883
1884    fn in_msg<T: IntoBytes + Immutable + KnownLayout>(message_type: MessageType, t: T) -> Vec<u8> {
1885        let mut data = Vec::new();
1886        data.extend_from_slice(&message_type.0.to_ne_bytes());
1887        data.extend_from_slice(&0u32.to_ne_bytes());
1888        data.extend_from_slice(t.as_bytes());
1889        data
1890    }
1891
1892    #[track_caller]
1893    fn check_message<T>(msg: OutgoingMessage, chk: T)
1894    where
1895        T: IntoBytes + FromBytes + Immutable + KnownLayout + Debug + VmbusMessage,
1896    {
1897        check_message_with_data(msg, chk, &[]);
1898    }
1899
1900    #[track_caller]
1901    fn check_message_with_data<T>(msg: OutgoingMessage, chk: T, data: &[u8])
1902    where
1903        T: IntoBytes + FromBytes + Immutable + KnownLayout + Debug + VmbusMessage,
1904    {
1905        let chk_data = OutgoingMessage::with_data(&chk, data);
1906        if msg.data() != chk_data.data() {
1907            let (header, rest) = MessageHeader::read_from_prefix(msg.data()).unwrap();
1908            assert_eq!(header.message_type(), <T as VmbusMessage>::MESSAGE_TYPE);
1909            let (msg, rest) = T::read_from_prefix(rest).expect("incorrect message size");
1910            if msg.as_bytes() != chk.as_bytes() {
1911                panic!("mismatched messages, expected {:#?}, got {:#?}", chk, msg);
1912            }
1913            if rest != data {
1914                panic!("mismatched data, expected {:#?}, got {:#?}", data, rest);
1915            }
1916        }
1917    }
1918
1919    struct TestServer {
1920        messages: mesh::Receiver<OutgoingMessage>,
1921        send: mesh::Sender<Vec<u8>>,
1922    }
1923
1924    impl TestServer {
1925        async fn next(&mut self) -> Option<OutgoingMessage> {
1926            self.messages.next().await
1927        }
1928
1929        fn send(&self, msg: Vec<u8>) {
1930            self.send.send(msg);
1931        }
1932
1933        async fn connect(&mut self, client: &mut VmbusClient) -> ConnectResult {
1934            self.connect_with_channels(client, |_| {}).await
1935        }
1936
1937        async fn connect_with_channels(
1938            &mut self,
1939            client: &mut VmbusClient,
1940            send_offers: impl FnOnce(&mut Self),
1941        ) -> ConnectResult {
1942            let client_connect = client.connect(0, None, Guid::ZERO);
1943
1944            let server_connect = async {
1945                let _ = self.next().await.unwrap();
1946
1947                self.send(in_msg(
1948                    MessageType::VERSION_RESPONSE,
1949                    protocol::VersionResponse2 {
1950                        version_response: protocol::VersionResponse {
1951                            version_supported: 1,
1952                            connection_state: ConnectionState::SUCCESSFUL,
1953                            padding: 0,
1954                            selected_version_or_connection_id: 0,
1955                        },
1956                        supported_features: SUPPORTED_FEATURE_FLAGS.into(),
1957                    },
1958                ));
1959
1960                check_message(self.next().await.unwrap(), protocol::RequestOffers {});
1961
1962                send_offers(self);
1963                self.send(in_msg(MessageType::ALL_OFFERS_DELIVERED, [0x00]));
1964            };
1965
1966            let (connection, ()) = (client_connect, server_connect).join().await;
1967
1968            let connection = connection.unwrap();
1969            assert_eq!(connection.version.version, Version::Copper);
1970            assert_eq!(connection.version.feature_flags, SUPPORTED_FEATURE_FLAGS);
1971            connection
1972        }
1973
1974        async fn get_channel(&mut self, client: &mut VmbusClient) -> OfferInfo {
1975            let [channel] = self
1976                .get_channels(client, 1)
1977                .await
1978                .offers
1979                .try_into()
1980                .unwrap();
1981            channel
1982        }
1983
1984        async fn get_channels(&mut self, client: &mut VmbusClient, count: usize) -> ConnectResult {
1985            self.connect_with_channels(client, |this| {
1986                for i in 0..count {
1987                    let offer = protocol::OfferChannel {
1988                        interface_id: Guid::new_random(),
1989                        instance_id: Guid::new_random(),
1990                        rsvd: [0; 4],
1991                        flags: OfferFlags::new(),
1992                        mmio_megabytes: 0,
1993                        user_defined: UserDefinedData::new_zeroed(),
1994                        subchannel_index: 0,
1995                        mmio_megabytes_optional: 0,
1996                        channel_id: ChannelId(i as u32),
1997                        monitor_id: 0,
1998                        monitor_allocated: 0,
1999                        is_dedicated: 0,
2000                        connection_id: 0,
2001                    };
2002
2003                    this.send(in_msg(MessageType::OFFER_CHANNEL, offer));
2004                }
2005            })
2006            .await
2007        }
2008
2009        async fn stop_client(&mut self, client: &mut VmbusClient) {
2010            let client_stop = client.stop();
2011            let server_stop = async {
2012                check_message(self.next().await.unwrap(), protocol::Pause);
2013                self.send(in_msg(MessageType::PAUSE_RESPONSE, protocol::PauseResponse));
2014            };
2015            (client_stop, server_stop).join().await;
2016        }
2017
2018        async fn start_client(&mut self, client: &mut VmbusClient) {
2019            client.start();
2020            check_message(self.next().await.unwrap(), protocol::Resume);
2021        }
2022    }
2023
2024    struct TestServerClient {
2025        sender: mesh::Sender<OutgoingMessage>,
2026        timer: PolledTimer,
2027        deadline: Option<pal_async::timer::Instant>,
2028    }
2029
2030    impl PollPostMessage for TestServerClient {
2031        fn poll_post_message(
2032            &mut self,
2033            cx: &mut Context<'_>,
2034            _connection_id: u32,
2035            _typ: u32,
2036            msg: &[u8],
2037        ) -> Poll<()> {
2038            loop {
2039                if let Some(deadline) = self.deadline {
2040                    ready!(self.timer.poll_until(cx, deadline));
2041                    self.deadline = None;
2042                }
2043                // Randomly choose whether to delay the message.
2044                //
2045                // FUTURE: use some kind of deterministic test framework for this to
2046                // allow for reproducible tests.
2047                let mut b = [0];
2048                getrandom::fill(&mut b).unwrap();
2049                if b[0] % 4 == 0 {
2050                    self.deadline =
2051                        Some(pal_async::timer::Instant::now() + Duration::from_millis(10));
2052                } else {
2053                    let msg = OutgoingMessage::from_message(msg).unwrap();
2054                    tracing::info!(
2055                        msg = ?MessageHeader::read_from_prefix(msg.data()),
2056                        "sending message"
2057                    );
2058                    self.sender.send(msg);
2059                    break Poll::Ready(());
2060                }
2061            }
2062        }
2063    }
2064
2065    struct NoopSynicEvents;
2066
2067    impl SynicEventClient for NoopSynicEvents {
2068        fn map_event(&self, _event_flag: u16, _event: &Event) -> std::io::Result<()> {
2069            Ok(())
2070        }
2071
2072        fn unmap_event(&self, _event_flag: u16) {}
2073
2074        fn signal_event(&self, _connection_id: u32, _event_flag: u16) -> std::io::Result<()> {
2075            Err(std::io::ErrorKind::Unsupported.into())
2076        }
2077    }
2078
2079    struct TestMessageSource {
2080        msg_recv: mesh::Receiver<Vec<u8>>,
2081        paused: bool,
2082    }
2083
2084    impl AsyncRecv for TestMessageSource {
2085        fn poll_recv(
2086            &mut self,
2087            cx: &mut Context<'_>,
2088            mut bufs: &mut [std::io::IoSliceMut<'_>],
2089        ) -> Poll<std::io::Result<usize>> {
2090            let value = match self.msg_recv.poll_recv(cx) {
2091                Poll::Ready(v) => v.unwrap(),
2092                Poll::Pending => {
2093                    if self.paused {
2094                        return Poll::Ready(Ok(0));
2095                    } else {
2096                        return Poll::Pending;
2097                    }
2098                }
2099            };
2100            let mut remaining = value.as_slice();
2101            let mut total_size = 0;
2102            while !remaining.is_empty() && !bufs.is_empty() {
2103                let size = bufs[0].len().min(remaining.len());
2104                bufs[0][..size].copy_from_slice(&remaining[..size]);
2105                remaining = &remaining[size..];
2106                bufs = &mut bufs[1..];
2107                total_size += size;
2108            }
2109
2110            Ok(total_size).into()
2111        }
2112    }
2113
2114    impl VmbusMessageSource for TestMessageSource {
2115        fn pause_message_stream(&mut self) {
2116            self.paused = true;
2117        }
2118
2119        fn resume_message_stream(&mut self) {
2120            self.paused = false;
2121        }
2122    }
2123
2124    fn test_init(driver: &DefaultDriver) -> (TestServer, VmbusClient) {
2125        let (msg_send, msg_recv) = mesh::channel();
2126        let (synic_send, synic_recv) = mesh::channel();
2127        let server = TestServer {
2128            messages: synic_recv,
2129            send: msg_send,
2130        };
2131        let mut client = VmbusClientBuilder::new(
2132            NoopSynicEvents,
2133            TestMessageSource {
2134                msg_recv,
2135                paused: false,
2136            },
2137            TestServerClient {
2138                sender: synic_send,
2139                deadline: None,
2140                timer: PolledTimer::new(driver),
2141            },
2142        )
2143        .build(driver);
2144        client.start();
2145        (server, client)
2146    }
2147
2148    #[async_test]
2149    async fn test_initiate_contact_success(driver: DefaultDriver) {
2150        let (mut server, client) = test_init(&driver);
2151        let _recv = client
2152            .access
2153            .client_request_send
2154            .call(ClientRequest::Connect, ConnectRequest::default());
2155        check_message(
2156            server.next().await.unwrap(),
2157            protocol::InitiateContact2 {
2158                initiate_contact: protocol::InitiateContact {
2159                    version_requested: Version::Copper as u32,
2160                    target_message_vp: 0,
2161                    interrupt_page_or_target_info: TargetInfo::new()
2162                        .with_sint(2)
2163                        .with_vtl(0)
2164                        .with_feature_flags(SUPPORTED_FEATURE_FLAGS.into())
2165                        .into(),
2166                    parent_to_child_monitor_page_gpa: 0,
2167                    child_to_parent_monitor_page_gpa: 0,
2168                },
2169                ..FromZeros::new_zeroed()
2170            },
2171        );
2172    }
2173
2174    #[async_test]
2175    async fn test_connect_success(driver: DefaultDriver) {
2176        let (mut server, mut client) = test_init(&driver);
2177        let client_connect = client.connect(0, None, Guid::ZERO);
2178
2179        let server_connect = async {
2180            check_message(
2181                server.next().await.unwrap(),
2182                protocol::InitiateContact2 {
2183                    initiate_contact: protocol::InitiateContact {
2184                        version_requested: Version::Copper as u32,
2185                        target_message_vp: 0,
2186                        interrupt_page_or_target_info: TargetInfo::new()
2187                            .with_sint(2)
2188                            .with_vtl(0)
2189                            .with_feature_flags(SUPPORTED_FEATURE_FLAGS.into())
2190                            .into(),
2191                        parent_to_child_monitor_page_gpa: 0,
2192                        child_to_parent_monitor_page_gpa: 0,
2193                    },
2194                    ..FromZeros::new_zeroed()
2195                },
2196            );
2197
2198            server.send(in_msg(
2199                MessageType::VERSION_RESPONSE,
2200                protocol::VersionResponse2 {
2201                    version_response: protocol::VersionResponse {
2202                        version_supported: 1,
2203                        connection_state: ConnectionState::SUCCESSFUL,
2204                        padding: 0,
2205                        selected_version_or_connection_id: 0,
2206                    },
2207                    supported_features: SUPPORTED_FEATURE_FLAGS.into_bits(),
2208                },
2209            ));
2210
2211            check_message(server.next().await.unwrap(), protocol::RequestOffers {});
2212            server.send(in_msg(MessageType::ALL_OFFERS_DELIVERED, [0x00]));
2213        };
2214
2215        let (connection, ()) = (client_connect, server_connect).join().await;
2216        let connection = connection.unwrap();
2217
2218        assert_eq!(connection.version.version, Version::Copper);
2219        assert_eq!(connection.version.feature_flags, SUPPORTED_FEATURE_FLAGS);
2220    }
2221
2222    #[async_test]
2223    async fn test_feature_flags(driver: DefaultDriver) {
2224        let (mut server, mut client) = test_init(&driver);
2225        let client_connect = client.connect(0, None, Guid::ZERO);
2226
2227        let server_connect = async {
2228            check_message(
2229                server.next().await.unwrap(),
2230                protocol::InitiateContact2 {
2231                    initiate_contact: protocol::InitiateContact {
2232                        version_requested: Version::Copper as u32,
2233                        target_message_vp: 0,
2234                        interrupt_page_or_target_info: TargetInfo::new()
2235                            .with_sint(2)
2236                            .with_vtl(0)
2237                            .with_feature_flags(SUPPORTED_FEATURE_FLAGS.into())
2238                            .into(),
2239                        parent_to_child_monitor_page_gpa: 0,
2240                        child_to_parent_monitor_page_gpa: 0,
2241                    },
2242                    ..FromZeros::new_zeroed()
2243                },
2244            );
2245
2246            // Report the server doesn't support some of the feature flags, and make
2247            // sure this is reflected in the returned version.
2248            server.send(in_msg(
2249                MessageType::VERSION_RESPONSE,
2250                protocol::VersionResponse2 {
2251                    version_response: protocol::VersionResponse {
2252                        version_supported: 1,
2253                        connection_state: ConnectionState::SUCCESSFUL,
2254                        padding: 0,
2255                        selected_version_or_connection_id: 0,
2256                    },
2257                    supported_features: 2,
2258                },
2259            ));
2260
2261            check_message(server.next().await.unwrap(), protocol::RequestOffers {});
2262            server.send(in_msg(MessageType::ALL_OFFERS_DELIVERED, [0x00]));
2263        };
2264
2265        let (connection, ()) = (client_connect, server_connect).join().await;
2266        let connection = connection.unwrap();
2267
2268        assert_eq!(connection.version.version, Version::Copper);
2269        assert_eq!(
2270            connection.version.feature_flags,
2271            FeatureFlags::new().with_channel_interrupt_redirection(true)
2272        );
2273    }
2274
2275    #[async_test]
2276    async fn test_client_id(driver: DefaultDriver) {
2277        let (mut server, client) = test_init(&driver);
2278        let initiate_contact = ConnectRequest {
2279            client_id: VMBUS_TEST_CLIENT_ID,
2280            ..Default::default()
2281        };
2282        let _recv = client
2283            .access
2284            .client_request_send
2285            .call(ClientRequest::Connect, initiate_contact);
2286
2287        check_message(
2288            server.next().await.unwrap(),
2289            protocol::InitiateContact2 {
2290                initiate_contact: protocol::InitiateContact {
2291                    version_requested: Version::Copper as u32,
2292                    target_message_vp: 0,
2293                    interrupt_page_or_target_info: TargetInfo::new()
2294                        .with_sint(2)
2295                        .with_vtl(0)
2296                        .with_feature_flags(SUPPORTED_FEATURE_FLAGS.into())
2297                        .into(),
2298                    parent_to_child_monitor_page_gpa: 0,
2299                    child_to_parent_monitor_page_gpa: 0,
2300                },
2301                client_id: VMBUS_TEST_CLIENT_ID,
2302            },
2303        );
2304    }
2305
2306    #[async_test]
2307    async fn test_version_negotiation(driver: DefaultDriver) {
2308        let (mut server, mut client) = test_init(&driver);
2309        let client_connect = client.connect(0, None, Guid::ZERO);
2310
2311        let server_connect = async {
2312            check_message(
2313                server.next().await.unwrap(),
2314                protocol::InitiateContact2 {
2315                    initiate_contact: protocol::InitiateContact {
2316                        version_requested: Version::Copper as u32,
2317                        target_message_vp: 0,
2318                        interrupt_page_or_target_info: TargetInfo::new()
2319                            .with_sint(2)
2320                            .with_vtl(0)
2321                            .with_feature_flags(SUPPORTED_FEATURE_FLAGS.into())
2322                            .into(),
2323                        parent_to_child_monitor_page_gpa: 0,
2324                        child_to_parent_monitor_page_gpa: 0,
2325                    },
2326                    ..FromZeros::new_zeroed()
2327                },
2328            );
2329
2330            server.send(in_msg(
2331                MessageType::VERSION_RESPONSE,
2332                protocol::VersionResponse {
2333                    version_supported: 0,
2334                    connection_state: ConnectionState::SUCCESSFUL,
2335                    padding: 0,
2336                    selected_version_or_connection_id: 0,
2337                },
2338            ));
2339
2340            check_message(
2341                server.next().await.unwrap(),
2342                protocol::InitiateContact {
2343                    version_requested: Version::Iron as u32,
2344                    target_message_vp: 0,
2345                    interrupt_page_or_target_info: TargetInfo::new()
2346                        .with_sint(2)
2347                        .with_vtl(0)
2348                        .with_feature_flags(FeatureFlags::new().into())
2349                        .into(),
2350                    parent_to_child_monitor_page_gpa: 0,
2351                    child_to_parent_monitor_page_gpa: 0,
2352                },
2353            );
2354
2355            server.send(in_msg(
2356                MessageType::VERSION_RESPONSE,
2357                protocol::VersionResponse {
2358                    version_supported: 1,
2359                    connection_state: ConnectionState::SUCCESSFUL,
2360                    padding: 0,
2361                    selected_version_or_connection_id: 0,
2362                },
2363            ));
2364
2365            check_message(server.next().await.unwrap(), protocol::RequestOffers {});
2366            server.send(in_msg(MessageType::ALL_OFFERS_DELIVERED, [0x00]));
2367        };
2368
2369        let (connection, ()) = (client_connect, server_connect).join().await;
2370        let connection = connection.unwrap();
2371
2372        assert_eq!(connection.version.version, Version::Iron);
2373        assert_eq!(connection.version.feature_flags, FeatureFlags::new());
2374    }
2375
2376    #[async_test]
2377    async fn test_open_channel_success(driver: DefaultDriver) {
2378        let (mut server, mut client) = test_init(&driver);
2379        let channel = server.get_channel(&mut client).await;
2380
2381        let recv = channel.request_send.call(
2382            ChannelRequest::Open,
2383            OpenRequest {
2384                open_data: OpenData {
2385                    target_vp: 0,
2386                    ring_offset: 0,
2387                    ring_gpadl_id: GpadlId(0),
2388                    event_flag: 0,
2389                    connection_id: 0,
2390                    user_data: UserDefinedData::new_zeroed(),
2391                },
2392                incoming_event: None,
2393                use_vtl2_connection_id: false,
2394            },
2395        );
2396
2397        check_message(
2398            server.next().await.unwrap(),
2399            protocol::OpenChannel2 {
2400                open_channel: protocol::OpenChannel {
2401                    channel_id: ChannelId(0),
2402                    open_id: 0,
2403                    ring_buffer_gpadl_id: GpadlId(0),
2404                    target_vp: 0,
2405                    downstream_ring_buffer_page_offset: 0,
2406                    user_data: UserDefinedData::new_zeroed(),
2407                },
2408                connection_id: 0,
2409                event_flag: 0,
2410                flags: Default::default(),
2411            },
2412        );
2413
2414        server.send(in_msg(
2415            MessageType::OPEN_CHANNEL_RESULT,
2416            protocol::OpenResult {
2417                channel_id: ChannelId(0),
2418                open_id: 0,
2419                status: protocol::STATUS_SUCCESS as u32,
2420            },
2421        ));
2422
2423        recv.await.unwrap().unwrap();
2424    }
2425
2426    #[async_test]
2427    async fn test_open_channel_fail(driver: DefaultDriver) {
2428        let (mut server, mut client) = test_init(&driver);
2429        let channel = server.get_channel(&mut client).await;
2430
2431        let recv = channel.request_send.call(
2432            ChannelRequest::Open,
2433            OpenRequest {
2434                open_data: OpenData {
2435                    target_vp: 0,
2436                    ring_offset: 0,
2437                    ring_gpadl_id: GpadlId(0),
2438                    event_flag: 0,
2439                    connection_id: 0,
2440                    user_data: UserDefinedData::new_zeroed(),
2441                },
2442                incoming_event: None,
2443                use_vtl2_connection_id: false,
2444            },
2445        );
2446
2447        check_message(
2448            server.next().await.unwrap(),
2449            protocol::OpenChannel2 {
2450                open_channel: protocol::OpenChannel {
2451                    channel_id: ChannelId(0),
2452                    open_id: 0,
2453                    ring_buffer_gpadl_id: GpadlId(0),
2454                    target_vp: 0,
2455                    downstream_ring_buffer_page_offset: 0,
2456                    user_data: UserDefinedData::new_zeroed(),
2457                },
2458                connection_id: 0,
2459                event_flag: 0,
2460                flags: Default::default(),
2461            },
2462        );
2463
2464        server.send(in_msg(
2465            MessageType::OPEN_CHANNEL_RESULT,
2466            protocol::OpenResult {
2467                channel_id: ChannelId(0),
2468                open_id: 0,
2469                status: protocol::STATUS_UNSUCCESSFUL as u32,
2470            },
2471        ));
2472
2473        recv.await.unwrap().unwrap_err();
2474    }
2475
2476    #[async_test]
2477    async fn test_modify_channel(driver: DefaultDriver) {
2478        let (mut server, mut client) = test_init(&driver);
2479        let channel = server.get_channel(&mut client).await;
2480
2481        // N.B. A real server requires the channel to be open before sending this, but the test
2482        //      server doesn't care.
2483        let recv = channel.request_send.call(
2484            ChannelRequest::Modify,
2485            ModifyRequest::TargetVp { target_vp: 1 },
2486        );
2487
2488        check_message(
2489            server.next().await.unwrap(),
2490            protocol::ModifyChannel {
2491                channel_id: ChannelId(0),
2492                target_vp: 1,
2493            },
2494        );
2495
2496        server.send(in_msg(
2497            MessageType::MODIFY_CHANNEL_RESPONSE,
2498            protocol::ModifyChannelResponse {
2499                channel_id: ChannelId(0),
2500                status: protocol::STATUS_SUCCESS,
2501            },
2502        ));
2503
2504        let status = recv.await.unwrap();
2505        assert_eq!(status, protocol::STATUS_SUCCESS);
2506    }
2507
2508    #[async_test]
2509    async fn test_save_restore_connected(driver: DefaultDriver) {
2510        let (mut server, mut client) = test_init(&driver);
2511        server.connect(&mut client).await;
2512        server.stop_client(&mut client).await;
2513        let s0 = client.save().await;
2514        let builder = client.sever().await;
2515        let mut client = builder.build(&driver);
2516        client.restore(s0.clone()).await.unwrap();
2517
2518        let s1 = client.save().await;
2519
2520        assert_eq!(s0, s1);
2521    }
2522
2523    #[async_test]
2524    async fn test_save_restore_connected_with_channel(driver: DefaultDriver) {
2525        let (mut server, mut client) = test_init(&driver);
2526        let c0 = server.get_channel(&mut client).await;
2527        server.stop_client(&mut client).await;
2528        let s0 = client.save().await;
2529        let builder = client.sever().await;
2530        let mut client = builder.build(&driver);
2531        let connection = client.restore(s0.clone()).await.unwrap().unwrap();
2532        let s1 = client.save().await;
2533        assert_eq!(s0, s1);
2534        assert_eq!(connection.offers[0].offer, c0.offer);
2535    }
2536
2537    #[async_test]
2538    async fn test_save_restore_connected_with_revoked_channel(driver: DefaultDriver) {
2539        let (mut server, mut client) = test_init(&driver);
2540        let c0 = server.get_channel(&mut client).await;
2541        server.send(in_msg(
2542            MessageType::RESCIND_CHANNEL_OFFER,
2543            protocol::RescindChannelOffer {
2544                channel_id: ChannelId(0),
2545            },
2546        ));
2547        c0.revoke_recv.await.unwrap();
2548        let rpc = c0.request_send.call(
2549            ChannelRequest::Modify,
2550            ModifyRequest::TargetVp { target_vp: 1 },
2551        );
2552
2553        check_message(
2554            server.next().await.unwrap(),
2555            protocol::ModifyChannel {
2556                channel_id: ChannelId(0),
2557                target_vp: 1,
2558            },
2559        );
2560
2561        let client_stop = client.stop();
2562        let server_stop = async {
2563            server.send(in_msg(
2564                MessageType::MODIFY_CHANNEL_RESPONSE,
2565                protocol::ModifyChannelResponse {
2566                    channel_id: ChannelId(0),
2567                    status: protocol::STATUS_SUCCESS,
2568                },
2569            ));
2570            check_message(server.next().await.unwrap(), protocol::Pause);
2571            server.send(in_msg(MessageType::PAUSE_RESPONSE, protocol::PauseResponse));
2572        };
2573        (client_stop, server_stop).join().await;
2574
2575        rpc.await.unwrap();
2576
2577        let s0 = client.save().await;
2578        let builder = client.sever().await;
2579        let mut client = builder.build(&driver);
2580        let connection = client.restore(s0.clone()).await.unwrap().unwrap();
2581        let s1 = client.save().await;
2582        assert_eq!(s0, s1);
2583        assert!(connection.offers.is_empty());
2584        server.start_client(&mut client).await;
2585        check_message(
2586            server.next().await.unwrap(),
2587            protocol::RelIdReleased {
2588                channel_id: ChannelId(0),
2589            },
2590        );
2591    }
2592
2593    #[async_test]
2594    async fn test_connect_fails_on_incorrect_state(driver: DefaultDriver) {
2595        let (mut server, mut client) = test_init(&driver);
2596        server.connect(&mut client).await;
2597        let err = client.connect(0, None, Guid::ZERO).await.unwrap_err();
2598        assert!(matches!(err, ConnectError::InvalidState), "{:?}", err);
2599    }
2600
2601    #[async_test]
2602    async fn test_hot_add_remove(driver: DefaultDriver) {
2603        let (mut server, mut client) = test_init(&driver);
2604
2605        let mut connection = server.connect(&mut client).await;
2606        let offer = protocol::OfferChannel {
2607            interface_id: Guid::new_random(),
2608            instance_id: Guid::new_random(),
2609            rsvd: [0; 4],
2610            flags: OfferFlags::new(),
2611            mmio_megabytes: 0,
2612            user_defined: UserDefinedData::new_zeroed(),
2613            subchannel_index: 0,
2614            mmio_megabytes_optional: 0,
2615            channel_id: ChannelId(5),
2616            monitor_id: 0,
2617            monitor_allocated: 0,
2618            is_dedicated: 0,
2619            connection_id: 0,
2620        };
2621
2622        server.send(in_msg(MessageType::OFFER_CHANNEL, offer));
2623        let info = connection.offer_recv.next().await.unwrap();
2624
2625        assert_eq!(offer, info.offer);
2626
2627        server.send(in_msg(
2628            MessageType::RESCIND_CHANNEL_OFFER,
2629            protocol::RescindChannelOffer {
2630                channel_id: ChannelId(5),
2631            },
2632        ));
2633
2634        info.revoke_recv.await.unwrap();
2635        drop(info.request_send);
2636
2637        check_message(
2638            server.next().await.unwrap(),
2639            protocol::RelIdReleased {
2640                channel_id: ChannelId(5),
2641            },
2642        );
2643    }
2644
2645    #[async_test]
2646    async fn test_gpadl_success(driver: DefaultDriver) {
2647        let (mut server, mut client) = test_init(&driver);
2648        let channel = server.get_channel(&mut client).await;
2649        let recv = channel.request_send.call(
2650            ChannelRequest::Gpadl,
2651            GpadlRequest {
2652                id: GpadlId(1),
2653                count: 1,
2654                buf: vec![5],
2655            },
2656        );
2657
2658        check_message_with_data(
2659            server.next().await.unwrap(),
2660            protocol::GpadlHeader {
2661                channel_id: ChannelId(0),
2662                gpadl_id: GpadlId(1),
2663                len: 8,
2664                count: 1,
2665            },
2666            0x5u64.as_bytes(),
2667        );
2668
2669        server.send(in_msg(
2670            MessageType::GPADL_CREATED,
2671            protocol::GpadlCreated {
2672                channel_id: ChannelId(0),
2673                gpadl_id: GpadlId(1),
2674                status: protocol::STATUS_SUCCESS,
2675            },
2676        ));
2677
2678        recv.await.unwrap().unwrap();
2679
2680        let rpc = channel
2681            .request_send
2682            .call(ChannelRequest::TeardownGpadl, GpadlId(1));
2683
2684        check_message(
2685            server.next().await.unwrap(),
2686            protocol::GpadlTeardown {
2687                channel_id: ChannelId(0),
2688                gpadl_id: GpadlId(1),
2689            },
2690        );
2691
2692        server.send(in_msg(
2693            MessageType::GPADL_TORNDOWN,
2694            protocol::GpadlTorndown {
2695                gpadl_id: GpadlId(1),
2696            },
2697        ));
2698
2699        rpc.await.unwrap();
2700    }
2701
2702    #[async_test]
2703    async fn test_gpadl_fail(driver: DefaultDriver) {
2704        let (mut server, mut client) = test_init(&driver);
2705        let channel = server.get_channel(&mut client).await;
2706        let recv = channel.request_send.call(
2707            ChannelRequest::Gpadl,
2708            GpadlRequest {
2709                id: GpadlId(1),
2710                count: 1,
2711                buf: vec![7],
2712            },
2713        );
2714
2715        check_message_with_data(
2716            server.next().await.unwrap(),
2717            protocol::GpadlHeader {
2718                channel_id: ChannelId(0),
2719                gpadl_id: GpadlId(1),
2720                len: 8,
2721                count: 1,
2722            },
2723            0x7u64.as_bytes(),
2724        );
2725
2726        server.send(in_msg(
2727            MessageType::GPADL_CREATED,
2728            protocol::GpadlCreated {
2729                channel_id: ChannelId(0),
2730                gpadl_id: GpadlId(1),
2731                status: protocol::STATUS_UNSUCCESSFUL,
2732            },
2733        ));
2734
2735        recv.await.unwrap().unwrap_err();
2736    }
2737
2738    #[async_test]
2739    async fn test_gpadl_with_revoke(driver: DefaultDriver) {
2740        let (mut server, mut client) = test_init(&driver);
2741        let channel = server.get_channel(&mut client).await;
2742        let channel_id = ChannelId(0);
2743        for gpadl_id in [1, 2, 3].map(GpadlId) {
2744            let recv = channel.request_send.call(
2745                ChannelRequest::Gpadl,
2746                GpadlRequest {
2747                    id: gpadl_id,
2748                    count: 1,
2749                    buf: vec![3],
2750                },
2751            );
2752
2753            check_message_with_data(
2754                server.next().await.unwrap(),
2755                protocol::GpadlHeader {
2756                    channel_id,
2757                    gpadl_id,
2758                    len: 8,
2759                    count: 1,
2760                },
2761                0x3u64.as_bytes(),
2762            );
2763
2764            server.send(in_msg(
2765                MessageType::GPADL_CREATED,
2766                protocol::GpadlCreated {
2767                    channel_id,
2768                    gpadl_id,
2769                    status: protocol::STATUS_SUCCESS,
2770                },
2771            ));
2772
2773            recv.await.unwrap().unwrap();
2774        }
2775
2776        let rpc = channel
2777            .request_send
2778            .call(ChannelRequest::TeardownGpadl, GpadlId(1));
2779
2780        check_message(
2781            server.next().await.unwrap(),
2782            protocol::GpadlTeardown {
2783                channel_id,
2784                gpadl_id: GpadlId(1),
2785            },
2786        );
2787
2788        server.send(in_msg(
2789            MessageType::RESCIND_CHANNEL_OFFER,
2790            protocol::RescindChannelOffer { channel_id },
2791        ));
2792
2793        let recv = channel.request_send.call_failable(
2794            ChannelRequest::Gpadl,
2795            GpadlRequest {
2796                id: GpadlId(4),
2797                count: 1,
2798                buf: vec![3],
2799            },
2800        );
2801
2802        check_message_with_data(
2803            server.next().await.unwrap(),
2804            protocol::GpadlHeader {
2805                channel_id,
2806                gpadl_id: GpadlId(4),
2807                len: 8,
2808                count: 1,
2809            },
2810            0x3u64.as_bytes(),
2811        );
2812
2813        server.send(in_msg(
2814            MessageType::GPADL_CREATED,
2815            protocol::GpadlCreated {
2816                channel_id,
2817                gpadl_id: GpadlId(4),
2818                status: protocol::STATUS_UNSUCCESSFUL,
2819            },
2820        ));
2821
2822        server.send(in_msg(
2823            MessageType::GPADL_TORNDOWN,
2824            protocol::GpadlTorndown {
2825                gpadl_id: GpadlId(1),
2826            },
2827        ));
2828
2829        rpc.await.unwrap();
2830        recv.await.unwrap_err();
2831
2832        channel.revoke_recv.await.unwrap();
2833
2834        let rpc = channel
2835            .request_send
2836            .call(ChannelRequest::TeardownGpadl, GpadlId(2));
2837        drop(channel.request_send);
2838
2839        check_message(
2840            server.next().await.unwrap(),
2841            protocol::GpadlTeardown {
2842                channel_id,
2843                gpadl_id: GpadlId(2),
2844            },
2845        );
2846
2847        server.send(in_msg(
2848            MessageType::GPADL_TORNDOWN,
2849            protocol::GpadlTorndown {
2850                gpadl_id: GpadlId(2),
2851            },
2852        ));
2853
2854        rpc.await.unwrap();
2855
2856        check_message(
2857            server.next().await.unwrap(),
2858            protocol::RelIdReleased { channel_id },
2859        );
2860    }
2861
2862    #[async_test]
2863    async fn test_modify_connection(driver: DefaultDriver) {
2864        let (mut server, mut client) = test_init(&driver);
2865        server.connect(&mut client).await;
2866        let call = client.access.client_request_send.call(
2867            ClientRequest::Modify,
2868            ModifyConnectionRequest {
2869                monitor_page: Some(MonitorPageGpas {
2870                    child_to_parent: 5,
2871                    parent_to_child: 6,
2872                }),
2873            },
2874        );
2875
2876        check_message(
2877            server.next().await.unwrap(),
2878            protocol::ModifyConnection {
2879                child_to_parent_monitor_page_gpa: 5,
2880                parent_to_child_monitor_page_gpa: 6,
2881            },
2882        );
2883
2884        server.send(in_msg(
2885            MessageType::MODIFY_CONNECTION_RESPONSE,
2886            protocol::ModifyConnectionResponse {
2887                connection_state: ConnectionState::FAILED_LOW_RESOURCES,
2888            },
2889        ));
2890
2891        let result = call.await.unwrap();
2892        assert_eq!(ConnectionState::FAILED_LOW_RESOURCES, result);
2893    }
2894
2895    #[async_test]
2896    async fn test_hvsock(driver: DefaultDriver) {
2897        let (mut server, mut client) = test_init(&driver);
2898        server.connect(&mut client).await;
2899        let request = HvsockConnectRequest {
2900            service_id: Guid::new_random(),
2901            endpoint_id: Guid::new_random(),
2902            silo_id: Guid::new_random(),
2903            hosted_silo_unaware: false,
2904        };
2905
2906        let resp = client.access().connect_hvsock(request);
2907        check_message(
2908            server.next().await.unwrap(),
2909            protocol::TlConnectRequest2 {
2910                base: protocol::TlConnectRequest {
2911                    service_id: request.service_id,
2912                    endpoint_id: request.endpoint_id,
2913                },
2914                silo_id: request.silo_id,
2915            },
2916        );
2917
2918        // Now send a failure result.
2919        server.send(in_msg(
2920            MessageType::TL_CONNECT_REQUEST_RESULT,
2921            protocol::TlConnectResult {
2922                service_id: request.service_id,
2923                endpoint_id: request.endpoint_id,
2924                status: protocol::STATUS_CONNECTION_REFUSED,
2925            },
2926        ));
2927
2928        let result = resp.await;
2929        assert!(result.is_none());
2930    }
2931
2932    #[async_test]
2933    async fn test_synic_event_flags(driver: DefaultDriver) {
2934        let (mut server, mut client) = test_init(&driver);
2935        let connection = server.get_channels(&mut client, 5).await;
2936        let event = Event::new();
2937
2938        for _ in 0..5 {
2939            for (i, channel) in connection.offers.iter().enumerate() {
2940                let recv = channel.request_send.call(
2941                    ChannelRequest::Open,
2942                    OpenRequest {
2943                        open_data: OpenData {
2944                            target_vp: 0,
2945                            ring_offset: 0,
2946                            ring_gpadl_id: GpadlId(0),
2947                            event_flag: 0,
2948                            connection_id: 0,
2949                            user_data: UserDefinedData::new_zeroed(),
2950                        },
2951                        incoming_event: Some(event.clone()),
2952                        use_vtl2_connection_id: false,
2953                    },
2954                );
2955
2956                let expected_event_flag = i as u16 + 1;
2957
2958                check_message(
2959                    server.next().await.unwrap(),
2960                    protocol::OpenChannel2 {
2961                        open_channel: protocol::OpenChannel {
2962                            channel_id: channel.offer.channel_id,
2963                            open_id: 0,
2964                            ring_buffer_gpadl_id: GpadlId(0),
2965                            target_vp: 0,
2966                            downstream_ring_buffer_page_offset: 0,
2967                            user_data: UserDefinedData::new_zeroed(),
2968                        },
2969                        connection_id: 0,
2970                        event_flag: expected_event_flag,
2971                        flags: OpenChannelFlags::new().with_redirect_interrupt(true),
2972                    },
2973                );
2974
2975                server.send(in_msg(
2976                    MessageType::OPEN_CHANNEL_RESULT,
2977                    protocol::OpenResult {
2978                        channel_id: channel.offer.channel_id,
2979                        open_id: 0,
2980                        status: protocol::STATUS_SUCCESS as u32,
2981                    },
2982                ));
2983
2984                let output = recv.await.unwrap().unwrap();
2985                assert_eq!(output.redirected_event_flag, Some(expected_event_flag));
2986            }
2987
2988            for (i, channel) in connection.offers.iter().enumerate() {
2989                // Close the channel to prepare for the next iteration of the loop.
2990                // The event flag should be the same each time.
2991                channel
2992                    .request_send
2993                    .call(ChannelRequest::Close, ())
2994                    .await
2995                    .unwrap();
2996
2997                check_message(
2998                    server.next().await.unwrap(),
2999                    protocol::CloseChannel {
3000                        channel_id: ChannelId(i as u32),
3001                    },
3002                );
3003            }
3004        }
3005    }
3006
3007    #[async_test]
3008    async fn test_revoke(driver: DefaultDriver) {
3009        let (mut server, mut client) = test_init(&driver);
3010        let channel = server.get_channel(&mut client).await;
3011
3012        server.send(in_msg(
3013            MessageType::RESCIND_CHANNEL_OFFER,
3014            protocol::RescindChannelOffer {
3015                channel_id: ChannelId(0),
3016            },
3017        ));
3018
3019        channel.revoke_recv.await.unwrap();
3020
3021        channel
3022            .request_send
3023            .call_failable(
3024                ChannelRequest::Open,
3025                OpenRequest {
3026                    open_data: OpenData {
3027                        target_vp: 0,
3028                        ring_offset: 0,
3029                        ring_gpadl_id: GpadlId(0),
3030                        event_flag: 0,
3031                        connection_id: 0,
3032                        user_data: UserDefinedData::new_zeroed(),
3033                    },
3034                    incoming_event: None,
3035                    use_vtl2_connection_id: false,
3036                },
3037            )
3038            .await
3039            .unwrap_err();
3040    }
3041
3042    #[async_test]
3043    #[should_panic(expected = "channel should not exist")]
3044    async fn test_reoffer_in_use_rel_id(driver: DefaultDriver) {
3045        let (mut server, mut client) = test_init(&driver);
3046        let mut connection = server.get_channels(&mut client, 1).await;
3047        let [channel] = connection.offers.try_into().unwrap();
3048
3049        server.send(in_msg(
3050            MessageType::RESCIND_CHANNEL_OFFER,
3051            protocol::RescindChannelOffer {
3052                channel_id: ChannelId(0),
3053            },
3054        ));
3055
3056        channel.revoke_recv.await.unwrap();
3057
3058        // This offer will cause a panic since the rel id is still in use.
3059        let offer = protocol::OfferChannel {
3060            interface_id: Guid::new_random(),
3061            instance_id: Guid::new_random(),
3062            rsvd: [0; 4],
3063            flags: OfferFlags::new(),
3064            mmio_megabytes: 0,
3065            user_defined: UserDefinedData::new_zeroed(),
3066            subchannel_index: 0,
3067            mmio_megabytes_optional: 0,
3068            channel_id: ChannelId(0),
3069            monitor_id: 0,
3070            monitor_allocated: 0,
3071            is_dedicated: 0,
3072            connection_id: 0,
3073        };
3074
3075        server.send(in_msg(MessageType::OFFER_CHANNEL, offer));
3076
3077        connection.offer_recv.next().await;
3078    }
3079
3080    #[async_test]
3081    async fn test_revoke_release_and_reoffer(driver: DefaultDriver) {
3082        let (mut server, mut client) = test_init(&driver);
3083        let mut connection = server.get_channels(&mut client, 1).await;
3084        let [channel] = connection.offers.try_into().unwrap();
3085
3086        server.send(in_msg(
3087            MessageType::RESCIND_CHANNEL_OFFER,
3088            protocol::RescindChannelOffer {
3089                channel_id: ChannelId(0),
3090            },
3091        ));
3092
3093        channel.revoke_recv.await.unwrap();
3094        drop(channel.request_send);
3095
3096        check_message(
3097            server.next().await.unwrap(),
3098            protocol::RelIdReleased {
3099                channel_id: ChannelId(0),
3100            },
3101        );
3102
3103        let offer = protocol::OfferChannel {
3104            interface_id: Guid::new_random(),
3105            instance_id: Guid::new_random(),
3106            rsvd: [0; 4],
3107            flags: OfferFlags::new(),
3108            mmio_megabytes: 0,
3109            user_defined: UserDefinedData::new_zeroed(),
3110            subchannel_index: 0,
3111            mmio_megabytes_optional: 0,
3112            channel_id: ChannelId(0),
3113            monitor_id: 0,
3114            monitor_allocated: 0,
3115            is_dedicated: 0,
3116            connection_id: 0,
3117        };
3118
3119        server.send(in_msg(MessageType::OFFER_CHANNEL, offer));
3120
3121        connection.offer_recv.next().await.unwrap();
3122    }
3123
3124    #[async_test]
3125    async fn test_release_revoke_and_reoffer(driver: DefaultDriver) {
3126        let (mut server, mut client) = test_init(&driver);
3127        let mut connection = server.get_channels(&mut client, 1).await;
3128        let [channel] = connection.offers.try_into().unwrap();
3129
3130        let open = channel.request_send.call_failable(
3131            ChannelRequest::Open,
3132            OpenRequest {
3133                open_data: OpenData {
3134                    target_vp: 0,
3135                    ring_offset: 0,
3136                    ring_gpadl_id: GpadlId(0),
3137                    event_flag: 0,
3138                    connection_id: 0,
3139                    user_data: UserDefinedData::new_zeroed(),
3140                },
3141                incoming_event: None,
3142                use_vtl2_connection_id: false,
3143            },
3144        );
3145
3146        let server_open = async {
3147            check_message(
3148                server.next().await.unwrap(),
3149                protocol::OpenChannel2 {
3150                    open_channel: protocol::OpenChannel {
3151                        channel_id: ChannelId(0),
3152                        open_id: 0,
3153                        ring_buffer_gpadl_id: GpadlId(0),
3154                        target_vp: 0,
3155                        downstream_ring_buffer_page_offset: 0,
3156                        user_data: UserDefinedData::new_zeroed(),
3157                    },
3158                    connection_id: 0,
3159                    event_flag: 0,
3160                    flags: Default::default(),
3161                },
3162            );
3163            server.send(in_msg(
3164                MessageType::OPEN_CHANNEL_RESULT,
3165                protocol::OpenResult {
3166                    channel_id: ChannelId(0),
3167                    open_id: 0,
3168                    status: protocol::STATUS_SUCCESS as u32,
3169                },
3170            ));
3171        };
3172
3173        (open, server_open).join().await.0.unwrap();
3174
3175        // This will close the channel but won't release it yet.
3176        drop(channel);
3177
3178        check_message(
3179            server.next().await.unwrap(),
3180            protocol::CloseChannel {
3181                channel_id: ChannelId(0),
3182            },
3183        );
3184
3185        server.send(in_msg(
3186            MessageType::RESCIND_CHANNEL_OFFER,
3187            protocol::RescindChannelOffer {
3188                channel_id: ChannelId(0),
3189            },
3190        ));
3191
3192        // Should be released.
3193        check_message(
3194            server.next().await.unwrap(),
3195            protocol::RelIdReleased {
3196                channel_id: ChannelId(0),
3197            },
3198        );
3199
3200        let offer = protocol::OfferChannel {
3201            interface_id: Guid::new_random(),
3202            instance_id: Guid::new_random(),
3203            rsvd: [0; 4],
3204            flags: OfferFlags::new(),
3205            mmio_megabytes: 0,
3206            user_defined: UserDefinedData::new_zeroed(),
3207            subchannel_index: 0,
3208            mmio_megabytes_optional: 0,
3209            channel_id: ChannelId(0),
3210            monitor_id: 0,
3211            monitor_allocated: 0,
3212            is_dedicated: 0,
3213            connection_id: 0,
3214        };
3215
3216        server.send(in_msg(MessageType::OFFER_CHANNEL, offer));
3217
3218        // New offer should come through.
3219        connection.offer_recv.next().await.unwrap();
3220    }
3221}