vmbus_client/
lib.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! Client driver for the Hyper-V Virtual Machine Bus (VmBus).
5
6#![expect(missing_docs)]
7#![forbid(unsafe_code)]
8
9pub mod driver;
10pub mod filter;
11mod hvsock;
12pub mod saved_state;
13
14pub use self::saved_state::SavedState;
15use anyhow::Context as _;
16use anyhow::Result;
17use futures::FutureExt;
18use futures::StreamExt;
19use futures::future::OptionFuture;
20use futures::stream::SelectAll;
21use futures_concurrency::future::Race;
22use guid::Guid;
23use inspect::Inspect;
24use mesh::rpc::FailableRpc;
25use mesh::rpc::Rpc;
26use mesh::rpc::RpcSend;
27use pal_async::task::Spawn;
28use pal_async::task::Task;
29use pal_event::Event;
30use std::collections::HashMap;
31use std::collections::VecDeque;
32use std::collections::hash_map;
33use std::convert::TryInto;
34use std::future::Future;
35use std::future::poll_fn;
36use std::ops::Deref;
37use std::ops::DerefMut;
38use std::pin::pin;
39use std::sync::Arc;
40use std::sync::atomic::AtomicU32;
41use std::sync::atomic::Ordering;
42use std::task::Context;
43use std::task::Poll;
44use thiserror::Error;
45use vmbus_async::async_dgram::AsyncRecv;
46use vmbus_async::async_dgram::AsyncRecvExt;
47use vmbus_channel::bus::GpadlRequest;
48use vmbus_channel::bus::ModifyRequest;
49use vmbus_channel::bus::OfferKey;
50use vmbus_channel::bus::OpenData;
51use vmbus_channel::gpadl::GpadlId;
52use vmbus_core::HvsockConnectRequest;
53use vmbus_core::OutgoingMessage;
54use vmbus_core::TaggedStream;
55use vmbus_core::VersionInfo;
56use vmbus_core::protocol;
57use vmbus_core::protocol::ChannelId;
58use vmbus_core::protocol::ConnectionState;
59use vmbus_core::protocol::FeatureFlags;
60use vmbus_core::protocol::Message;
61use vmbus_core::protocol::OpenChannelFlags;
62use vmbus_core::protocol::Version;
63use vmcore::interrupt::Interrupt;
64use vmcore::synic::MonitorPageGpas;
65use zerocopy::Immutable;
66use zerocopy::IntoBytes;
67use zerocopy::KnownLayout;
68
69const SINT: u8 = 2;
70const VTL: u8 = 0;
71const SUPPORTED_VERSIONS: &[Version] = &[Version::Iron, Version::Copper];
72const SUPPORTED_FEATURE_FLAGS: FeatureFlags = FeatureFlags::new()
73    .with_guest_specified_signal_parameters(true)
74    .with_channel_interrupt_redirection(true)
75    .with_modify_connection(true)
76    .with_client_id(true)
77    .with_pause_resume(true);
78
79/// The client interface synic events.
80pub trait SynicEventClient: Send + Sync {
81    /// Maps an incoming event signal on SINT7 to `event`.
82    fn map_event(&self, event_flag: u16, event: &Event) -> std::io::Result<()>;
83
84    /// Unmaps an event previously mapped with `map_event`.
85    fn unmap_event(&self, event_flag: u16);
86
87    /// Signals an event on the synic.
88    fn signal_event(&self, connection_id: u32, event_flag: u16) -> std::io::Result<()>;
89}
90
91/// A stream of vmbus messages that can be paused and resumed.
92pub trait VmbusMessageSource: AsyncRecv + Send {
93    /// Stop accepting new messages from the synic. After this is called, the message source must
94    /// return any pending messages already in the queue, and then return EOF.
95    fn pause_message_stream(&mut self) {}
96
97    /// Resume accepting new messages from the synic.
98    fn resume_message_stream(&mut self) {}
99}
100
101pub trait PollPostMessage: Send {
102    fn poll_post_message(
103        &mut self,
104        cx: &mut Context<'_>,
105        connection_id: u32,
106        typ: u32,
107        msg: &[u8],
108    ) -> Poll<()>;
109}
110
111#[derive(Inspect)]
112pub struct VmbusClient {
113    #[inspect(flatten, send = "TaskRequest::Inspect")]
114    task_send: mesh::Sender<TaskRequest>,
115    #[inspect(skip)]
116    access: VmbusClientAccess,
117    #[inspect(skip)]
118    task: Task<ClientTask>,
119}
120
121#[derive(Debug, thiserror::Error)]
122pub enum ConnectError {
123    #[error("invalid state to connect to the server")]
124    InvalidState,
125    #[error("no supported protocol versions")]
126    NoSupportedVersions,
127    #[error("failed to connect to the server: {0:?}")]
128    FailedToConnect(ConnectionState),
129}
130
131#[derive(Clone)]
132pub struct VmbusClientAccess {
133    client_request_send: mesh::Sender<ClientRequest>,
134}
135
136/// A builder for creating a [`VmbusClient`].
137pub struct VmbusClientBuilder {
138    event_client: Arc<dyn SynicEventClient>,
139    msg_source: Box<dyn VmbusMessageSource>,
140    msg_client: Box<dyn PollPostMessage>,
141}
142
143impl VmbusClientBuilder {
144    /// Creates a new instance of the builder with the given synic input.
145    pub fn new(
146        event_client: impl SynicEventClient + 'static,
147        msg_source: impl VmbusMessageSource + 'static,
148        msg_client: impl PollPostMessage + 'static,
149    ) -> Self {
150        Self {
151            event_client: Arc::new(event_client),
152            msg_source: Box::new(msg_source),
153            msg_client: Box::new(msg_client),
154        }
155    }
156
157    /// Creates a new instance with a receiver for incoming synic messages.
158    pub fn build(self, spawner: &impl Spawn) -> VmbusClient {
159        let (task_send, task_recv) = mesh::channel();
160        let (client_request_send, client_request_recv) = mesh::channel();
161
162        let inner = ClientTaskInner {
163            messages: OutgoingMessages {
164                poster: self.msg_client,
165                queued: VecDeque::new(),
166                state: OutgoingMessageState::Paused,
167            },
168            teardown_gpadls: HashMap::new(),
169            channel_requests: SelectAll::new(),
170            synic: SynicState {
171                event_flag_state: Vec::new(),
172                event_client: self.event_client,
173            },
174        };
175
176        let mut task = ClientTask {
177            inner,
178            channels: ChannelList::default(),
179            task_recv,
180            running: false,
181            msg_source: self.msg_source,
182            client_request_recv,
183            state: ClientState::Disconnected,
184            modify_request: None,
185            hvsock_tracker: hvsock::HvsockRequestTracker::new(),
186        };
187
188        let task = spawner.spawn("vmbus client", async move {
189            task.run().await;
190            task
191        });
192
193        VmbusClient {
194            access: VmbusClientAccess {
195                client_request_send,
196            },
197            task_send,
198            task,
199        }
200    }
201}
202
203impl VmbusClient {
204    /// Connects to the server, negotiating the protocol version and retrieving
205    /// the initial list of channel offers.
206    pub async fn connect(
207        &mut self,
208        target_message_vp: u32,
209        monitor_page: Option<MonitorPageGpas>,
210        client_id: Guid,
211    ) -> Result<ConnectResult, ConnectError> {
212        let request = ConnectRequest {
213            target_message_vp,
214            monitor_page,
215            client_id,
216        };
217
218        self.access
219            .client_request_send
220            .call(ClientRequest::Connect, request)
221            .await
222            .unwrap()
223    }
224
225    pub async fn unload(self) {
226        self.access
227            .client_request_send
228            .call(ClientRequest::Unload, ())
229            .await
230            .unwrap();
231
232        self.sever().await;
233    }
234
235    pub fn access(&self) -> &VmbusClientAccess {
236        &self.access
237    }
238
239    pub fn start(&mut self) {
240        self.task_send.send(TaskRequest::Start);
241    }
242
243    pub async fn stop(&mut self) {
244        self.task_send
245            .call(TaskRequest::Stop, ())
246            .await
247            .expect("Failed to send stop request");
248    }
249
250    pub async fn save(&self) -> SavedState {
251        self.task_send
252            .call(TaskRequest::Save, ())
253            .await
254            .expect("Failed to send save request")
255    }
256
257    pub async fn restore(
258        &mut self,
259        state: SavedState,
260    ) -> Result<Option<ConnectResult>, RestoreError> {
261        self.task_send
262            .call(TaskRequest::Restore, state)
263            .await
264            .expect("Failed to send restore request")
265    }
266
267    pub async fn post_restore(&mut self) {
268        self.task_send
269            .call(TaskRequest::PostRestore, ())
270            .await
271            .expect("Failed to send post-restore request");
272    }
273
274    async fn sever(self) -> VmbusClientBuilder {
275        drop(self.task_send);
276        let task = self.task.await;
277        VmbusClientBuilder {
278            event_client: task.inner.synic.event_client,
279            msg_source: task.msg_source,
280            msg_client: task.inner.messages.poster,
281        }
282    }
283}
284
285#[derive(Debug)]
286pub struct ConnectResult {
287    pub version: VersionInfo,
288    pub offers: Vec<OfferInfo>,
289    pub offer_recv: mesh::Receiver<OfferInfo>,
290}
291
292impl VmbusClientAccess {
293    pub async fn modify(&self, request: ModifyConnectionRequest) -> ConnectionState {
294        self.client_request_send
295            .call(ClientRequest::Modify, request)
296            .await
297            .expect("Failed to send modify request")
298    }
299
300    pub fn connect_hvsock(
301        &self,
302        request: HvsockConnectRequest,
303    ) -> impl Future<Output = Option<OfferInfo>> + use<> {
304        self.client_request_send
305            .call(ClientRequest::HvsockConnect, request)
306            .map(|r| r.ok().flatten())
307    }
308}
309
310#[derive(Debug)]
311pub struct OpenRequest {
312    pub open_data: OpenData,
313    pub incoming_event: Option<Event>,
314    pub use_vtl2_connection_id: bool,
315}
316
317#[derive(Debug)]
318pub struct RestoreRequest {
319    pub incoming_event: Option<Event>,
320    // FUTURE: move to saved state, don't rely on the caller.
321    pub redirected_event_flag: Option<u16>,
322    // FUTURE: ditto
323    pub connection_id: u32,
324}
325
326/// Expresses an operation requested of the client.
327pub enum ChannelRequest {
328    Open(FailableRpc<OpenRequest, OpenOutput>),
329    Restore(FailableRpc<RestoreRequest, OpenOutput>),
330    Close(Rpc<(), ()>),
331    Gpadl(FailableRpc<GpadlRequest, ()>),
332    TeardownGpadl(Rpc<GpadlId, ()>),
333    Modify(Rpc<ModifyRequest, i32>),
334}
335
336#[derive(Debug)]
337pub struct OpenOutput {
338    // FUTURE: remove this once it's part of the saved state.
339    pub redirected_event_flag: Option<u16>,
340}
341
342impl std::fmt::Display for ChannelRequest {
343    fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
344        let s = match self {
345            ChannelRequest::Open(_) => "Open",
346            ChannelRequest::Close(_) => "Close",
347            ChannelRequest::Restore(_) => "Restore",
348            ChannelRequest::Gpadl(_) => "Gpadl",
349            ChannelRequest::TeardownGpadl(_) => "TeardownGpadl",
350            ChannelRequest::Modify(_) => "Modify",
351        };
352        fmt.pad(s)
353    }
354}
355
356#[derive(Debug, Error)]
357pub enum RestoreError {
358    #[error("unsupported protocol version {0:#x}")]
359    UnsupportedVersion(u32),
360
361    #[error("unsupported feature flags {0:#x}")]
362    UnsupportedFeatureFlags(u32),
363
364    #[error("duplicate channel id {0}")]
365    DuplicateChannelId(u32),
366
367    #[error("duplicate gpadl id {0}")]
368    DuplicateGpadlId(u32),
369
370    #[error("gpadl for unknown channel id {0}")]
371    GpadlForUnknownChannelId(u32),
372
373    #[error("invalid pending message")]
374    InvalidPendingMessage(#[source] vmbus_core::MessageTooLarge),
375
376    #[error("failed to offer restored channel")]
377    OfferFailed(#[source] anyhow::Error),
378}
379
380/// Provides the offer details from the server in addition to both a channel
381/// to request client actions and a channel to receive server responses.
382#[derive(Debug, Inspect)]
383pub struct OfferInfo {
384    pub offer: protocol::OfferChannel,
385    #[inspect(skip)]
386    pub guest_to_host_interrupt: Interrupt,
387    #[inspect(skip)]
388    pub request_send: mesh::Sender<ChannelRequest>,
389    #[inspect(skip)]
390    pub revoke_recv: mesh::OneshotReceiver<()>,
391}
392
393#[derive(Debug)]
394enum ClientRequest {
395    Connect(Rpc<ConnectRequest, Result<ConnectResult, ConnectError>>),
396    Unload(Rpc<(), ()>),
397    Modify(Rpc<ModifyConnectionRequest, ConnectionState>),
398    HvsockConnect(Rpc<HvsockConnectRequest, Option<OfferInfo>>),
399}
400
401impl std::fmt::Display for ClientRequest {
402    fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
403        let s = match self {
404            ClientRequest::Connect(..) => "Connect",
405            ClientRequest::Unload { .. } => "Unload",
406            ClientRequest::Modify(..) => "Modify",
407            ClientRequest::HvsockConnect(..) => "HvsockConnect",
408        };
409        fmt.pad(s)
410    }
411}
412
413enum TaskRequest {
414    Inspect(inspect::Deferred),
415    Save(Rpc<(), SavedState>),
416    Restore(Rpc<SavedState, Result<Option<ConnectResult>, RestoreError>>),
417    PostRestore(Rpc<(), ()>),
418    Start,
419    Stop(Rpc<(), ()>),
420}
421
422/// The overall state machine used to drive which actions the client can legally
423/// take. This primarily pertains to overall client activity but has a
424/// side-effect of limiting whether or not channels can perform actions.
425#[derive(Inspect)]
426#[inspect(external_tag)]
427enum ClientState {
428    /// The client has yet to connect to the server.
429    Disconnected,
430    /// The client has initiated contact with the server.
431    Connecting {
432        version: Version,
433        #[inspect(skip)]
434        rpc: Rpc<ConnectRequest, Result<ConnectResult, ConnectError>>,
435    },
436    /// The client has negotiated the protocol version with the server.
437    Connected {
438        version: VersionInfo,
439        #[inspect(skip)]
440        offer_send: mesh::Sender<OfferInfo>,
441    },
442    /// The client has requested offers from the server.
443    RequestingOffers {
444        version: VersionInfo,
445        #[inspect(skip)]
446        rpc: Rpc<(), Result<ConnectResult, ConnectError>>,
447        #[inspect(skip)]
448        offers: Vec<OfferInfo>,
449    },
450    /// The client has initiated an unload from the server.
451    Disconnecting {
452        version: VersionInfo,
453        #[inspect(skip)]
454        rpc: Rpc<(), ()>,
455    },
456}
457
458impl ClientState {
459    fn get_version(&self) -> Option<VersionInfo> {
460        match self {
461            ClientState::Connected { version, .. } => Some(*version),
462            ClientState::RequestingOffers { version, .. } => Some(*version),
463            ClientState::Disconnecting { version, .. } => Some(*version),
464            ClientState::Disconnected | ClientState::Connecting { .. } => None,
465        }
466    }
467}
468
469impl std::fmt::Display for ClientState {
470    fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
471        let s = match self {
472            ClientState::Disconnected => "Disconnected",
473            ClientState::Connecting { .. } => "Connecting",
474            ClientState::Connected { .. } => "Connected",
475            ClientState::RequestingOffers { .. } => "RequestingOffers",
476            ClientState::Disconnecting { .. } => "Disconnecting",
477        };
478        fmt.pad(s)
479    }
480}
481
482#[derive(Copy, Clone, Debug, Default)]
483struct ConnectRequest {
484    target_message_vp: u32,
485    monitor_page: Option<MonitorPageGpas>,
486    client_id: Guid,
487}
488
489#[derive(Copy, Clone, Debug, Default)]
490pub struct ModifyConnectionRequest {
491    pub monitor_page: Option<MonitorPageGpas>,
492}
493
494impl From<ModifyConnectionRequest> for protocol::ModifyConnection {
495    fn from(value: ModifyConnectionRequest) -> Self {
496        let monitor_page = value.monitor_page.unwrap_or_default();
497
498        Self {
499            parent_to_child_monitor_page_gpa: monitor_page.parent_to_child,
500            child_to_parent_monitor_page_gpa: monitor_page.child_to_parent,
501        }
502    }
503}
504
505/// The per-channel state which dictates which whether or not a channel can
506/// request an Open/Close. As GPADLs can happen outside this loop there is no
507/// state tied to GPADL actions.
508#[derive(Debug, Inspect)]
509#[inspect(external_tag)]
510enum ChannelState {
511    /// The channel has been offered to the client.
512    Offered,
513    /// The channel has requested the server to be opened.
514    Opening {
515        redirected_event_flag: Option<u16>,
516        #[inspect(skip)]
517        redirected_event: Option<Event>,
518        #[inspect(skip)]
519        rpc: FailableRpc<(), OpenOutput>,
520    },
521    /// The channel has been restored but not claimed.
522    Restored,
523    /// The channel has been successfully opened.
524    Opened {
525        redirected_event_flag: Option<u16>,
526        #[inspect(skip)]
527        redirected_event: Option<Event>,
528    },
529    /// The channel has been revoked by the server.
530    Revoked,
531}
532
533impl std::fmt::Display for ChannelState {
534    fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
535        let s = match self {
536            ChannelState::Opening { .. } => "Opening",
537            ChannelState::Offered => "Offered",
538            ChannelState::Opened { .. } => "Opened",
539            ChannelState::Restored => "Restored",
540            ChannelState::Revoked => "Revoked",
541        };
542        fmt.pad(s)
543    }
544}
545
546#[derive(Debug, Inspect)]
547struct Channel {
548    offer: protocol::OfferChannel,
549    // When dropped, notifies the caller the channel has been revoked.
550    #[inspect(skip)]
551    revoke_send: Option<mesh::OneshotSender<()>>,
552    state: ChannelState,
553    #[inspect(with = "|x| x.is_some()")]
554    modify_response_send: Option<Rpc<(), i32>>,
555    #[inspect(with = "|x| inspect::iter_by_key(x).map_key(|x| x.0)")]
556    gpadls: HashMap<GpadlId, GpadlState>,
557    is_client_released: bool,
558    connection_id: Arc<AtomicU32>,
559}
560
561impl Channel {
562    fn pending_request(&self) -> Option<&'static str> {
563        if self.modify_response_send.is_some() {
564            return Some("modify");
565        }
566        self.gpadls.iter().find_map(|(_, gpadl)| match gpadl {
567            GpadlState::Offered(_) => Some("creating gpadl"),
568            GpadlState::Created => None,
569            GpadlState::TearingDown { .. } => Some("tearing down gpadl"),
570        })
571    }
572}
573
574#[derive(Inspect)]
575struct ClientTask {
576    #[inspect(flatten)]
577    inner: ClientTaskInner,
578    channels: ChannelList,
579    state: ClientState,
580    hvsock_tracker: hvsock::HvsockRequestTracker,
581    running: bool,
582    #[inspect(with = "|x| x.is_some()")]
583    modify_request: Option<Rpc<ModifyConnectionRequest, ConnectionState>>,
584    #[inspect(skip)]
585    msg_source: Box<dyn VmbusMessageSource>,
586    #[inspect(skip)]
587    task_recv: mesh::Receiver<TaskRequest>,
588    #[inspect(skip)]
589    client_request_recv: mesh::Receiver<ClientRequest>,
590}
591
592impl ClientTask {
593    fn handle_initiate_contact(
594        &mut self,
595        rpc: Rpc<ConnectRequest, Result<ConnectResult, ConnectError>>,
596        version: Version,
597    ) {
598        let ClientState::Disconnected = self.state else {
599            tracing::warn!(client_state = %self.state, "invalid client state for InitiateContact");
600            rpc.complete(Err(ConnectError::InvalidState));
601            return;
602        };
603        let feature_flags = if version >= Version::Copper {
604            SUPPORTED_FEATURE_FLAGS
605        } else {
606            FeatureFlags::new()
607        };
608
609        let request = rpc.input();
610
611        tracing::debug!(version = ?version, ?feature_flags, "VmBus client connecting");
612        let target_info = protocol::TargetInfo::new()
613            .with_sint(SINT)
614            .with_vtl(VTL)
615            .with_feature_flags(feature_flags.into());
616        let monitor_page = request.monitor_page.unwrap_or_default();
617        let msg = protocol::InitiateContact2 {
618            initiate_contact: protocol::InitiateContact {
619                version_requested: version as u32,
620                target_message_vp: request.target_message_vp,
621                interrupt_page_or_target_info: target_info.into(),
622                parent_to_child_monitor_page_gpa: monitor_page.parent_to_child,
623                child_to_parent_monitor_page_gpa: monitor_page.child_to_parent,
624            },
625            client_id: request.client_id,
626        };
627
628        self.state = ClientState::Connecting { version, rpc };
629        if version < Version::Copper {
630            self.inner.messages.send(&msg.initiate_contact)
631        } else {
632            self.inner.messages.send(&msg);
633        }
634    }
635
636    fn handle_unload(&mut self, rpc: Rpc<(), ()>) {
637        tracing::debug!(%self.state, "VmBus client disconnecting");
638        self.state = ClientState::Disconnecting {
639            version: self.state.get_version().expect("invalid state for unload"),
640            rpc,
641        };
642
643        self.inner.messages.send(&protocol::Unload {});
644    }
645
646    fn handle_modify(&mut self, request: Rpc<ModifyConnectionRequest, ConnectionState>) {
647        if !matches!(self.state, ClientState::Connected { version, .. }
648            if version.feature_flags.modify_connection())
649        {
650            tracing::warn!("ModifyConnection not supported");
651            request.complete(ConnectionState::FAILED_UNKNOWN_FAILURE);
652            return;
653        }
654
655        if self.modify_request.is_some() {
656            tracing::warn!("Duplicate ModifyConnection request");
657            request.complete(ConnectionState::FAILED_UNKNOWN_FAILURE);
658            return;
659        }
660
661        let message = protocol::ModifyConnection::from(*request.input());
662        self.modify_request = Some(request);
663        self.inner.messages.send(&message);
664    }
665
666    fn handle_tl_connect(&mut self, rpc: Rpc<HvsockConnectRequest, Option<OfferInfo>>) {
667        // The client only supports protocol versions which use the newer message format.
668        // The host will not send a TlConnectRequestResult message on success, so a response to this
669        // message is not guaranteed.
670        let message = protocol::TlConnectRequest2::from(*rpc.input());
671        self.hvsock_tracker.add_request(rpc);
672        self.inner.messages.send(&message);
673    }
674
675    fn handle_client_request(&mut self, request: ClientRequest) {
676        match request {
677            ClientRequest::Connect(rpc) => {
678                self.handle_initiate_contact(rpc, *SUPPORTED_VERSIONS.last().unwrap());
679            }
680            ClientRequest::Unload(rpc) => {
681                self.handle_unload(rpc);
682            }
683            ClientRequest::Modify(request) => self.handle_modify(request),
684            ClientRequest::HvsockConnect(request) => self.handle_tl_connect(request),
685        }
686    }
687
688    fn handle_version_response(&mut self, msg: protocol::VersionResponse2) {
689        let old_state = std::mem::replace(&mut self.state, ClientState::Disconnected);
690        let ClientState::Connecting { version, rpc } = old_state else {
691            self.state = old_state;
692            tracing::warn!(
693                client_state = %self.state,
694                "invalid client state to handle VersionResponse"
695            );
696            return;
697        };
698        if msg.version_response.version_supported > 0 {
699            if msg.version_response.connection_state != ConnectionState::SUCCESSFUL {
700                rpc.complete(Err(ConnectError::FailedToConnect(
701                    msg.version_response.connection_state,
702                )));
703                return;
704            }
705
706            let feature_flags = if version >= Version::Copper {
707                FeatureFlags::from(msg.supported_features)
708            } else {
709                FeatureFlags::new()
710            };
711
712            let version = VersionInfo {
713                version,
714                feature_flags,
715            };
716
717            self.inner.messages.send(&protocol::RequestOffers {});
718            self.state = ClientState::RequestingOffers {
719                version,
720                rpc: rpc.split().1,
721                offers: Vec::new(),
722            };
723            tracing::info!(?version, "VmBus client connected, requesting offers");
724        } else {
725            let index = SUPPORTED_VERSIONS
726                .iter()
727                .position(|v| *v == version)
728                .unwrap();
729
730            if index == 0 {
731                rpc.complete(Err(ConnectError::NoSupportedVersions));
732                return;
733            }
734            let next_version = SUPPORTED_VERSIONS[index - 1];
735            tracing::debug!(
736                version = version as u32,
737                next_version = next_version as u32,
738                "Unsupported version, retrying"
739            );
740            self.handle_initiate_contact(rpc, next_version);
741        }
742    }
743
744    fn create_channel(&mut self, offer: protocol::OfferChannel) -> Result<OfferInfo> {
745        self.create_channel_core(offer, ChannelState::Offered)
746    }
747
748    fn create_channel_core(
749        &mut self,
750        offer: protocol::OfferChannel,
751        state: ChannelState,
752    ) -> Result<OfferInfo> {
753        if self.channels.0.contains_key(&offer.channel_id) {
754            anyhow::bail!("channel {:?} exists", offer.channel_id);
755        }
756        let (request_send, request_recv) = mesh::channel();
757        let (revoke_send, revoke_recv) = mesh::oneshot();
758
759        let connection_id = Arc::new(AtomicU32::new(0));
760        self.channels.0.insert(
761            offer.channel_id,
762            Channel {
763                revoke_send: Some(revoke_send),
764                offer,
765                state,
766                modify_response_send: None,
767                gpadls: HashMap::new(),
768                is_client_released: false,
769                connection_id: connection_id.clone(),
770            },
771        );
772
773        self.inner
774            .channel_requests
775            .push(TaggedStream::new(offer.channel_id, request_recv));
776
777        Ok(OfferInfo {
778            offer,
779            guest_to_host_interrupt: self.inner.synic.guest_to_host_interrupt(connection_id),
780            revoke_recv,
781            request_send,
782        })
783    }
784
785    fn handle_offer(&mut self, offer: protocol::OfferChannel) {
786        let offer_info = self
787            .create_channel(offer)
788            .expect("channel should not exist");
789
790        tracing::info!(
791                state = %self.state,
792                channel_id = offer.channel_id.0,
793                interface_id = %offer.interface_id,
794                instance_id = %offer.instance_id,
795                subchannel_index = offer.subchannel_index,
796                "received offer");
797
798        if let Some(offer) = self.hvsock_tracker.check_offer(&offer_info.offer) {
799            offer.complete(Some(offer_info));
800        } else {
801            match &mut self.state {
802                ClientState::Connected { offer_send, .. } => {
803                    offer_send.send(offer_info);
804                }
805                ClientState::RequestingOffers { offers, .. } => {
806                    offers.push(offer_info);
807                }
808                state => unreachable!("invalid client state for OfferChannel: {state}"),
809            }
810        }
811    }
812
813    fn handle_rescind(&mut self, rescind: protocol::RescindChannelOffer) -> TriedRelease {
814        let mut channel = self.channels.get_mut(rescind.channel_id);
815        tracing::info!(
816            state = %self.state,
817            channel_id = rescind.channel_id.0,
818            key = %OfferKey::from(&channel.offer),
819            "received rescind"
820        );
821        let event_flag = match std::mem::replace(&mut channel.state, ChannelState::Revoked) {
822            ChannelState::Offered => None,
823            ChannelState::Opening {
824                redirected_event_flag,
825                redirected_event: _,
826                rpc,
827            } => {
828                rpc.fail(anyhow::anyhow!("channel revoked"));
829                redirected_event_flag
830            }
831            ChannelState::Restored => None,
832            ChannelState::Opened {
833                redirected_event_flag,
834                redirected_event: _,
835            } => redirected_event_flag,
836            ChannelState::Revoked => {
837                panic!("channel id {:?} already revoked", rescind.channel_id);
838            }
839        };
840        if let Some(event_flag) = event_flag {
841            self.inner.synic.free_event_flag(event_flag);
842        }
843
844        // Drop the channel and send the revoked message to the client.
845        channel.revoke_send.take().unwrap().send(());
846
847        channel.try_release(&mut self.inner.messages)
848    }
849
850    fn handle_offers_delivered(&mut self) {
851        match std::mem::replace(&mut self.state, ClientState::Disconnected) {
852            ClientState::RequestingOffers {
853                version,
854                rpc,
855                offers,
856            } => {
857                tracing::info!(version = ?version, "VmBus client connected, offers delivered");
858                let (offer_send, offer_recv) = mesh::channel();
859                self.state = ClientState::Connected {
860                    version,
861                    offer_send,
862                };
863                rpc.complete(Ok(ConnectResult {
864                    version,
865                    offers,
866                    offer_recv,
867                }));
868            }
869            state => {
870                tracing::warn!(client_state = %state, "invalid client state for OffersDelivered");
871                self.state = state;
872            }
873        }
874    }
875
876    fn handle_gpadl_created(&mut self, request: protocol::GpadlCreated) -> TriedRelease {
877        let mut channel = self.channels.get_mut(request.channel_id);
878        let Some(gpadl_state) = channel.gpadls.get_mut(&request.gpadl_id) else {
879            panic!("GpadlCreated for unknown gpadl {:#x}", request.gpadl_id.0);
880        };
881
882        let rpc = match std::mem::replace(gpadl_state, GpadlState::Created) {
883            GpadlState::Offered(rpc) => rpc,
884            old_state => {
885                panic!(
886                    "invalid state {old_state:?} for gpadl {:#x}:{:#x}",
887                    request.channel_id.0, request.gpadl_id.0
888                );
889            }
890        };
891
892        let gpadl_created = request.status == protocol::STATUS_SUCCESS;
893        if gpadl_created {
894            rpc.complete(Ok(()));
895        } else {
896            channel.gpadls.remove(&request.gpadl_id).unwrap();
897            rpc.fail(anyhow::anyhow!(
898                "gpadl creation failed: {:#x}",
899                request.status
900            ));
901        };
902        channel.try_release(&mut self.inner.messages)
903    }
904
905    fn handle_open_result(&mut self, result: protocol::OpenResult) {
906        let mut channel = self.channels.get_mut(result.channel_id);
907        tracing::debug!(
908            channel_id = result.channel_id.0,
909            key = %OfferKey::from(&channel.offer),
910            result = result.status,
911            "received open result"
912        );
913
914        let channel_opened = result.status == protocol::STATUS_SUCCESS as u32;
915        let old_state = std::mem::replace(&mut channel.state, ChannelState::Offered);
916        let ChannelState::Opening {
917            redirected_event_flag,
918            redirected_event,
919            rpc,
920        } = old_state
921        else {
922            tracing::warn!(
923                key = %OfferKey::from(&channel.offer),
924                old_state = ?channel.state,
925                channel_opened,
926                "invalid state for open result"
927            );
928            channel.state = old_state;
929            return;
930        };
931
932        if !channel_opened {
933            if let Some(event_flag) = redirected_event_flag {
934                self.inner.synic.free_event_flag(event_flag);
935            }
936            rpc.fail(anyhow::anyhow!("open failed: {:#x}", result.status));
937            return;
938        }
939
940        channel.state = ChannelState::Opened {
941            redirected_event_flag,
942            redirected_event,
943        };
944
945        rpc.complete(Ok(OpenOutput {
946            redirected_event_flag,
947        }));
948    }
949
950    fn handle_gpadl_torndown(&mut self, request: protocol::GpadlTorndown) -> TriedRelease {
951        let Some(channel_id) = self.inner.teardown_gpadls.remove(&request.gpadl_id) else {
952            panic!("gpadl {:#x} not in teardown list", request.gpadl_id.0);
953        };
954
955        let mut channel = self.channels.get_mut(channel_id);
956        tracing::debug!(
957            gpadl_id = request.gpadl_id.0,
958            channel_id = channel_id.0,
959            key = %OfferKey::from(&channel.offer),
960            "Received GpadlTorndown"
961        );
962
963        let gpadl_state = channel
964            .gpadls
965            .remove(&request.gpadl_id)
966            .expect("gpadl validated above");
967
968        let GpadlState::TearingDown { rpcs } = gpadl_state else {
969            panic!("gpadl should be tearing down if in teardown list, state = {gpadl_state:?}");
970        };
971
972        for rpc in rpcs {
973            rpc.complete(());
974        }
975        channel.try_release(&mut self.inner.messages)
976    }
977
978    fn handle_unload_complete(&mut self) {
979        match std::mem::replace(&mut self.state, ClientState::Disconnected) {
980            ClientState::Disconnecting { version: _, rpc } => {
981                tracing::info!("VmBus client disconnected");
982                rpc.complete(());
983            }
984            state => {
985                tracing::warn!(client_state = %state, "invalid client state for UnloadComplete");
986            }
987        }
988    }
989
990    fn handle_modify_complete(&mut self, response: protocol::ModifyConnectionResponse) {
991        if let Some(request) = self.modify_request.take() {
992            request.complete(response.connection_state)
993        } else {
994            tracing::warn!("Unexpected modify complete request");
995        }
996    }
997
998    fn handle_modify_channel_response(
999        &mut self,
1000        response: protocol::ModifyChannelResponse,
1001    ) -> TriedRelease {
1002        let mut channel = self.channels.get_mut(response.channel_id);
1003        let Some(sender) = channel.modify_response_send.take() else {
1004            panic!(
1005                "unexpected modify channel response for channel {:#x}",
1006                response.channel_id.0
1007            );
1008        };
1009
1010        sender.complete(response.status);
1011        channel.try_release(&mut self.inner.messages)
1012    }
1013
1014    fn handle_tl_connect_result(&mut self, response: protocol::TlConnectResult) {
1015        if let Some(rpc) = self.hvsock_tracker.check_result(&response) {
1016            rpc.complete(None);
1017        }
1018    }
1019
1020    /// Returns false if the message was a pause complete message.
1021    fn handle_synic_message(&mut self, data: &[u8]) -> bool {
1022        let msg = Message::parse(data, self.state.get_version()).unwrap();
1023        tracing::trace!(?msg, "received client message from synic");
1024
1025        match msg {
1026            Message::VersionResponse3(version_response, ..) => {
1027                // The client never sends the server-specified monitor pages feature flag, but
1028                // since version response messages are distinguished only by size, the response can
1029                // still look like `VersionResponse3` if the size was not set exactly by the server.
1030                // Since the feature flag can't be set, the extra data can be ignored.
1031                self.handle_version_response(version_response.version_response2);
1032            }
1033            Message::VersionResponse2(version_response, ..) => {
1034                self.handle_version_response(version_response);
1035            }
1036            Message::VersionResponse(version_response, ..) => {
1037                self.handle_version_response(version_response.into());
1038            }
1039            Message::OfferChannel(offer, ..) => {
1040                self.handle_offer(offer);
1041            }
1042            Message::AllOffersDelivered(..) => {
1043                self.handle_offers_delivered();
1044            }
1045            Message::UnloadComplete(..) => {
1046                self.handle_unload_complete();
1047            }
1048            Message::ModifyConnectionResponse(response, ..) => {
1049                self.handle_modify_complete(response);
1050            }
1051            Message::GpadlCreated(gpadl, ..) => {
1052                self.handle_gpadl_created(gpadl);
1053            }
1054            Message::OpenResult(result, ..) => {
1055                self.handle_open_result(result);
1056            }
1057            Message::GpadlTorndown(gpadl, ..) => {
1058                self.handle_gpadl_torndown(gpadl);
1059            }
1060            Message::RescindChannelOffer(rescind, ..) => {
1061                self.handle_rescind(rescind);
1062            }
1063            Message::ModifyChannelResponse(response, ..) => {
1064                self.handle_modify_channel_response(response);
1065            }
1066            Message::TlConnectResult(response, ..) => self.handle_tl_connect_result(response),
1067            // Unsupported messages.
1068            Message::CloseReservedChannelResponse(..) => {
1069                todo!("Unsupported message {msg:?}")
1070            }
1071            Message::PauseResponse(..) => {
1072                return false;
1073            }
1074            // Messages that should only be received by a vmbus server.
1075            Message::RequestOffers(..)
1076            | Message::OpenChannel2(..)
1077            | Message::OpenChannel(..)
1078            | Message::CloseChannel(..)
1079            | Message::GpadlHeader(..)
1080            | Message::GpadlBody(..)
1081            | Message::GpadlTeardown(..)
1082            | Message::RelIdReleased(..)
1083            | Message::InitiateContact(..)
1084            | Message::InitiateContact2(..)
1085            | Message::Unload(..)
1086            | Message::OpenReservedChannel(..)
1087            | Message::CloseReservedChannel(..)
1088            | Message::TlConnectRequest2(..)
1089            | Message::TlConnectRequest(..)
1090            | Message::ModifyChannel(..)
1091            | Message::ModifyConnection(..)
1092            | Message::Pause(..)
1093            | Message::Resume(..) => {
1094                unreachable!("Client received server message {msg:?}");
1095            }
1096        }
1097        true
1098    }
1099
1100    fn handle_open_channel(
1101        &mut self,
1102        channel_id: ChannelId,
1103        rpc: FailableRpc<OpenRequest, OpenOutput>,
1104    ) {
1105        let mut channel = self.channels.get_mut(channel_id);
1106        match &channel.state {
1107            ChannelState::Offered => {}
1108            ChannelState::Revoked => {
1109                rpc.fail(anyhow::anyhow!("channel revoked"));
1110                return;
1111            }
1112            state => {
1113                rpc.fail(anyhow::anyhow!("invalid channel state: {}", state));
1114                return;
1115            }
1116        }
1117
1118        tracing::info!(
1119            channel_id = channel_id.0,
1120            key = %OfferKey::from(&channel.offer),
1121            "opening channel on host"
1122        );
1123
1124        let (request, rpc) = rpc.split();
1125        let open_data = &request.open_data;
1126
1127        let supports_interrupt_redirection =
1128            if let ClientState::Connected { version, .. } = self.state {
1129                version.feature_flags.guest_specified_signal_parameters()
1130                    || version.feature_flags.channel_interrupt_redirection()
1131            } else {
1132                false
1133            };
1134
1135        if !supports_interrupt_redirection && open_data.event_flag != channel_id.0 as u16 {
1136            rpc.fail(anyhow::anyhow!(
1137                "host does not support specifying the event flag"
1138            ));
1139            return;
1140        }
1141
1142        let open_channel = protocol::OpenChannel {
1143            channel_id,
1144            open_id: 0,
1145            ring_buffer_gpadl_id: open_data.ring_gpadl_id,
1146            target_vp: open_data.target_vp,
1147            downstream_ring_buffer_page_offset: open_data.ring_offset,
1148            user_data: open_data.user_data,
1149        };
1150
1151        let connection_id = if request.use_vtl2_connection_id {
1152            if !supports_interrupt_redirection {
1153                rpc.fail(anyhow::anyhow!(
1154                    "host does not support specfiying the connection ID"
1155                ));
1156                return;
1157            }
1158            protocol::ConnectionId::new(channel_id.0, 2.try_into().unwrap(), 7).0
1159        } else {
1160            open_data.connection_id
1161        };
1162
1163        // No failure paths after the one for allocating the event flag, since
1164        // otherwise we would need to free the event flag.
1165        let mut flags = OpenChannelFlags::new();
1166        let event_flag = if let Some(event) = &request.incoming_event {
1167            if !supports_interrupt_redirection {
1168                rpc.fail(anyhow::anyhow!(
1169                    "host does not support redirecting interrupts"
1170                ));
1171                return;
1172            }
1173
1174            flags.set_redirect_interrupt(true);
1175            match self.inner.synic.allocate_event_flag(event) {
1176                Ok(flag) => flag,
1177                Err(err) => {
1178                    rpc.fail(err.context("failed to allocate event flag"));
1179                    return;
1180                }
1181            }
1182        } else {
1183            open_data.event_flag
1184        };
1185
1186        if supports_interrupt_redirection {
1187            self.inner.messages.send(&protocol::OpenChannel2 {
1188                open_channel,
1189                connection_id,
1190                event_flag,
1191                flags,
1192            });
1193        } else {
1194            self.inner.messages.send(&open_channel);
1195        }
1196
1197        channel
1198            .connection_id
1199            .store(connection_id, Ordering::Release);
1200        channel.state = ChannelState::Opening {
1201            redirected_event_flag: (request.incoming_event.is_some()).then_some(event_flag),
1202            redirected_event: request.incoming_event,
1203            rpc,
1204        }
1205    }
1206
1207    fn handle_restore_channel(
1208        &mut self,
1209        channel_id: ChannelId,
1210        request: RestoreRequest,
1211    ) -> Result<OpenOutput> {
1212        let mut channel = self.channels.get_mut(channel_id);
1213        if !matches!(channel.state, ChannelState::Restored) {
1214            anyhow::bail!("invalid channel state: {}", channel.state);
1215        }
1216
1217        if request.incoming_event.is_some() != request.redirected_event_flag.is_some() {
1218            anyhow::bail!("incoming event and redirected event flag must both be set or unset");
1219        }
1220
1221        if let Some((flag, event)) = request
1222            .redirected_event_flag
1223            .zip(request.incoming_event.as_ref())
1224        {
1225            self.inner.synic.restore_event_flag(flag, event)?;
1226        }
1227
1228        channel
1229            .connection_id
1230            .store(request.connection_id, Ordering::Release);
1231        channel.state = ChannelState::Opened {
1232            redirected_event_flag: request.redirected_event_flag,
1233            redirected_event: request.incoming_event,
1234        };
1235        Ok(OpenOutput {
1236            redirected_event_flag: request.redirected_event_flag,
1237        })
1238    }
1239
1240    fn handle_gpadl(&mut self, channel_id: ChannelId, rpc: FailableRpc<GpadlRequest, ()>) {
1241        let (request, rpc) = rpc.split();
1242        let mut channel = self.channels.get_mut(channel_id);
1243        if channel
1244            .gpadls
1245            .insert(request.id, GpadlState::Offered(rpc))
1246            .is_some()
1247        {
1248            panic!(
1249                "duplicate gpadl ID {:?} for channel {:?}.",
1250                request.id, channel_id
1251            );
1252        }
1253
1254        tracing::trace!(
1255            channel_id = channel_id.0,
1256            key = %OfferKey::from(&channel.offer),
1257            gpadl_id = request.id.0,
1258            count = request.count,
1259            len = request.buf.len(),
1260            "received gpadl request"
1261        );
1262
1263        // Split off the values that fit in the header.
1264        let (first, remaining) = if request.buf.len() > protocol::GpadlHeader::MAX_DATA_VALUES {
1265            request.buf.split_at(protocol::GpadlHeader::MAX_DATA_VALUES)
1266        } else {
1267            (request.buf.as_slice(), [].as_slice())
1268        };
1269
1270        let message = protocol::GpadlHeader {
1271            channel_id,
1272            gpadl_id: request.id,
1273            len: (request.buf.len() * size_of::<u64>())
1274                .try_into()
1275                .expect("Too many GPA values"),
1276            count: request.count,
1277        };
1278
1279        self.inner
1280            .messages
1281            .send_with_data(&message, first.as_bytes());
1282
1283        // Send GpadlBody messages for the remaining values.
1284        let message = protocol::GpadlBody {
1285            rsvd: 0,
1286            gpadl_id: request.id,
1287        };
1288        for chunk in remaining.chunks(protocol::GpadlBody::MAX_DATA_VALUES) {
1289            self.inner
1290                .messages
1291                .send_with_data(&message, chunk.as_bytes());
1292        }
1293    }
1294
1295    fn handle_gpadl_teardown(&mut self, channel_id: ChannelId, rpc: Rpc<GpadlId, ()>) {
1296        let (gpadl_id, rpc) = rpc.split();
1297        let mut channel = self.channels.get_mut(channel_id);
1298        let Some(gpadl_state) = channel.gpadls.get_mut(&gpadl_id) else {
1299            tracing::warn!(
1300                gpadl_id = gpadl_id.0,
1301                channel_id = channel_id.0,
1302                key = %OfferKey::from(&channel.offer),
1303                "Gpadl teardown for unknown gpadl or revoked channel"
1304            );
1305            return;
1306        };
1307
1308        match gpadl_state {
1309            GpadlState::Offered(_) => {
1310                tracing::warn!(
1311                    gpadl_id = gpadl_id.0,
1312                    channel_id = channel_id.0,
1313                    key = %OfferKey::from(&channel.offer),
1314                    "gpadl teardown for offered gpadl"
1315                );
1316            }
1317            GpadlState::Created => {
1318                *gpadl_state = GpadlState::TearingDown { rpcs: vec![rpc] };
1319                // The caller must guarantee that GPADL teardown requests are only made
1320                // for unique GPADL IDs. This is currently enforced in vmbus_server by
1321                // blocking GPADL teardown messages for reserved channels.
1322                assert!(
1323                    self.inner
1324                        .teardown_gpadls
1325                        .insert(gpadl_id, channel_id)
1326                        .is_none(),
1327                    "Gpadl state validated above"
1328                );
1329
1330                self.inner.messages.send(&protocol::GpadlTeardown {
1331                    channel_id,
1332                    gpadl_id,
1333                });
1334            }
1335            GpadlState::TearingDown { rpcs } => {
1336                rpcs.push(rpc);
1337            }
1338        }
1339    }
1340
1341    fn handle_close_channel(&mut self, channel_id: ChannelId) {
1342        let mut channel = self.channels.get_mut(channel_id);
1343        self.inner.close_channel(channel_id, &mut channel);
1344    }
1345
1346    fn handle_modify_channel(&mut self, channel_id: ChannelId, rpc: Rpc<ModifyRequest, i32>) {
1347        // The client doesn't support versions below Iron, so we always expect the host to send a
1348        // ModifyChannelResponse. This means we don't need to worry about sending a ChannelResponse
1349        // if that weren't supported.
1350        assert!(self.check_version(Version::Iron));
1351        let mut channel = self.channels.get_mut(channel_id);
1352        if channel.modify_response_send.is_some() {
1353            panic!("duplicate channel modify request {channel_id:?}");
1354        }
1355
1356        let (request, response) = rpc.split();
1357        channel.modify_response_send = Some(response);
1358        let payload = match request {
1359            ModifyRequest::TargetVp { target_vp } => protocol::ModifyChannel {
1360                channel_id,
1361                target_vp,
1362            },
1363        };
1364
1365        self.inner.messages.send(&payload);
1366    }
1367
1368    fn handle_channel_request(&mut self, channel_id: ChannelId, request: ChannelRequest) {
1369        match request {
1370            ChannelRequest::Open(rpc) => self.handle_open_channel(channel_id, rpc),
1371            ChannelRequest::Restore(rpc) => {
1372                rpc.handle_failable_sync(|request| self.handle_restore_channel(channel_id, request))
1373            }
1374            ChannelRequest::Gpadl(req) => self.handle_gpadl(channel_id, req),
1375            ChannelRequest::TeardownGpadl(req) => self.handle_gpadl_teardown(channel_id, req),
1376            ChannelRequest::Close(req) => {
1377                req.handle_sync(|()| self.handle_close_channel(channel_id))
1378            }
1379            ChannelRequest::Modify(req) => self.handle_modify_channel(channel_id, req),
1380        }
1381    }
1382
1383    async fn handle_task(&mut self, task: TaskRequest) {
1384        match task {
1385            TaskRequest::Inspect(deferred) => {
1386                deferred.inspect(&*self);
1387            }
1388            TaskRequest::Save(rpc) => rpc.handle_sync(|()| self.handle_save()),
1389            TaskRequest::Restore(rpc) => {
1390                rpc.handle_sync(|saved_state| self.handle_restore(saved_state))
1391            }
1392            TaskRequest::PostRestore(rpc) => rpc.handle_sync(|()| self.handle_post_restore()),
1393            TaskRequest::Start => self.handle_start(),
1394            TaskRequest::Stop(rpc) => rpc.handle(async |()| self.handle_stop().await).await,
1395        }
1396    }
1397
1398    /// Makes sure a channel is closed if the channel request stream was dropped.
1399    fn handle_device_removal(&mut self, channel_id: ChannelId) -> TriedRelease {
1400        let mut channel = self.channels.get_mut(channel_id);
1401        channel.is_client_released = true;
1402        // Close the channel if it is still open.
1403        if let ChannelState::Opened { .. } = channel.state {
1404            tracing::warn!(
1405                channel_id = channel_id.0,
1406                key = %OfferKey::from(&channel.offer),
1407                "Channel dropped without closing first"
1408            );
1409            self.inner.close_channel(channel_id, &mut channel);
1410        }
1411        channel.try_release(&mut self.inner.messages)
1412    }
1413
1414    /// Determines if the client is connected with at least the specified version.
1415    fn check_version(&self, version: Version) -> bool {
1416        matches!(self.state, ClientState::Connected { version: v, .. } if v.version >= version)
1417    }
1418
1419    fn handle_start(&mut self) {
1420        assert!(!self.running);
1421        self.msg_source.resume_message_stream();
1422        self.inner.messages.resume();
1423        self.running = true;
1424    }
1425
1426    async fn handle_stop(&mut self) {
1427        assert!(self.running);
1428
1429        loop {
1430            // Process messages until there are no more channels waiting for
1431            // responses. This is necessary to ensure that the saved state does
1432            // not have to support encoding revoked channels for which we are
1433            // waiting for GPADL or modify responses.
1434            while let Some((id, request)) = self.channels.revoked_channel_with_pending_request() {
1435                tracelimit::info_ratelimited!(
1436                    channel_id = id.0,
1437                    request,
1438                    "waiting for responses for channel"
1439                );
1440                assert!(self.process_next_message().await);
1441            }
1442
1443            if self.can_pause_resume() {
1444                self.inner.messages.pause();
1445            } else {
1446                // Mask the sint to pause the message stream. The host will
1447                // retry any queued messages after the sint is unmasked.
1448                self.msg_source.pause_message_stream();
1449                self.inner.messages.force_pause();
1450            }
1451
1452            // Continue processing messages until we hit EOF or get a pause
1453            // response.
1454            while self.process_next_message().await {}
1455
1456            // Ensure there are still no pending requests. If there are, resume
1457            // and go around again.
1458            if self
1459                .channels
1460                .revoked_channel_with_pending_request()
1461                .is_none()
1462            {
1463                break;
1464            }
1465            if !self.can_pause_resume() {
1466                self.msg_source.resume_message_stream();
1467            }
1468            self.inner.messages.resume();
1469        }
1470
1471        tracing::debug!("messages drained");
1472        // Because the run loop awaits all async operations, there is no need for rundown.
1473        self.running = false;
1474    }
1475
1476    async fn process_next_message(&mut self) -> bool {
1477        let mut buf = [0; protocol::MAX_MESSAGE_SIZE];
1478        let recv = self.msg_source.recv(&mut buf);
1479        // Concurrently flush until there is no more work to do, since pending
1480        // messages may be blocking responses from the host.
1481        let flush = async {
1482            self.inner.messages.flush_messages().await;
1483            std::future::pending().await
1484        };
1485        let size = (recv, flush)
1486            .race()
1487            .await
1488            .expect("Fatal error reading messages from synic");
1489        if size == 0 {
1490            return false;
1491        }
1492        self.handle_synic_message(&buf[..size])
1493    }
1494
1495    /// Returns whether the server supports in-band messages to pause/resume the
1496    /// message stream.
1497    ///
1498    /// For hosts where this is not supported, we mask the sint to pause new
1499    /// messages being queued to the sint, then drain the messages. This does
1500    /// not work with some host implementations, which cannot support draining
1501    /// the message queue while the sint is masked (due to the use of
1502    /// HvPostMessageDirect).
1503    fn can_pause_resume(&self) -> bool {
1504        if let ClientState::Connected { version, .. } = self.state {
1505            version.feature_flags.pause_resume()
1506        } else {
1507            false
1508        }
1509    }
1510
1511    async fn run(&mut self) {
1512        let mut buf = [0; protocol::MAX_MESSAGE_SIZE];
1513        loop {
1514            let mut message_recv =
1515                OptionFuture::from(self.running.then(|| self.msg_source.recv(&mut buf).fuse()));
1516
1517            // If there are pending outgoing messages, the host is backed up.
1518            // Try to flush the queue, and in the meantime, stop generating new
1519            // messages by stopping processing client requests, so as to avoid
1520            // the outgoing message queue growing without bound.
1521            //
1522            // We still need to process incoming messages when in this state,
1523            // even though they may generate additional outgoing messages, to
1524            // avoid a deadlock with the host. The host can always DoS the
1525            // guest, so this is not an attack vector.
1526            let host_backed_up = !self.inner.messages.is_empty();
1527            let flush_messages = OptionFuture::from(
1528                (self.running && host_backed_up)
1529                    .then(|| self.inner.messages.flush_messages().fuse()),
1530            );
1531
1532            let mut client_request_recv = OptionFuture::from(
1533                (self.running && !host_backed_up).then(|| self.client_request_recv.next()),
1534            );
1535
1536            let mut channel_requests = OptionFuture::from(
1537                (self.running && !host_backed_up)
1538                    .then(|| self.inner.channel_requests.select_next_some()),
1539            );
1540
1541            futures::select! { // merge semantics
1542                _r = pin!(flush_messages) => {}
1543                r = self.task_recv.next() => {
1544                    if let Some(task) = r {
1545                        self.handle_task(task).await;
1546                    } else {
1547                        break;
1548                    }
1549                }
1550                r = client_request_recv => {
1551                    if let Some(Some(request)) = r {
1552                        self.handle_client_request(request);
1553                    } else {
1554                        break;
1555                    }
1556                }
1557                r = channel_requests => {
1558                    match r.unwrap() {
1559                        (id, Some(request)) => self.handle_channel_request(id, request),
1560                        (id, _) => {
1561                            self.handle_device_removal(id);
1562                        }
1563                    }
1564                }
1565                r = message_recv => {
1566                    match r.unwrap() {
1567                        Ok(size) => {
1568                            if size == 0 {
1569                                panic!("Unexpected end of file reading messages from synic.");
1570                            }
1571
1572                            self.handle_synic_message(&buf[..size]);
1573                        }
1574                        Err(err) => {
1575                            panic!("Error reading messages from synic: {err:?}");
1576                        }
1577                    }
1578                }
1579                complete => break,
1580            }
1581        }
1582    }
1583}
1584
1585impl ClientTaskInner {
1586    fn close_channel(&mut self, channel_id: ChannelId, channel: &mut Channel) {
1587        if let ChannelState::Opened {
1588            redirected_event_flag,
1589            ..
1590        } = channel.state
1591        {
1592            if let Some(flag) = redirected_event_flag {
1593                self.synic.free_event_flag(flag);
1594            }
1595            tracing::info!(
1596                channel_id = channel_id.0,
1597                key = %OfferKey::from(&channel.offer),
1598                "closing channel on host"
1599            );
1600
1601            self.messages.send(&protocol::CloseChannel { channel_id });
1602            channel.state = ChannelState::Offered;
1603            channel.connection_id.store(0, Ordering::Release);
1604        } else {
1605            tracing::warn!(
1606                channel_id = channel_id.0,
1607                key = %OfferKey::from(&channel.offer),
1608                channel_state = %channel.state,
1609                "invalid channel state for close channel"
1610            );
1611        }
1612    }
1613}
1614
1615#[derive(Debug, Inspect)]
1616#[inspect(external_tag)]
1617enum GpadlState {
1618    /// GpadlHeader has been sent to the host.
1619    Offered(#[inspect(skip)] FailableRpc<(), ()>),
1620    /// Host has responded with GpadlCreated.
1621    Created,
1622    /// GpadlTeardown message has been sent to the host.
1623    TearingDown {
1624        #[inspect(skip)]
1625        rpcs: Vec<Rpc<(), ()>>,
1626    },
1627}
1628
1629#[derive(Inspect)]
1630struct OutgoingMessages {
1631    #[inspect(skip)]
1632    poster: Box<dyn PollPostMessage>,
1633    #[inspect(with = "|x| x.len()")]
1634    queued: VecDeque<OutgoingMessage>,
1635    state: OutgoingMessageState,
1636}
1637
1638#[derive(Inspect, PartialEq, Eq, Debug)]
1639enum OutgoingMessageState {
1640    Running,
1641    SendingPauseMessage,
1642    Paused,
1643}
1644
1645impl OutgoingMessages {
1646    fn send<T: IntoBytes + protocol::VmbusMessage + std::fmt::Debug + Immutable + KnownLayout>(
1647        &mut self,
1648        msg: &T,
1649    ) {
1650        self.send_with_data(msg, &[])
1651    }
1652
1653    fn send_with_data<
1654        T: IntoBytes + protocol::VmbusMessage + std::fmt::Debug + Immutable + KnownLayout,
1655    >(
1656        &mut self,
1657        msg: &T,
1658        data: &[u8],
1659    ) {
1660        tracing::trace!(typ = ?T::MESSAGE_TYPE, "Sending message to host");
1661        let msg = OutgoingMessage::with_data(msg, data);
1662        if self.queued.is_empty() && self.state == OutgoingMessageState::Running {
1663            let r = self.poster.poll_post_message(
1664                &mut Context::from_waker(std::task::Waker::noop()),
1665                protocol::VMBUS_MESSAGE_REDIRECT_CONNECTION_ID,
1666                1,
1667                msg.data(),
1668            );
1669            if let Poll::Ready(()) = r {
1670                return;
1671            }
1672        }
1673        tracing::trace!("queueing message");
1674        self.queued.push_back(msg);
1675    }
1676
1677    async fn flush_messages(&mut self) {
1678        let mut send = async |msg: &OutgoingMessage| {
1679            poll_fn(|cx| {
1680                self.poster.poll_post_message(
1681                    cx,
1682                    protocol::VMBUS_MESSAGE_REDIRECT_CONNECTION_ID,
1683                    1,
1684                    msg.data(),
1685                )
1686            })
1687            .await
1688        };
1689        match self.state {
1690            OutgoingMessageState::Running => {
1691                while let Some(msg) = self.queued.front() {
1692                    send(msg).await;
1693                    tracing::trace!("sent queued message");
1694                    self.queued.pop_front();
1695                }
1696            }
1697            OutgoingMessageState::SendingPauseMessage => {
1698                send(&OutgoingMessage::new(&protocol::Pause)).await;
1699                tracing::trace!("sent pause message");
1700                self.state = OutgoingMessageState::Paused;
1701            }
1702            OutgoingMessageState::Paused => {}
1703        }
1704    }
1705
1706    /// Pause by sending a pause message to the host. This will cause the host
1707    /// to stop sending messages after sending a pause response.
1708    fn pause(&mut self) {
1709        assert_eq!(self.state, OutgoingMessageState::Running);
1710        self.state = OutgoingMessageState::SendingPauseMessage;
1711        // Queue a resume message to be sent later.
1712        self.queued
1713            .push_front(OutgoingMessage::new(&protocol::Resume));
1714    }
1715
1716    /// Force a pause by setting the state to Paused. This is used when the
1717    /// host does not support in-band pause/resume messages, in which case
1718    /// the SINT is masked to force the host to stop sending messages.
1719    fn force_pause(&mut self) {
1720        assert_eq!(self.state, OutgoingMessageState::Running);
1721        self.state = OutgoingMessageState::Paused;
1722    }
1723
1724    fn resume(&mut self) {
1725        assert_eq!(self.state, OutgoingMessageState::Paused);
1726        self.state = OutgoingMessageState::Running;
1727    }
1728
1729    fn is_empty(&self) -> bool {
1730        self.queued.is_empty()
1731    }
1732}
1733
1734#[derive(Inspect)]
1735struct ClientTaskInner {
1736    messages: OutgoingMessages,
1737    #[inspect(with = "|x| inspect::iter_by_key(x).map_key(|id| id.0)")]
1738    teardown_gpadls: HashMap<GpadlId, ChannelId>,
1739    #[inspect(skip)]
1740    channel_requests: SelectAll<TaggedStream<ChannelId, mesh::Receiver<ChannelRequest>>>,
1741    synic: SynicState,
1742}
1743
1744#[derive(Inspect)]
1745struct SynicState {
1746    #[inspect(skip)]
1747    event_client: Arc<dyn SynicEventClient>,
1748    #[inspect(iter_by_index)]
1749    event_flag_state: Vec<bool>,
1750}
1751
1752#[derive(Inspect, Default)]
1753#[inspect(transparent)]
1754struct ChannelList(
1755    #[inspect(with = "|x| inspect::iter_by_key(x).map_key(|id| id.0)")] HashMap<ChannelId, Channel>,
1756);
1757
1758/// A reference to a channel that can be used to remove the channel from the map
1759/// as well.
1760struct ChannelRef<'a>(hash_map::OccupiedEntry<'a, ChannelId, Channel>);
1761
1762/// A tag value used to indicate that [`ChannelRef::try_release`] has been called.
1763/// This is useful as a return value for methods that might transition a channel
1764/// into a fully released state.
1765struct TriedRelease(());
1766
1767impl ChannelRef<'_> {
1768    /// If the channel has been fully released (revoked, released by the client,
1769    /// no pending requests), notifies the server and removes this channel from
1770    /// the map.
1771    fn try_release(self, messages: &mut OutgoingMessages) -> TriedRelease {
1772        if self.is_client_released
1773            && matches!(self.state, ChannelState::Revoked)
1774            && self.pending_request().is_none()
1775        {
1776            let channel_id = *self.0.key();
1777            tracelimit::info_ratelimited!(
1778                channel_id = channel_id.0,
1779                key = %OfferKey::from(&self.offer),
1780                "releasing channel"
1781            );
1782
1783            messages.send(&protocol::RelIdReleased { channel_id });
1784            self.0.remove();
1785        }
1786        TriedRelease(())
1787    }
1788}
1789
1790impl Deref for ChannelRef<'_> {
1791    type Target = Channel;
1792
1793    fn deref(&self) -> &Self::Target {
1794        self.0.get()
1795    }
1796}
1797
1798impl DerefMut for ChannelRef<'_> {
1799    fn deref_mut(&mut self) -> &mut Self::Target {
1800        self.0.get_mut()
1801    }
1802}
1803
1804impl ChannelList {
1805    fn revoked_channel_with_pending_request(&self) -> Option<(ChannelId, &'static str)> {
1806        self.0.iter().find_map(|(&id, channel)| {
1807            if !matches!(channel.state, ChannelState::Revoked) {
1808                return None;
1809            }
1810            Some((id, channel.pending_request()?))
1811        })
1812    }
1813
1814    #[track_caller]
1815    fn get_mut(&mut self, channel_id: ChannelId) -> ChannelRef<'_> {
1816        match self.0.entry(channel_id) {
1817            hash_map::Entry::Occupied(entry) => ChannelRef(entry),
1818            hash_map::Entry::Vacant(_) => {
1819                panic!("channel {:?} not found", channel_id);
1820            }
1821        }
1822    }
1823}
1824
1825impl SynicState {
1826    fn guest_to_host_interrupt(&self, connection_id: Arc<AtomicU32>) -> Interrupt {
1827        Interrupt::from_fn({
1828            let event_client = self.event_client.clone();
1829            move || {
1830                let connection_id = connection_id.load(Ordering::Acquire);
1831                if connection_id == 0 {
1832                    tracing::debug!("interrupt signal after close");
1833                    return;
1834                }
1835
1836                if let Err(err) = event_client.signal_event(connection_id, 0) {
1837                    tracelimit::warn_ratelimited!(
1838                        error = &err as &dyn std::error::Error,
1839                        "failed to signal event"
1840                    );
1841                }
1842            }
1843        })
1844    }
1845
1846    const MAX_EVENT_FLAGS: u16 = 2047;
1847
1848    fn allocate_event_flag(&mut self, event: &Event) -> Result<u16> {
1849        let i = self
1850            .event_flag_state
1851            .iter()
1852            .position(|&used| !used)
1853            .ok_or(())
1854            .or_else(|()| {
1855                if self.event_flag_state.len() >= Self::MAX_EVENT_FLAGS as usize {
1856                    anyhow::bail!("out of event flags");
1857                }
1858                self.event_flag_state.push(false);
1859                Ok(self.event_flag_state.len() - 1)
1860            })?;
1861
1862        let event_flag = (i + 1) as u16;
1863        self.event_client
1864            .map_event(event_flag, event)
1865            .context("failed to map event")?;
1866        self.event_flag_state[i] = true;
1867        Ok(event_flag)
1868    }
1869
1870    fn restore_event_flag(&mut self, flag: u16, event: &Event) -> Result<()> {
1871        let i = (flag as usize)
1872            .checked_sub(1)
1873            .context("invalid event flag")?;
1874        if i >= Self::MAX_EVENT_FLAGS as usize {
1875            anyhow::bail!("invalid event flag");
1876        }
1877        if self.event_flag_state.len() <= i {
1878            self.event_flag_state.resize(i + 1, false);
1879        }
1880        if self.event_flag_state[i] {
1881            anyhow::bail!("event flag already in use");
1882        }
1883        self.event_client
1884            .map_event(flag, event)
1885            .context("failed to map event")?;
1886        self.event_flag_state[i] = true;
1887        Ok(())
1888    }
1889
1890    fn free_event_flag(&mut self, flag: u16) {
1891        let i = flag as usize - 1;
1892        assert!(i < self.event_flag_state.len());
1893        self.event_flag_state[i] = false;
1894    }
1895}
1896
1897#[cfg(test)]
1898mod tests {
1899    use super::*;
1900    use futures_concurrency::future::Join;
1901    use guid::Guid;
1902    use pal_async::DefaultDriver;
1903    use pal_async::async_test;
1904    use pal_async::timer::PolledTimer;
1905    use protocol::TargetInfo;
1906    use std::fmt::Debug;
1907    use std::task::ready;
1908    use std::time::Duration;
1909    use test_with_tracing::test;
1910    use vmbus_core::protocol::MessageHeader;
1911    use vmbus_core::protocol::MessageType;
1912    use vmbus_core::protocol::OfferFlags;
1913    use vmbus_core::protocol::UserDefinedData;
1914    use vmbus_core::protocol::VmbusMessage;
1915    use zerocopy::FromBytes;
1916    use zerocopy::FromZeros;
1917    use zerocopy::Immutable;
1918    use zerocopy::IntoBytes;
1919    use zerocopy::KnownLayout;
1920
1921    const VMBUS_TEST_CLIENT_ID: Guid = guid::guid!("e6e6e6e6-e6e6-e6e6-e6e6-e6e6e6e6e6e6");
1922
1923    fn in_msg<T: IntoBytes + Immutable + KnownLayout>(message_type: MessageType, t: T) -> Vec<u8> {
1924        let mut data = Vec::new();
1925        data.extend_from_slice(&message_type.0.to_ne_bytes());
1926        data.extend_from_slice(&0u32.to_ne_bytes());
1927        data.extend_from_slice(t.as_bytes());
1928        data
1929    }
1930
1931    #[track_caller]
1932    fn check_message<T>(msg: OutgoingMessage, chk: T)
1933    where
1934        T: IntoBytes + FromBytes + Immutable + KnownLayout + Debug + VmbusMessage,
1935    {
1936        check_message_with_data(msg, chk, &[]);
1937    }
1938
1939    #[track_caller]
1940    fn check_message_with_data<T>(msg: OutgoingMessage, chk: T, data: &[u8])
1941    where
1942        T: IntoBytes + FromBytes + Immutable + KnownLayout + Debug + VmbusMessage,
1943    {
1944        let chk_data = OutgoingMessage::with_data(&chk, data);
1945        if msg.data() != chk_data.data() {
1946            let (header, rest) = MessageHeader::read_from_prefix(msg.data()).unwrap();
1947            assert_eq!(header.message_type(), <T as VmbusMessage>::MESSAGE_TYPE);
1948            let (msg, rest) = T::read_from_prefix(rest).expect("incorrect message size");
1949            if msg.as_bytes() != chk.as_bytes() {
1950                panic!("mismatched messages, expected {:#?}, got {:#?}", chk, msg);
1951            }
1952            if rest != data {
1953                panic!("mismatched data, expected {:#?}, got {:#?}", data, rest);
1954            }
1955        }
1956    }
1957
1958    struct TestServer {
1959        messages: mesh::Receiver<OutgoingMessage>,
1960        send: mesh::Sender<Vec<u8>>,
1961    }
1962
1963    impl TestServer {
1964        async fn next(&mut self) -> Option<OutgoingMessage> {
1965            self.messages.next().await
1966        }
1967
1968        fn send(&self, msg: Vec<u8>) {
1969            self.send.send(msg);
1970        }
1971
1972        async fn connect(&mut self, client: &mut VmbusClient) -> ConnectResult {
1973            self.connect_with_channels(client, |_| {}).await
1974        }
1975
1976        async fn connect_with_channels(
1977            &mut self,
1978            client: &mut VmbusClient,
1979            send_offers: impl FnOnce(&mut Self),
1980        ) -> ConnectResult {
1981            let client_connect = client.connect(0, None, Guid::ZERO);
1982
1983            let server_connect = async {
1984                let _ = self.next().await.unwrap();
1985
1986                self.send(in_msg(
1987                    MessageType::VERSION_RESPONSE,
1988                    protocol::VersionResponse2 {
1989                        version_response: protocol::VersionResponse {
1990                            version_supported: 1,
1991                            connection_state: ConnectionState::SUCCESSFUL,
1992                            padding: 0,
1993                            selected_version_or_connection_id: 0,
1994                        },
1995                        supported_features: SUPPORTED_FEATURE_FLAGS.into(),
1996                    },
1997                ));
1998
1999                check_message(self.next().await.unwrap(), protocol::RequestOffers {});
2000
2001                send_offers(self);
2002                self.send(in_msg(MessageType::ALL_OFFERS_DELIVERED, [0x00]));
2003            };
2004
2005            let (connection, ()) = (client_connect, server_connect).join().await;
2006
2007            let connection = connection.unwrap();
2008            assert_eq!(connection.version.version, Version::Copper);
2009            assert_eq!(connection.version.feature_flags, SUPPORTED_FEATURE_FLAGS);
2010            connection
2011        }
2012
2013        async fn get_channel(&mut self, client: &mut VmbusClient) -> OfferInfo {
2014            let [channel] = self
2015                .get_channels(client, 1)
2016                .await
2017                .offers
2018                .try_into()
2019                .unwrap();
2020            channel
2021        }
2022
2023        async fn get_channels(&mut self, client: &mut VmbusClient, count: usize) -> ConnectResult {
2024            self.connect_with_channels(client, |this| {
2025                for i in 0..count {
2026                    let offer = protocol::OfferChannel {
2027                        interface_id: Guid::new_random(),
2028                        instance_id: Guid::new_random(),
2029                        rsvd: [0; 4],
2030                        flags: OfferFlags::new(),
2031                        mmio_megabytes: 0,
2032                        user_defined: UserDefinedData::new_zeroed(),
2033                        subchannel_index: 0,
2034                        mmio_megabytes_optional: 0,
2035                        channel_id: ChannelId(i as u32),
2036                        monitor_id: 0,
2037                        monitor_allocated: 0,
2038                        is_dedicated: 0,
2039                        connection_id: 0,
2040                    };
2041
2042                    this.send(in_msg(MessageType::OFFER_CHANNEL, offer));
2043                }
2044            })
2045            .await
2046        }
2047
2048        async fn stop_client(&mut self, client: &mut VmbusClient) {
2049            let client_stop = client.stop();
2050            let server_stop = async {
2051                check_message(self.next().await.unwrap(), protocol::Pause);
2052                self.send(in_msg(MessageType::PAUSE_RESPONSE, protocol::PauseResponse));
2053            };
2054            (client_stop, server_stop).join().await;
2055        }
2056
2057        async fn start_client(&mut self, client: &mut VmbusClient) {
2058            client.start();
2059            check_message(self.next().await.unwrap(), protocol::Resume);
2060        }
2061    }
2062
2063    struct TestServerClient {
2064        sender: mesh::Sender<OutgoingMessage>,
2065        timer: PolledTimer,
2066        deadline: Option<pal_async::timer::Instant>,
2067    }
2068
2069    impl PollPostMessage for TestServerClient {
2070        fn poll_post_message(
2071            &mut self,
2072            cx: &mut Context<'_>,
2073            _connection_id: u32,
2074            _typ: u32,
2075            msg: &[u8],
2076        ) -> Poll<()> {
2077            loop {
2078                if let Some(deadline) = self.deadline {
2079                    ready!(self.timer.poll_until(cx, deadline));
2080                    self.deadline = None;
2081                }
2082                // Randomly choose whether to delay the message.
2083                //
2084                // FUTURE: use some kind of deterministic test framework for this to
2085                // allow for reproducible tests.
2086                let mut b = [0];
2087                getrandom::fill(&mut b).unwrap();
2088                if b[0] % 4 == 0 {
2089                    self.deadline =
2090                        Some(pal_async::timer::Instant::now() + Duration::from_millis(10));
2091                } else {
2092                    let msg = OutgoingMessage::from_message(msg).unwrap();
2093                    tracing::info!(
2094                        msg = ?MessageHeader::read_from_prefix(msg.data()),
2095                        "sending message"
2096                    );
2097                    self.sender.send(msg);
2098                    break Poll::Ready(());
2099                }
2100            }
2101        }
2102    }
2103
2104    struct NoopSynicEvents;
2105
2106    impl SynicEventClient for NoopSynicEvents {
2107        fn map_event(&self, _event_flag: u16, _event: &Event) -> std::io::Result<()> {
2108            Ok(())
2109        }
2110
2111        fn unmap_event(&self, _event_flag: u16) {}
2112
2113        fn signal_event(&self, _connection_id: u32, _event_flag: u16) -> std::io::Result<()> {
2114            Err(std::io::ErrorKind::Unsupported.into())
2115        }
2116    }
2117
2118    struct TestMessageSource {
2119        msg_recv: mesh::Receiver<Vec<u8>>,
2120        paused: bool,
2121    }
2122
2123    impl AsyncRecv for TestMessageSource {
2124        fn poll_recv(
2125            &mut self,
2126            cx: &mut Context<'_>,
2127            mut bufs: &mut [std::io::IoSliceMut<'_>],
2128        ) -> Poll<std::io::Result<usize>> {
2129            let value = match self.msg_recv.poll_recv(cx) {
2130                Poll::Ready(v) => v.unwrap(),
2131                Poll::Pending => {
2132                    if self.paused {
2133                        return Poll::Ready(Ok(0));
2134                    } else {
2135                        return Poll::Pending;
2136                    }
2137                }
2138            };
2139            let mut remaining = value.as_slice();
2140            let mut total_size = 0;
2141            while !remaining.is_empty() && !bufs.is_empty() {
2142                let size = bufs[0].len().min(remaining.len());
2143                bufs[0][..size].copy_from_slice(&remaining[..size]);
2144                remaining = &remaining[size..];
2145                bufs = &mut bufs[1..];
2146                total_size += size;
2147            }
2148
2149            Ok(total_size).into()
2150        }
2151    }
2152
2153    impl VmbusMessageSource for TestMessageSource {
2154        fn pause_message_stream(&mut self) {
2155            self.paused = true;
2156        }
2157
2158        fn resume_message_stream(&mut self) {
2159            self.paused = false;
2160        }
2161    }
2162
2163    fn test_init(driver: &DefaultDriver) -> (TestServer, VmbusClient) {
2164        let (msg_send, msg_recv) = mesh::channel();
2165        let (synic_send, synic_recv) = mesh::channel();
2166        let server = TestServer {
2167            messages: synic_recv,
2168            send: msg_send,
2169        };
2170        let mut client = VmbusClientBuilder::new(
2171            NoopSynicEvents,
2172            TestMessageSource {
2173                msg_recv,
2174                paused: false,
2175            },
2176            TestServerClient {
2177                sender: synic_send,
2178                deadline: None,
2179                timer: PolledTimer::new(driver),
2180            },
2181        )
2182        .build(driver);
2183        client.start();
2184        (server, client)
2185    }
2186
2187    #[async_test]
2188    async fn test_initiate_contact_success(driver: DefaultDriver) {
2189        let (mut server, client) = test_init(&driver);
2190        let _recv = client
2191            .access
2192            .client_request_send
2193            .call(ClientRequest::Connect, ConnectRequest::default());
2194        check_message(
2195            server.next().await.unwrap(),
2196            protocol::InitiateContact2 {
2197                initiate_contact: protocol::InitiateContact {
2198                    version_requested: Version::Copper as u32,
2199                    target_message_vp: 0,
2200                    interrupt_page_or_target_info: TargetInfo::new()
2201                        .with_sint(2)
2202                        .with_vtl(0)
2203                        .with_feature_flags(SUPPORTED_FEATURE_FLAGS.into())
2204                        .into(),
2205                    parent_to_child_monitor_page_gpa: 0,
2206                    child_to_parent_monitor_page_gpa: 0,
2207                },
2208                ..FromZeros::new_zeroed()
2209            },
2210        );
2211    }
2212
2213    #[async_test]
2214    async fn test_connect_success(driver: DefaultDriver) {
2215        let (mut server, mut client) = test_init(&driver);
2216        let client_connect = client.connect(0, None, Guid::ZERO);
2217
2218        let server_connect = async {
2219            check_message(
2220                server.next().await.unwrap(),
2221                protocol::InitiateContact2 {
2222                    initiate_contact: protocol::InitiateContact {
2223                        version_requested: Version::Copper as u32,
2224                        target_message_vp: 0,
2225                        interrupt_page_or_target_info: TargetInfo::new()
2226                            .with_sint(2)
2227                            .with_vtl(0)
2228                            .with_feature_flags(SUPPORTED_FEATURE_FLAGS.into())
2229                            .into(),
2230                        parent_to_child_monitor_page_gpa: 0,
2231                        child_to_parent_monitor_page_gpa: 0,
2232                    },
2233                    ..FromZeros::new_zeroed()
2234                },
2235            );
2236
2237            server.send(in_msg(
2238                MessageType::VERSION_RESPONSE,
2239                protocol::VersionResponse2 {
2240                    version_response: protocol::VersionResponse {
2241                        version_supported: 1,
2242                        connection_state: ConnectionState::SUCCESSFUL,
2243                        padding: 0,
2244                        selected_version_or_connection_id: 0,
2245                    },
2246                    supported_features: SUPPORTED_FEATURE_FLAGS.into_bits(),
2247                },
2248            ));
2249
2250            check_message(server.next().await.unwrap(), protocol::RequestOffers {});
2251            server.send(in_msg(MessageType::ALL_OFFERS_DELIVERED, [0x00]));
2252        };
2253
2254        let (connection, ()) = (client_connect, server_connect).join().await;
2255        let connection = connection.unwrap();
2256
2257        assert_eq!(connection.version.version, Version::Copper);
2258        assert_eq!(connection.version.feature_flags, SUPPORTED_FEATURE_FLAGS);
2259    }
2260
2261    #[async_test]
2262    async fn test_feature_flags(driver: DefaultDriver) {
2263        let (mut server, mut client) = test_init(&driver);
2264        let client_connect = client.connect(0, None, Guid::ZERO);
2265
2266        let server_connect = async {
2267            check_message(
2268                server.next().await.unwrap(),
2269                protocol::InitiateContact2 {
2270                    initiate_contact: protocol::InitiateContact {
2271                        version_requested: Version::Copper as u32,
2272                        target_message_vp: 0,
2273                        interrupt_page_or_target_info: TargetInfo::new()
2274                            .with_sint(2)
2275                            .with_vtl(0)
2276                            .with_feature_flags(SUPPORTED_FEATURE_FLAGS.into())
2277                            .into(),
2278                        parent_to_child_monitor_page_gpa: 0,
2279                        child_to_parent_monitor_page_gpa: 0,
2280                    },
2281                    ..FromZeros::new_zeroed()
2282                },
2283            );
2284
2285            // Report the server doesn't support some of the feature flags, and make
2286            // sure this is reflected in the returned version.
2287            server.send(in_msg(
2288                MessageType::VERSION_RESPONSE,
2289                protocol::VersionResponse2 {
2290                    version_response: protocol::VersionResponse {
2291                        version_supported: 1,
2292                        connection_state: ConnectionState::SUCCESSFUL,
2293                        padding: 0,
2294                        selected_version_or_connection_id: 0,
2295                    },
2296                    supported_features: 2,
2297                },
2298            ));
2299
2300            check_message(server.next().await.unwrap(), protocol::RequestOffers {});
2301            server.send(in_msg(MessageType::ALL_OFFERS_DELIVERED, [0x00]));
2302        };
2303
2304        let (connection, ()) = (client_connect, server_connect).join().await;
2305        let connection = connection.unwrap();
2306
2307        assert_eq!(connection.version.version, Version::Copper);
2308        assert_eq!(
2309            connection.version.feature_flags,
2310            FeatureFlags::new().with_channel_interrupt_redirection(true)
2311        );
2312    }
2313
2314    #[async_test]
2315    async fn test_client_id(driver: DefaultDriver) {
2316        let (mut server, client) = test_init(&driver);
2317        let initiate_contact = ConnectRequest {
2318            client_id: VMBUS_TEST_CLIENT_ID,
2319            ..Default::default()
2320        };
2321        let _recv = client
2322            .access
2323            .client_request_send
2324            .call(ClientRequest::Connect, initiate_contact);
2325
2326        check_message(
2327            server.next().await.unwrap(),
2328            protocol::InitiateContact2 {
2329                initiate_contact: protocol::InitiateContact {
2330                    version_requested: Version::Copper as u32,
2331                    target_message_vp: 0,
2332                    interrupt_page_or_target_info: TargetInfo::new()
2333                        .with_sint(2)
2334                        .with_vtl(0)
2335                        .with_feature_flags(SUPPORTED_FEATURE_FLAGS.into())
2336                        .into(),
2337                    parent_to_child_monitor_page_gpa: 0,
2338                    child_to_parent_monitor_page_gpa: 0,
2339                },
2340                client_id: VMBUS_TEST_CLIENT_ID,
2341            },
2342        );
2343    }
2344
2345    #[async_test]
2346    async fn test_version_negotiation(driver: DefaultDriver) {
2347        let (mut server, mut client) = test_init(&driver);
2348        let client_connect = client.connect(0, None, Guid::ZERO);
2349
2350        let server_connect = async {
2351            check_message(
2352                server.next().await.unwrap(),
2353                protocol::InitiateContact2 {
2354                    initiate_contact: protocol::InitiateContact {
2355                        version_requested: Version::Copper as u32,
2356                        target_message_vp: 0,
2357                        interrupt_page_or_target_info: TargetInfo::new()
2358                            .with_sint(2)
2359                            .with_vtl(0)
2360                            .with_feature_flags(SUPPORTED_FEATURE_FLAGS.into())
2361                            .into(),
2362                        parent_to_child_monitor_page_gpa: 0,
2363                        child_to_parent_monitor_page_gpa: 0,
2364                    },
2365                    ..FromZeros::new_zeroed()
2366                },
2367            );
2368
2369            server.send(in_msg(
2370                MessageType::VERSION_RESPONSE,
2371                protocol::VersionResponse {
2372                    version_supported: 0,
2373                    connection_state: ConnectionState::SUCCESSFUL,
2374                    padding: 0,
2375                    selected_version_or_connection_id: 0,
2376                },
2377            ));
2378
2379            check_message(
2380                server.next().await.unwrap(),
2381                protocol::InitiateContact {
2382                    version_requested: Version::Iron as u32,
2383                    target_message_vp: 0,
2384                    interrupt_page_or_target_info: TargetInfo::new()
2385                        .with_sint(2)
2386                        .with_vtl(0)
2387                        .with_feature_flags(FeatureFlags::new().into())
2388                        .into(),
2389                    parent_to_child_monitor_page_gpa: 0,
2390                    child_to_parent_monitor_page_gpa: 0,
2391                },
2392            );
2393
2394            server.send(in_msg(
2395                MessageType::VERSION_RESPONSE,
2396                protocol::VersionResponse {
2397                    version_supported: 1,
2398                    connection_state: ConnectionState::SUCCESSFUL,
2399                    padding: 0,
2400                    selected_version_or_connection_id: 0,
2401                },
2402            ));
2403
2404            check_message(server.next().await.unwrap(), protocol::RequestOffers {});
2405            server.send(in_msg(MessageType::ALL_OFFERS_DELIVERED, [0x00]));
2406        };
2407
2408        let (connection, ()) = (client_connect, server_connect).join().await;
2409        let connection = connection.unwrap();
2410
2411        assert_eq!(connection.version.version, Version::Iron);
2412        assert_eq!(connection.version.feature_flags, FeatureFlags::new());
2413    }
2414
2415    #[async_test]
2416    async fn test_open_channel_success(driver: DefaultDriver) {
2417        let (mut server, mut client) = test_init(&driver);
2418        let channel = server.get_channel(&mut client).await;
2419
2420        let recv = channel.request_send.call(
2421            ChannelRequest::Open,
2422            OpenRequest {
2423                open_data: OpenData {
2424                    target_vp: 0,
2425                    ring_offset: 0,
2426                    ring_gpadl_id: GpadlId(0),
2427                    event_flag: 0,
2428                    connection_id: 0,
2429                    user_data: UserDefinedData::new_zeroed(),
2430                },
2431                incoming_event: None,
2432                use_vtl2_connection_id: false,
2433            },
2434        );
2435
2436        check_message(
2437            server.next().await.unwrap(),
2438            protocol::OpenChannel2 {
2439                open_channel: protocol::OpenChannel {
2440                    channel_id: ChannelId(0),
2441                    open_id: 0,
2442                    ring_buffer_gpadl_id: GpadlId(0),
2443                    target_vp: 0,
2444                    downstream_ring_buffer_page_offset: 0,
2445                    user_data: UserDefinedData::new_zeroed(),
2446                },
2447                connection_id: 0,
2448                event_flag: 0,
2449                flags: Default::default(),
2450            },
2451        );
2452
2453        server.send(in_msg(
2454            MessageType::OPEN_CHANNEL_RESULT,
2455            protocol::OpenResult {
2456                channel_id: ChannelId(0),
2457                open_id: 0,
2458                status: protocol::STATUS_SUCCESS as u32,
2459            },
2460        ));
2461
2462        recv.await.unwrap().unwrap();
2463    }
2464
2465    #[async_test]
2466    async fn test_open_channel_fail(driver: DefaultDriver) {
2467        let (mut server, mut client) = test_init(&driver);
2468        let channel = server.get_channel(&mut client).await;
2469
2470        let recv = channel.request_send.call(
2471            ChannelRequest::Open,
2472            OpenRequest {
2473                open_data: OpenData {
2474                    target_vp: 0,
2475                    ring_offset: 0,
2476                    ring_gpadl_id: GpadlId(0),
2477                    event_flag: 0,
2478                    connection_id: 0,
2479                    user_data: UserDefinedData::new_zeroed(),
2480                },
2481                incoming_event: None,
2482                use_vtl2_connection_id: false,
2483            },
2484        );
2485
2486        check_message(
2487            server.next().await.unwrap(),
2488            protocol::OpenChannel2 {
2489                open_channel: protocol::OpenChannel {
2490                    channel_id: ChannelId(0),
2491                    open_id: 0,
2492                    ring_buffer_gpadl_id: GpadlId(0),
2493                    target_vp: 0,
2494                    downstream_ring_buffer_page_offset: 0,
2495                    user_data: UserDefinedData::new_zeroed(),
2496                },
2497                connection_id: 0,
2498                event_flag: 0,
2499                flags: Default::default(),
2500            },
2501        );
2502
2503        server.send(in_msg(
2504            MessageType::OPEN_CHANNEL_RESULT,
2505            protocol::OpenResult {
2506                channel_id: ChannelId(0),
2507                open_id: 0,
2508                status: protocol::STATUS_UNSUCCESSFUL as u32,
2509            },
2510        ));
2511
2512        recv.await.unwrap().unwrap_err();
2513    }
2514
2515    #[async_test]
2516    async fn test_modify_channel(driver: DefaultDriver) {
2517        let (mut server, mut client) = test_init(&driver);
2518        let channel = server.get_channel(&mut client).await;
2519
2520        // N.B. A real server requires the channel to be open before sending this, but the test
2521        //      server doesn't care.
2522        let recv = channel.request_send.call(
2523            ChannelRequest::Modify,
2524            ModifyRequest::TargetVp { target_vp: 1 },
2525        );
2526
2527        check_message(
2528            server.next().await.unwrap(),
2529            protocol::ModifyChannel {
2530                channel_id: ChannelId(0),
2531                target_vp: 1,
2532            },
2533        );
2534
2535        server.send(in_msg(
2536            MessageType::MODIFY_CHANNEL_RESPONSE,
2537            protocol::ModifyChannelResponse {
2538                channel_id: ChannelId(0),
2539                status: protocol::STATUS_SUCCESS,
2540            },
2541        ));
2542
2543        let status = recv.await.unwrap();
2544        assert_eq!(status, protocol::STATUS_SUCCESS);
2545    }
2546
2547    #[async_test]
2548    async fn test_save_restore_connected(driver: DefaultDriver) {
2549        let (mut server, mut client) = test_init(&driver);
2550        server.connect(&mut client).await;
2551        server.stop_client(&mut client).await;
2552        let s0 = client.save().await;
2553        let builder = client.sever().await;
2554        let mut client = builder.build(&driver);
2555        client.restore(s0.clone()).await.unwrap();
2556
2557        let s1 = client.save().await;
2558
2559        assert_eq!(s0, s1);
2560    }
2561
2562    #[async_test]
2563    async fn test_save_restore_connected_with_channel(driver: DefaultDriver) {
2564        let (mut server, mut client) = test_init(&driver);
2565        let c0 = server.get_channel(&mut client).await;
2566        server.stop_client(&mut client).await;
2567        let s0 = client.save().await;
2568        let builder = client.sever().await;
2569        let mut client = builder.build(&driver);
2570        let connection = client.restore(s0.clone()).await.unwrap().unwrap();
2571        let s1 = client.save().await;
2572        assert_eq!(s0, s1);
2573        assert_eq!(connection.offers[0].offer, c0.offer);
2574    }
2575
2576    #[async_test]
2577    async fn test_save_restore_connected_with_revoked_channel(driver: DefaultDriver) {
2578        let (mut server, mut client) = test_init(&driver);
2579        let c0 = server.get_channel(&mut client).await;
2580        server.send(in_msg(
2581            MessageType::RESCIND_CHANNEL_OFFER,
2582            protocol::RescindChannelOffer {
2583                channel_id: ChannelId(0),
2584            },
2585        ));
2586        c0.revoke_recv.await.unwrap();
2587        let rpc = c0.request_send.call(
2588            ChannelRequest::Modify,
2589            ModifyRequest::TargetVp { target_vp: 1 },
2590        );
2591
2592        check_message(
2593            server.next().await.unwrap(),
2594            protocol::ModifyChannel {
2595                channel_id: ChannelId(0),
2596                target_vp: 1,
2597            },
2598        );
2599
2600        let client_stop = client.stop();
2601        let server_stop = async {
2602            server.send(in_msg(
2603                MessageType::MODIFY_CHANNEL_RESPONSE,
2604                protocol::ModifyChannelResponse {
2605                    channel_id: ChannelId(0),
2606                    status: protocol::STATUS_SUCCESS,
2607                },
2608            ));
2609            check_message(server.next().await.unwrap(), protocol::Pause);
2610            server.send(in_msg(MessageType::PAUSE_RESPONSE, protocol::PauseResponse));
2611        };
2612        (client_stop, server_stop).join().await;
2613
2614        rpc.await.unwrap();
2615
2616        let s0 = client.save().await;
2617        let builder = client.sever().await;
2618        let mut client = builder.build(&driver);
2619        let connection = client.restore(s0.clone()).await.unwrap().unwrap();
2620        let s1 = client.save().await;
2621        assert_eq!(s0, s1);
2622        assert!(connection.offers.is_empty());
2623        server.start_client(&mut client).await;
2624        check_message(
2625            server.next().await.unwrap(),
2626            protocol::RelIdReleased {
2627                channel_id: ChannelId(0),
2628            },
2629        );
2630    }
2631
2632    #[async_test]
2633    async fn test_connect_fails_on_incorrect_state(driver: DefaultDriver) {
2634        let (mut server, mut client) = test_init(&driver);
2635        server.connect(&mut client).await;
2636        let err = client.connect(0, None, Guid::ZERO).await.unwrap_err();
2637        assert!(matches!(err, ConnectError::InvalidState), "{:?}", err);
2638    }
2639
2640    #[async_test]
2641    async fn test_hot_add_remove(driver: DefaultDriver) {
2642        let (mut server, mut client) = test_init(&driver);
2643
2644        let mut connection = server.connect(&mut client).await;
2645        let offer = protocol::OfferChannel {
2646            interface_id: Guid::new_random(),
2647            instance_id: Guid::new_random(),
2648            rsvd: [0; 4],
2649            flags: OfferFlags::new(),
2650            mmio_megabytes: 0,
2651            user_defined: UserDefinedData::new_zeroed(),
2652            subchannel_index: 0,
2653            mmio_megabytes_optional: 0,
2654            channel_id: ChannelId(5),
2655            monitor_id: 0,
2656            monitor_allocated: 0,
2657            is_dedicated: 0,
2658            connection_id: 0,
2659        };
2660
2661        server.send(in_msg(MessageType::OFFER_CHANNEL, offer));
2662        let info = connection.offer_recv.next().await.unwrap();
2663
2664        assert_eq!(offer, info.offer);
2665
2666        server.send(in_msg(
2667            MessageType::RESCIND_CHANNEL_OFFER,
2668            protocol::RescindChannelOffer {
2669                channel_id: ChannelId(5),
2670            },
2671        ));
2672
2673        info.revoke_recv.await.unwrap();
2674        drop(info.request_send);
2675
2676        check_message(
2677            server.next().await.unwrap(),
2678            protocol::RelIdReleased {
2679                channel_id: ChannelId(5),
2680            },
2681        );
2682    }
2683
2684    #[async_test]
2685    async fn test_gpadl_success(driver: DefaultDriver) {
2686        let (mut server, mut client) = test_init(&driver);
2687        let channel = server.get_channel(&mut client).await;
2688        let recv = channel.request_send.call(
2689            ChannelRequest::Gpadl,
2690            GpadlRequest {
2691                id: GpadlId(1),
2692                count: 1,
2693                buf: vec![5],
2694            },
2695        );
2696
2697        check_message_with_data(
2698            server.next().await.unwrap(),
2699            protocol::GpadlHeader {
2700                channel_id: ChannelId(0),
2701                gpadl_id: GpadlId(1),
2702                len: 8,
2703                count: 1,
2704            },
2705            0x5u64.as_bytes(),
2706        );
2707
2708        server.send(in_msg(
2709            MessageType::GPADL_CREATED,
2710            protocol::GpadlCreated {
2711                channel_id: ChannelId(0),
2712                gpadl_id: GpadlId(1),
2713                status: protocol::STATUS_SUCCESS,
2714            },
2715        ));
2716
2717        recv.await.unwrap().unwrap();
2718
2719        let rpc = channel
2720            .request_send
2721            .call(ChannelRequest::TeardownGpadl, GpadlId(1));
2722
2723        check_message(
2724            server.next().await.unwrap(),
2725            protocol::GpadlTeardown {
2726                channel_id: ChannelId(0),
2727                gpadl_id: GpadlId(1),
2728            },
2729        );
2730
2731        server.send(in_msg(
2732            MessageType::GPADL_TORNDOWN,
2733            protocol::GpadlTorndown {
2734                gpadl_id: GpadlId(1),
2735            },
2736        ));
2737
2738        rpc.await.unwrap();
2739    }
2740
2741    #[async_test]
2742    async fn test_gpadl_fail(driver: DefaultDriver) {
2743        let (mut server, mut client) = test_init(&driver);
2744        let channel = server.get_channel(&mut client).await;
2745        let recv = channel.request_send.call(
2746            ChannelRequest::Gpadl,
2747            GpadlRequest {
2748                id: GpadlId(1),
2749                count: 1,
2750                buf: vec![7],
2751            },
2752        );
2753
2754        check_message_with_data(
2755            server.next().await.unwrap(),
2756            protocol::GpadlHeader {
2757                channel_id: ChannelId(0),
2758                gpadl_id: GpadlId(1),
2759                len: 8,
2760                count: 1,
2761            },
2762            0x7u64.as_bytes(),
2763        );
2764
2765        server.send(in_msg(
2766            MessageType::GPADL_CREATED,
2767            protocol::GpadlCreated {
2768                channel_id: ChannelId(0),
2769                gpadl_id: GpadlId(1),
2770                status: protocol::STATUS_UNSUCCESSFUL,
2771            },
2772        ));
2773
2774        recv.await.unwrap().unwrap_err();
2775    }
2776
2777    #[async_test]
2778    async fn test_gpadl_with_revoke(driver: DefaultDriver) {
2779        let (mut server, mut client) = test_init(&driver);
2780        let channel = server.get_channel(&mut client).await;
2781        let channel_id = ChannelId(0);
2782        for gpadl_id in [1, 2, 3].map(GpadlId) {
2783            let recv = channel.request_send.call(
2784                ChannelRequest::Gpadl,
2785                GpadlRequest {
2786                    id: gpadl_id,
2787                    count: 1,
2788                    buf: vec![3],
2789                },
2790            );
2791
2792            check_message_with_data(
2793                server.next().await.unwrap(),
2794                protocol::GpadlHeader {
2795                    channel_id,
2796                    gpadl_id,
2797                    len: 8,
2798                    count: 1,
2799                },
2800                0x3u64.as_bytes(),
2801            );
2802
2803            server.send(in_msg(
2804                MessageType::GPADL_CREATED,
2805                protocol::GpadlCreated {
2806                    channel_id,
2807                    gpadl_id,
2808                    status: protocol::STATUS_SUCCESS,
2809                },
2810            ));
2811
2812            recv.await.unwrap().unwrap();
2813        }
2814
2815        let rpc = channel
2816            .request_send
2817            .call(ChannelRequest::TeardownGpadl, GpadlId(1));
2818
2819        check_message(
2820            server.next().await.unwrap(),
2821            protocol::GpadlTeardown {
2822                channel_id,
2823                gpadl_id: GpadlId(1),
2824            },
2825        );
2826
2827        server.send(in_msg(
2828            MessageType::RESCIND_CHANNEL_OFFER,
2829            protocol::RescindChannelOffer { channel_id },
2830        ));
2831
2832        let recv = channel.request_send.call_failable(
2833            ChannelRequest::Gpadl,
2834            GpadlRequest {
2835                id: GpadlId(4),
2836                count: 1,
2837                buf: vec![3],
2838            },
2839        );
2840
2841        check_message_with_data(
2842            server.next().await.unwrap(),
2843            protocol::GpadlHeader {
2844                channel_id,
2845                gpadl_id: GpadlId(4),
2846                len: 8,
2847                count: 1,
2848            },
2849            0x3u64.as_bytes(),
2850        );
2851
2852        server.send(in_msg(
2853            MessageType::GPADL_CREATED,
2854            protocol::GpadlCreated {
2855                channel_id,
2856                gpadl_id: GpadlId(4),
2857                status: protocol::STATUS_UNSUCCESSFUL,
2858            },
2859        ));
2860
2861        server.send(in_msg(
2862            MessageType::GPADL_TORNDOWN,
2863            protocol::GpadlTorndown {
2864                gpadl_id: GpadlId(1),
2865            },
2866        ));
2867
2868        rpc.await.unwrap();
2869        recv.await.unwrap_err();
2870
2871        channel.revoke_recv.await.unwrap();
2872
2873        let rpc = channel
2874            .request_send
2875            .call(ChannelRequest::TeardownGpadl, GpadlId(2));
2876        drop(channel.request_send);
2877
2878        check_message(
2879            server.next().await.unwrap(),
2880            protocol::GpadlTeardown {
2881                channel_id,
2882                gpadl_id: GpadlId(2),
2883            },
2884        );
2885
2886        server.send(in_msg(
2887            MessageType::GPADL_TORNDOWN,
2888            protocol::GpadlTorndown {
2889                gpadl_id: GpadlId(2),
2890            },
2891        ));
2892
2893        rpc.await.unwrap();
2894
2895        check_message(
2896            server.next().await.unwrap(),
2897            protocol::RelIdReleased { channel_id },
2898        );
2899    }
2900
2901    #[async_test]
2902    async fn test_modify_connection(driver: DefaultDriver) {
2903        let (mut server, mut client) = test_init(&driver);
2904        server.connect(&mut client).await;
2905        let call = client.access.client_request_send.call(
2906            ClientRequest::Modify,
2907            ModifyConnectionRequest {
2908                monitor_page: Some(MonitorPageGpas {
2909                    child_to_parent: 5,
2910                    parent_to_child: 6,
2911                }),
2912            },
2913        );
2914
2915        check_message(
2916            server.next().await.unwrap(),
2917            protocol::ModifyConnection {
2918                child_to_parent_monitor_page_gpa: 5,
2919                parent_to_child_monitor_page_gpa: 6,
2920            },
2921        );
2922
2923        server.send(in_msg(
2924            MessageType::MODIFY_CONNECTION_RESPONSE,
2925            protocol::ModifyConnectionResponse {
2926                connection_state: ConnectionState::FAILED_LOW_RESOURCES,
2927            },
2928        ));
2929
2930        let result = call.await.unwrap();
2931        assert_eq!(ConnectionState::FAILED_LOW_RESOURCES, result);
2932    }
2933
2934    #[async_test]
2935    async fn test_hvsock(driver: DefaultDriver) {
2936        let (mut server, mut client) = test_init(&driver);
2937        server.connect(&mut client).await;
2938        let request = HvsockConnectRequest {
2939            service_id: Guid::new_random(),
2940            endpoint_id: Guid::new_random(),
2941            silo_id: Guid::new_random(),
2942            hosted_silo_unaware: false,
2943        };
2944
2945        let resp = client.access().connect_hvsock(request);
2946        check_message(
2947            server.next().await.unwrap(),
2948            protocol::TlConnectRequest2 {
2949                base: protocol::TlConnectRequest {
2950                    service_id: request.service_id,
2951                    endpoint_id: request.endpoint_id,
2952                },
2953                silo_id: request.silo_id,
2954            },
2955        );
2956
2957        // Now send a failure result.
2958        server.send(in_msg(
2959            MessageType::TL_CONNECT_REQUEST_RESULT,
2960            protocol::TlConnectResult {
2961                service_id: request.service_id,
2962                endpoint_id: request.endpoint_id,
2963                status: protocol::STATUS_CONNECTION_REFUSED,
2964            },
2965        ));
2966
2967        let result = resp.await;
2968        assert!(result.is_none());
2969    }
2970
2971    #[async_test]
2972    async fn test_synic_event_flags(driver: DefaultDriver) {
2973        let (mut server, mut client) = test_init(&driver);
2974        let connection = server.get_channels(&mut client, 5).await;
2975        let event = Event::new();
2976
2977        for _ in 0..5 {
2978            for (i, channel) in connection.offers.iter().enumerate() {
2979                let recv = channel.request_send.call(
2980                    ChannelRequest::Open,
2981                    OpenRequest {
2982                        open_data: OpenData {
2983                            target_vp: 0,
2984                            ring_offset: 0,
2985                            ring_gpadl_id: GpadlId(0),
2986                            event_flag: 0,
2987                            connection_id: 0,
2988                            user_data: UserDefinedData::new_zeroed(),
2989                        },
2990                        incoming_event: Some(event.clone()),
2991                        use_vtl2_connection_id: false,
2992                    },
2993                );
2994
2995                let expected_event_flag = i as u16 + 1;
2996
2997                check_message(
2998                    server.next().await.unwrap(),
2999                    protocol::OpenChannel2 {
3000                        open_channel: protocol::OpenChannel {
3001                            channel_id: channel.offer.channel_id,
3002                            open_id: 0,
3003                            ring_buffer_gpadl_id: GpadlId(0),
3004                            target_vp: 0,
3005                            downstream_ring_buffer_page_offset: 0,
3006                            user_data: UserDefinedData::new_zeroed(),
3007                        },
3008                        connection_id: 0,
3009                        event_flag: expected_event_flag,
3010                        flags: OpenChannelFlags::new().with_redirect_interrupt(true),
3011                    },
3012                );
3013
3014                server.send(in_msg(
3015                    MessageType::OPEN_CHANNEL_RESULT,
3016                    protocol::OpenResult {
3017                        channel_id: channel.offer.channel_id,
3018                        open_id: 0,
3019                        status: protocol::STATUS_SUCCESS as u32,
3020                    },
3021                ));
3022
3023                let output = recv.await.unwrap().unwrap();
3024                assert_eq!(output.redirected_event_flag, Some(expected_event_flag));
3025            }
3026
3027            for (i, channel) in connection.offers.iter().enumerate() {
3028                // Close the channel to prepare for the next iteration of the loop.
3029                // The event flag should be the same each time.
3030                channel
3031                    .request_send
3032                    .call(ChannelRequest::Close, ())
3033                    .await
3034                    .unwrap();
3035
3036                check_message(
3037                    server.next().await.unwrap(),
3038                    protocol::CloseChannel {
3039                        channel_id: ChannelId(i as u32),
3040                    },
3041                );
3042            }
3043        }
3044    }
3045
3046    #[async_test]
3047    async fn test_revoke(driver: DefaultDriver) {
3048        let (mut server, mut client) = test_init(&driver);
3049        let channel = server.get_channel(&mut client).await;
3050
3051        server.send(in_msg(
3052            MessageType::RESCIND_CHANNEL_OFFER,
3053            protocol::RescindChannelOffer {
3054                channel_id: ChannelId(0),
3055            },
3056        ));
3057
3058        channel.revoke_recv.await.unwrap();
3059
3060        channel
3061            .request_send
3062            .call_failable(
3063                ChannelRequest::Open,
3064                OpenRequest {
3065                    open_data: OpenData {
3066                        target_vp: 0,
3067                        ring_offset: 0,
3068                        ring_gpadl_id: GpadlId(0),
3069                        event_flag: 0,
3070                        connection_id: 0,
3071                        user_data: UserDefinedData::new_zeroed(),
3072                    },
3073                    incoming_event: None,
3074                    use_vtl2_connection_id: false,
3075                },
3076            )
3077            .await
3078            .unwrap_err();
3079    }
3080
3081    #[async_test]
3082    #[should_panic(expected = "channel should not exist")]
3083    async fn test_reoffer_in_use_rel_id(driver: DefaultDriver) {
3084        let (mut server, mut client) = test_init(&driver);
3085        let mut connection = server.get_channels(&mut client, 1).await;
3086        let [channel] = connection.offers.try_into().unwrap();
3087
3088        server.send(in_msg(
3089            MessageType::RESCIND_CHANNEL_OFFER,
3090            protocol::RescindChannelOffer {
3091                channel_id: ChannelId(0),
3092            },
3093        ));
3094
3095        channel.revoke_recv.await.unwrap();
3096
3097        // This offer will cause a panic since the rel id is still in use.
3098        let offer = protocol::OfferChannel {
3099            interface_id: Guid::new_random(),
3100            instance_id: Guid::new_random(),
3101            rsvd: [0; 4],
3102            flags: OfferFlags::new(),
3103            mmio_megabytes: 0,
3104            user_defined: UserDefinedData::new_zeroed(),
3105            subchannel_index: 0,
3106            mmio_megabytes_optional: 0,
3107            channel_id: ChannelId(0),
3108            monitor_id: 0,
3109            monitor_allocated: 0,
3110            is_dedicated: 0,
3111            connection_id: 0,
3112        };
3113
3114        server.send(in_msg(MessageType::OFFER_CHANNEL, offer));
3115
3116        connection.offer_recv.next().await;
3117    }
3118
3119    #[async_test]
3120    async fn test_revoke_release_and_reoffer(driver: DefaultDriver) {
3121        let (mut server, mut client) = test_init(&driver);
3122        let mut connection = server.get_channels(&mut client, 1).await;
3123        let [channel] = connection.offers.try_into().unwrap();
3124
3125        server.send(in_msg(
3126            MessageType::RESCIND_CHANNEL_OFFER,
3127            protocol::RescindChannelOffer {
3128                channel_id: ChannelId(0),
3129            },
3130        ));
3131
3132        channel.revoke_recv.await.unwrap();
3133        drop(channel.request_send);
3134
3135        check_message(
3136            server.next().await.unwrap(),
3137            protocol::RelIdReleased {
3138                channel_id: ChannelId(0),
3139            },
3140        );
3141
3142        let offer = protocol::OfferChannel {
3143            interface_id: Guid::new_random(),
3144            instance_id: Guid::new_random(),
3145            rsvd: [0; 4],
3146            flags: OfferFlags::new(),
3147            mmio_megabytes: 0,
3148            user_defined: UserDefinedData::new_zeroed(),
3149            subchannel_index: 0,
3150            mmio_megabytes_optional: 0,
3151            channel_id: ChannelId(0),
3152            monitor_id: 0,
3153            monitor_allocated: 0,
3154            is_dedicated: 0,
3155            connection_id: 0,
3156        };
3157
3158        server.send(in_msg(MessageType::OFFER_CHANNEL, offer));
3159
3160        connection.offer_recv.next().await.unwrap();
3161    }
3162
3163    #[async_test]
3164    async fn test_release_revoke_and_reoffer(driver: DefaultDriver) {
3165        let (mut server, mut client) = test_init(&driver);
3166        let mut connection = server.get_channels(&mut client, 1).await;
3167        let [channel] = connection.offers.try_into().unwrap();
3168
3169        let open = channel.request_send.call_failable(
3170            ChannelRequest::Open,
3171            OpenRequest {
3172                open_data: OpenData {
3173                    target_vp: 0,
3174                    ring_offset: 0,
3175                    ring_gpadl_id: GpadlId(0),
3176                    event_flag: 0,
3177                    connection_id: 0,
3178                    user_data: UserDefinedData::new_zeroed(),
3179                },
3180                incoming_event: None,
3181                use_vtl2_connection_id: false,
3182            },
3183        );
3184
3185        let server_open = async {
3186            check_message(
3187                server.next().await.unwrap(),
3188                protocol::OpenChannel2 {
3189                    open_channel: protocol::OpenChannel {
3190                        channel_id: ChannelId(0),
3191                        open_id: 0,
3192                        ring_buffer_gpadl_id: GpadlId(0),
3193                        target_vp: 0,
3194                        downstream_ring_buffer_page_offset: 0,
3195                        user_data: UserDefinedData::new_zeroed(),
3196                    },
3197                    connection_id: 0,
3198                    event_flag: 0,
3199                    flags: Default::default(),
3200                },
3201            );
3202            server.send(in_msg(
3203                MessageType::OPEN_CHANNEL_RESULT,
3204                protocol::OpenResult {
3205                    channel_id: ChannelId(0),
3206                    open_id: 0,
3207                    status: protocol::STATUS_SUCCESS as u32,
3208                },
3209            ));
3210        };
3211
3212        (open, server_open).join().await.0.unwrap();
3213
3214        // This will close the channel but won't release it yet.
3215        drop(channel);
3216
3217        check_message(
3218            server.next().await.unwrap(),
3219            protocol::CloseChannel {
3220                channel_id: ChannelId(0),
3221            },
3222        );
3223
3224        server.send(in_msg(
3225            MessageType::RESCIND_CHANNEL_OFFER,
3226            protocol::RescindChannelOffer {
3227                channel_id: ChannelId(0),
3228            },
3229        ));
3230
3231        // Should be released.
3232        check_message(
3233            server.next().await.unwrap(),
3234            protocol::RelIdReleased {
3235                channel_id: ChannelId(0),
3236            },
3237        );
3238
3239        let offer = protocol::OfferChannel {
3240            interface_id: Guid::new_random(),
3241            instance_id: Guid::new_random(),
3242            rsvd: [0; 4],
3243            flags: OfferFlags::new(),
3244            mmio_megabytes: 0,
3245            user_defined: UserDefinedData::new_zeroed(),
3246            subchannel_index: 0,
3247            mmio_megabytes_optional: 0,
3248            channel_id: ChannelId(0),
3249            monitor_id: 0,
3250            monitor_allocated: 0,
3251            is_dedicated: 0,
3252            connection_id: 0,
3253        };
3254
3255        server.send(in_msg(MessageType::OFFER_CHANNEL, offer));
3256
3257        // New offer should come through.
3258        connection.offer_recv.next().await.unwrap();
3259    }
3260}