vmbus_server/
lib.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4#![expect(missing_docs)]
5#![forbid(unsafe_code)]
6
7mod channel_bitmap;
8pub mod channels;
9pub mod event;
10pub mod hvsock;
11mod monitor;
12mod proxyintegration;
13
14/// The GUID type used for vmbus channel identifiers.
15pub type Guid = guid::Guid;
16
17use anyhow::Context;
18use async_trait::async_trait;
19use channel_bitmap::ChannelBitmap;
20use channels::ConnectionTarget;
21pub use channels::InitiateContactRequest;
22use channels::MessageTarget;
23pub use channels::MnfUsage;
24use channels::ModifyConnectionRequest;
25pub use channels::ModifyConnectionResponse;
26use channels::Notifier;
27use channels::OfferId;
28pub use channels::OfferParamsInternal;
29use channels::OpenParams;
30use channels::RestoreError;
31pub use channels::Update;
32use futures::FutureExt;
33use futures::StreamExt;
34use futures::channel::mpsc;
35use futures::channel::mpsc::SendError;
36use futures::future::OptionFuture;
37use futures::future::poll_fn;
38use futures::stream::SelectAll;
39use guestmem::GuestMemory;
40use hvdef::Vtl;
41use inspect::Inspect;
42use mesh::payload::Protobuf;
43use mesh::rpc::FailableRpc;
44use mesh::rpc::Rpc;
45use mesh::rpc::RpcError;
46use mesh::rpc::RpcSend;
47use pal_async::task::Spawn;
48use pal_async::task::Task;
49use pal_event::Event;
50#[cfg(windows)]
51pub use proxyintegration::ProxyIntegration;
52#[cfg(windows)]
53pub use proxyintegration::ProxyServerInfo;
54use ring::PAGE_SIZE;
55use std::collections::HashMap;
56use std::future;
57use std::future::Future;
58use std::pin::Pin;
59use std::sync::Arc;
60use std::task::Poll;
61use std::task::ready;
62use std::time::Duration;
63use unicycle::FuturesUnordered;
64use vmbus_channel::bus::ChannelRequest;
65use vmbus_channel::bus::ChannelServerRequest;
66use vmbus_channel::bus::GpadlRequest;
67use vmbus_channel::bus::ModifyRequest;
68use vmbus_channel::bus::OfferInput;
69use vmbus_channel::bus::OfferKey;
70use vmbus_channel::bus::OfferResources;
71use vmbus_channel::bus::OpenData;
72use vmbus_channel::bus::OpenRequest;
73use vmbus_channel::bus::OpenResult;
74use vmbus_channel::bus::ParentBus;
75use vmbus_channel::bus::RestoreResult;
76use vmbus_channel::gpadl::GpadlMap;
77use vmbus_channel::gpadl_ring::AlignedGpadlView;
78use vmbus_channel::gpadl_ring::GpadlRingMem;
79use vmbus_core::HvsockConnectRequest;
80use vmbus_core::HvsockConnectResult;
81use vmbus_core::MaxVersionInfo;
82use vmbus_core::OutgoingMessage;
83use vmbus_core::TaggedStream;
84use vmbus_core::VersionInfo;
85use vmbus_core::protocol;
86pub use vmbus_core::protocol::GpadlId;
87#[cfg(windows)]
88use vmbus_proxy::ProxyHandle;
89use vmbus_ring as ring;
90use vmbus_ring::gparange::MultiPagedRangeBuf;
91use vmcore::interrupt::Interrupt;
92use vmcore::save_restore::SavedStateRoot;
93use vmcore::synic::EventPort;
94use vmcore::synic::GuestEventPort;
95use vmcore::synic::GuestMessagePort;
96use vmcore::synic::MessagePort;
97use vmcore::synic::MonitorPageGpas;
98use vmcore::synic::SynicPortAccess;
99
100const SINT: u8 = 2;
101pub const REDIRECT_SINT: u8 = 7;
102pub const REDIRECT_VTL: Vtl = Vtl::Vtl2;
103const SHARED_EVENT_CONNECTION_ID: u32 = 2;
104const EVENT_PORT_ID: u32 = 2;
105const VMBUS_MESSAGE_TYPE: u32 = 1;
106
107const MAX_CONCURRENT_HVSOCK_REQUESTS: usize = 16;
108
109pub struct VmbusServer {
110    task_send: mesh::Sender<VmbusRequest>,
111    control: Arc<VmbusServerControl>,
112    _message_port: Box<dyn Sync + Send>,
113    _multiclient_message_port: Option<Box<dyn Sync + Send>>,
114    task: Task<ServerTask>,
115}
116
117pub struct VmbusServerBuilder<'a, T: Spawn> {
118    spawner: &'a T,
119    synic: Arc<dyn SynicPortAccess>,
120    gm: GuestMemory,
121    private_gm: Option<GuestMemory>,
122    vtl: Vtl,
123    hvsock_notify: Option<HvsockServerChannelHalf>,
124    server_relay: Option<VmbusServerChannelHalf>,
125    saved_state_notify: Option<mesh::Sender<SavedStateRequest>>,
126    external_server: Option<mesh::Sender<InitiateContactRequest>>,
127    external_requests: Option<mesh::Receiver<InitiateContactRequest>>,
128    use_message_redirect: bool,
129    channel_id_offset: u16,
130    max_version: Option<MaxVersionInfo>,
131    delay_max_version: bool,
132    enable_mnf: bool,
133    force_confidential_external_memory: bool,
134    send_messages_while_stopped: bool,
135}
136
137#[derive(mesh::MeshPayload)]
138/// The request to send to the proxy to set or clear its saved state cache.
139pub enum SavedStateRequest {
140    Set(FailableRpc<Box<channels::SavedState>, ()>),
141    Clear(Rpc<(), ()>),
142}
143
144/// The server side of the connection between a vmbus server and a relay.
145pub struct ServerChannelHalf<Request, Response> {
146    request_send: mesh::Sender<Request>,
147    response_receive: mesh::Receiver<Response>,
148}
149
150/// The relay side of a connection between a vmbus server and a relay.
151pub struct RelayChannelHalf<Request, Response> {
152    pub request_receive: mesh::Receiver<Request>,
153    pub response_send: mesh::Sender<Response>,
154}
155
156/// A connection between a vmbus server and a relay.
157pub struct RelayChannel<Request, Response> {
158    pub relay_half: RelayChannelHalf<Request, Response>,
159    pub server_half: ServerChannelHalf<Request, Response>,
160}
161
162impl<Request: 'static + Send, Response: 'static + Send> RelayChannel<Request, Response> {
163    /// Creates a new channel between the vmbus server and a relay.
164    pub fn new() -> Self {
165        let (request_send, request_receive) = mesh::channel();
166        let (response_send, response_receive) = mesh::channel();
167        Self {
168            relay_half: RelayChannelHalf {
169                request_receive,
170                response_send,
171            },
172            server_half: ServerChannelHalf {
173                request_send,
174                response_receive,
175            },
176        }
177    }
178}
179
180pub type VmbusServerChannelHalf = ServerChannelHalf<ModifyRelayRequest, ModifyConnectionResponse>;
181pub type VmbusRelayChannelHalf = RelayChannelHalf<ModifyRelayRequest, ModifyConnectionResponse>;
182pub type VmbusRelayChannel = RelayChannel<ModifyRelayRequest, ModifyConnectionResponse>;
183pub type HvsockServerChannelHalf = ServerChannelHalf<HvsockConnectRequest, HvsockConnectResult>;
184pub type HvsockRelayChannelHalf = RelayChannelHalf<HvsockConnectRequest, HvsockConnectResult>;
185pub type HvsockRelayChannel = RelayChannel<HvsockConnectRequest, HvsockConnectResult>;
186
187/// A request from the server to the relay to modify connection state.
188///
189/// The version, use_interrupt_page and target_message_vp field can only be present if this request
190/// was sent for an InitiateContact message from the guest.
191#[derive(Debug, Copy, Clone)]
192pub struct ModifyRelayRequest {
193    pub version: Option<u32>,
194    pub monitor_page: Update<MonitorPageGpas>,
195    pub use_interrupt_page: Option<bool>,
196}
197
198impl From<ModifyConnectionRequest> for ModifyRelayRequest {
199    fn from(value: ModifyConnectionRequest) -> Self {
200        Self {
201            version: value.version,
202            monitor_page: value.monitor_page,
203            use_interrupt_page: match value.interrupt_page {
204                Update::Unchanged => None,
205                Update::Reset => Some(false),
206                Update::Set(_) => Some(true),
207            },
208        }
209    }
210}
211
212#[derive(Debug)]
213enum VmbusRequest {
214    Reset(Rpc<(), ()>),
215    Inspect(inspect::Deferred),
216    Save(Rpc<(), SavedState>),
217    Restore(Rpc<Box<SavedState>, Result<(), RestoreError>>),
218    Start,
219    Stop(Rpc<(), ()>),
220}
221
222#[derive(mesh::MeshPayload, Debug)]
223pub struct OfferInfo {
224    pub params: OfferParamsInternal,
225    pub request_send: mesh::Sender<ChannelRequest>,
226    pub server_request_recv: mesh::Receiver<ChannelServerRequest>,
227}
228
229#[expect(clippy::large_enum_variant)]
230#[derive(mesh::MeshPayload)]
231pub(crate) enum OfferRequest {
232    Offer(FailableRpc<OfferInfo, ()>),
233    ForceReset(Rpc<(), ()>),
234}
235
236impl Inspect for VmbusServer {
237    fn inspect(&self, req: inspect::Request<'_>) {
238        self.task_send.send(VmbusRequest::Inspect(req.defer()));
239    }
240}
241
242struct ChannelEvent(Interrupt);
243
244impl EventPort for ChannelEvent {
245    fn handle_event(&self, _flag: u16) {
246        self.0.deliver();
247    }
248
249    fn os_event(&self) -> Option<&Event> {
250        self.0.event()
251    }
252}
253
254#[derive(Debug, Protobuf, SavedStateRoot)]
255#[mesh(package = "vmbus.server")]
256pub struct SavedState {
257    #[mesh(1)]
258    server: channels::SavedState,
259    // Indicates if the lost synic bug is fixed or not. By default it's false.
260    // During the restore process, we check if the field is not true then
261    // unstick_channels() function will be called to mitigate the issue.
262    #[mesh(2)]
263    lost_synic_bug_fixed: bool,
264}
265
266const MESSAGE_CONNECTION_ID: u32 = 1;
267const MULTICLIENT_MESSAGE_CONNECTION_ID: u32 = 4;
268
269impl<'a, T: Spawn> VmbusServerBuilder<'a, T> {
270    /// Creates a new builder for `VmbusServer` with the default options.
271    pub fn new(spawner: &'a T, synic: Arc<dyn SynicPortAccess>, gm: GuestMemory) -> Self {
272        Self {
273            spawner,
274            synic,
275            gm,
276            private_gm: None,
277            vtl: Vtl::Vtl0,
278            hvsock_notify: None,
279            server_relay: None,
280            saved_state_notify: None,
281            external_server: None,
282            external_requests: None,
283            use_message_redirect: false,
284            channel_id_offset: 0,
285            max_version: None,
286            delay_max_version: false,
287            enable_mnf: false,
288            force_confidential_external_memory: false,
289            send_messages_while_stopped: false,
290        }
291    }
292
293    /// Sets a separate guest memory instance to use for channels that are confidential (non-relay
294    /// channels in Underhill on a hardware isolated VM). This is not relevant for a non-Underhill
295    /// VmBus server.
296    pub fn private_gm(mut self, private_gm: Option<GuestMemory>) -> Self {
297        self.private_gm = private_gm;
298        self
299    }
300
301    /// Sets the VTL that this instance will serve.
302    pub fn vtl(mut self, vtl: Vtl) -> Self {
303        self.vtl = vtl;
304        self
305    }
306
307    /// Sets a send/receive pair used to handle hvsocket requests.
308    pub fn hvsock_notify(mut self, hvsock_notify: Option<HvsockServerChannelHalf>) -> Self {
309        self.hvsock_notify = hvsock_notify;
310        self
311    }
312
313    /// Sets a send channel used to enlighten ProxyIntegration about saved channels.
314    pub fn saved_state_notify(
315        mut self,
316        saved_state_notify: Option<mesh::Sender<SavedStateRequest>>,
317    ) -> Self {
318        self.saved_state_notify = saved_state_notify;
319        self
320    }
321
322    /// Sets a send/receive pair that will be notified of server requests. This is used by the
323    /// Underhill relay.
324    pub fn server_relay(mut self, server_relay: Option<VmbusServerChannelHalf>) -> Self {
325        self.server_relay = server_relay;
326        self
327    }
328
329    /// Sets a receiver that receives requests from another server.
330    pub fn external_requests(
331        mut self,
332        external_requests: Option<mesh::Receiver<InitiateContactRequest>>,
333    ) -> Self {
334        self.external_requests = external_requests;
335        self
336    }
337
338    /// Sets a sender used to forward unhandled connect requests (which used a different VTL)
339    /// to another server.
340    pub fn external_server(
341        mut self,
342        external_server: Option<mesh::Sender<InitiateContactRequest>>,
343    ) -> Self {
344        self.external_server = external_server;
345        self
346    }
347
348    /// Sets a value which indicates whether the vmbus control plane is redirected to Underhill.
349    pub fn use_message_redirect(mut self, use_message_redirect: bool) -> Self {
350        self.use_message_redirect = use_message_redirect;
351        self
352    }
353
354    /// Tells the server to use an offset when generating channel IDs to void collisions with
355    /// another vmbus server.
356    ///
357    /// N.B. This should only be used by the Underhill vmbus server.
358    pub fn enable_channel_id_offset(mut self, enable: bool) -> Self {
359        self.channel_id_offset = if enable { 1024 } else { 0 };
360        self
361    }
362
363    /// Tells the server to limit the protocol version offered to the guest.
364    ///
365    /// N.B. This is used for testing older protocols without requiring a specific guest OS.
366    pub fn max_version(mut self, max_version: Option<MaxVersionInfo>) -> Self {
367        self.max_version = max_version;
368        self
369    }
370
371    /// Delay limiting the maximum version until after the first `Unload` message.
372    ///
373    /// N.B. This is used to enable the use of versions older than `Version::Win10` with Uefi boot,
374    ///      since that's the oldest version the Uefi client supports.
375    pub fn delay_max_version(mut self, delay: bool) -> Self {
376        self.delay_max_version = delay;
377        self
378    }
379
380    /// Enable MNF support in the server.
381    ///
382    /// N.B. Enabling this has no effect if the synic does not support mapping monitor pages.
383    pub fn enable_mnf(mut self, enable: bool) -> Self {
384        self.enable_mnf = enable;
385        self
386    }
387
388    /// Force all non-relay channels to use encrypted external memory. Used for testing purposes
389    /// only.
390    pub fn force_confidential_external_memory(mut self, force: bool) -> Self {
391        self.force_confidential_external_memory = force;
392        self
393    }
394
395    /// Send messages to the partition even while stopped, which can cause
396    /// corrupted synic states across VM reset.
397    ///
398    /// This option is used to prevent messages from getting into the queue, for
399    /// saved state compatibility with release/2411. It can be removed once that
400    /// release is no longer supported.
401    pub fn send_messages_while_stopped(mut self, send: bool) -> Self {
402        self.send_messages_while_stopped = send;
403        self
404    }
405
406    /// Creates a new instance of the server.
407    ///
408    /// When the object is dropped, all channels will be closed and revoked
409    /// automatically.
410    pub fn build(self) -> anyhow::Result<VmbusServer> {
411        #[expect(clippy::disallowed_methods)] // TODO
412        let (message_send, message_recv) = mpsc::channel(64);
413        let message_sender = Arc::new(MessageSender {
414            send: message_send.clone(),
415            multiclient: self.use_message_redirect,
416        });
417
418        let (redirect_vtl, redirect_sint) = if self.use_message_redirect {
419            (REDIRECT_VTL, REDIRECT_SINT)
420        } else {
421            (self.vtl, SINT)
422        };
423
424        // If this server is not for VTL2, use a server-specific connection ID rather than the
425        // standard one.
426        let connection_id = if self.vtl == Vtl::Vtl0 && !self.use_message_redirect {
427            MESSAGE_CONNECTION_ID
428        } else {
429            // TODO: This ID should be using the correct target VP, but that is not known until
430            //       InitiateContact.
431            VmbusServer::get_child_message_connection_id(0, redirect_sint, redirect_vtl)
432        };
433
434        let _message_port = self
435            .synic
436            .add_message_port(connection_id, redirect_vtl, message_sender)
437            .context("failed to create vmbus synic ports")?;
438
439        // If this server is for VTL0, it is also responsible for the multiclient message port.
440        // N.B. If control plane redirection is enabled, the redirected message port is used for
441        //      multiclient and no separate multiclient port is created.
442        let _multiclient_message_port = if self.vtl == Vtl::Vtl0 && !self.use_message_redirect {
443            let multiclient_message_sender = Arc::new(MessageSender {
444                send: message_send,
445                multiclient: true,
446            });
447
448            Some(
449                self.synic
450                    .add_message_port(
451                        MULTICLIENT_MESSAGE_CONNECTION_ID,
452                        self.vtl,
453                        multiclient_message_sender,
454                    )
455                    .context("failed to create vmbus synic ports")?,
456            )
457        } else {
458            None
459        };
460
461        let (offer_send, offer_recv) = mesh::mpsc_channel();
462        let control = Arc::new(VmbusServerControl {
463            mem: self.gm.clone(),
464            private_mem: self.private_gm.clone(),
465            send: offer_send,
466            use_event: self.synic.prefer_os_events(),
467            force_confidential_external_memory: self.force_confidential_external_memory,
468        });
469
470        let mut server = channels::Server::new(self.vtl, connection_id, self.channel_id_offset);
471
472        // If requested, limit the maximum protocol version and feature flags.
473        if let Some(version) = self.max_version {
474            server.set_compatibility_version(version, self.delay_max_version);
475        }
476        let (relay_request_send, relay_response_recv) =
477            if let Some(server_relay) = self.server_relay {
478                let r = server_relay.response_receive.boxed().fuse();
479                (server_relay.request_send, r)
480            } else {
481                let (req_send, req_recv) = mesh::channel();
482                let resp_recv = req_recv
483                    .map(|_| {
484                        ModifyConnectionResponse::Supported(
485                            protocol::ConnectionState::SUCCESSFUL,
486                            protocol::FeatureFlags::from_bits(u32::MAX),
487                        )
488                    })
489                    .boxed()
490                    .fuse();
491                (req_send, resp_recv)
492            };
493
494        // If no hvsock notifier was specified, use a default one that always sends an error response.
495        let (hvsock_send, hvsock_recv) = if let Some(hvsock_notify) = self.hvsock_notify {
496            let r = hvsock_notify.response_receive.boxed().fuse();
497            (hvsock_notify.request_send, r)
498        } else {
499            let (req_send, req_recv) = mesh::channel();
500            let resp_recv = req_recv
501                .map(|r: HvsockConnectRequest| HvsockConnectResult::from_request(&r, false))
502                .boxed()
503                .fuse();
504            (req_send, resp_recv)
505        };
506
507        let inner = ServerTaskInner {
508            running: false,
509            send_messages_while_stopped: self.send_messages_while_stopped,
510            gm: self.gm,
511            private_gm: self.private_gm,
512            vtl: self.vtl,
513            redirect_vtl,
514            redirect_sint,
515            message_port: self
516                .synic
517                .new_guest_message_port(redirect_vtl, 0, redirect_sint)?,
518            synic: self.synic,
519            hvsock_requests: 0,
520            hvsock_send,
521            saved_state_notify: self.saved_state_notify,
522            channels: HashMap::new(),
523            channel_responses: FuturesUnordered::new(),
524            relay_send: relay_request_send,
525            external_server_send: self.external_server,
526            channel_bitmap: None,
527            shared_event_port: None,
528            reset_done: Vec::new(),
529            enable_mnf: self.enable_mnf,
530        };
531
532        let (task_send, task_recv) = mesh::channel();
533        let mut server_task = ServerTask {
534            server,
535            task_recv,
536            offer_recv,
537            message_recv,
538            server_request_recv: SelectAll::new(),
539            inner,
540            external_requests: self.external_requests,
541            next_seq: 0,
542            unstick_on_start: false,
543        };
544
545        let task = self.spawner.spawn("vmbus server", async move {
546            server_task.run(relay_response_recv, hvsock_recv).await;
547            server_task
548        });
549
550        Ok(VmbusServer {
551            task_send,
552            control,
553            _message_port,
554            _multiclient_message_port,
555            task,
556        })
557    }
558}
559
560impl VmbusServer {
561    /// Creates a new builder for `VmbusServer` with the default options.
562    pub fn builder<T: Spawn>(
563        spawner: &T,
564        synic: Arc<dyn SynicPortAccess>,
565        gm: GuestMemory,
566    ) -> VmbusServerBuilder<'_, T> {
567        VmbusServerBuilder::new(spawner, synic, gm)
568    }
569
570    pub async fn save(&self) -> SavedState {
571        self.task_send.call(VmbusRequest::Save, ()).await.unwrap()
572    }
573
574    pub async fn restore(&self, state: SavedState) -> Result<(), RestoreError> {
575        self.task_send
576            .call(VmbusRequest::Restore, Box::new(state))
577            .await
578            .unwrap()
579    }
580
581    /// Stop the control plane.
582    pub async fn stop(&self) {
583        self.task_send.call(VmbusRequest::Stop, ()).await.unwrap()
584    }
585
586    /// Starts the control plane.
587    pub fn start(&self) {
588        self.task_send.send(VmbusRequest::Start);
589    }
590
591    /// Resets the vmbus channel state.
592    pub async fn reset(&self) {
593        tracing::debug!("resetting channel state");
594        self.task_send.call(VmbusRequest::Reset, ()).await.unwrap()
595    }
596
597    /// Tears down the vmbus control plane.
598    pub async fn shutdown(self) {
599        drop(self.task_send);
600        let _ = self.task.await;
601    }
602
603    /// Returns an object that can be used to offer channels.
604    pub fn control(&self) -> Arc<VmbusServerControl> {
605        self.control.clone()
606    }
607
608    /// Returns the message connection ID to use for a communication from the guest for servers
609    /// that use a non-standard SINT or VTL.
610    fn get_child_message_connection_id(vp_index: u32, sint_index: u8, vtl: Vtl) -> u32 {
611        MULTICLIENT_MESSAGE_CONNECTION_ID
612            | (vtl as u32) << 22
613            | vp_index << 8
614            | (sint_index as u32) << 4
615    }
616
617    fn get_child_event_port_id(channel_id: protocol::ChannelId, sint_index: u8, vtl: Vtl) -> u32 {
618        EVENT_PORT_ID | (vtl as u32) << 22 | channel_id.0 << 8 | (sint_index as u32) << 4
619    }
620}
621
622#[derive(mesh::MeshPayload)]
623pub struct RestoreInfo {
624    open_data: Option<OpenData>,
625    gpadls: Vec<(GpadlId, u16, Vec<u64>)>,
626    interrupt: Option<Interrupt>,
627}
628
629#[derive(Default)]
630pub struct SynicMessage {
631    data: Vec<u8>,
632    multiclient: bool,
633    trusted: bool,
634}
635
636struct ServerTask {
637    server: channels::Server,
638    task_recv: mesh::Receiver<VmbusRequest>,
639    offer_recv: mesh::Receiver<OfferRequest>,
640    message_recv: mpsc::Receiver<SynicMessage>,
641    server_request_recv: SelectAll<TaggedStream<OfferId, mesh::Receiver<ChannelServerRequest>>>,
642    inner: ServerTaskInner,
643    external_requests: Option<mesh::Receiver<InitiateContactRequest>>,
644    /// Next value for [`Channel::seq`].
645    next_seq: u64,
646    unstick_on_start: bool,
647}
648
649struct ServerTaskInner {
650    running: bool,
651    send_messages_while_stopped: bool,
652    gm: GuestMemory,
653    private_gm: Option<GuestMemory>,
654    synic: Arc<dyn SynicPortAccess>,
655    vtl: Vtl,
656    redirect_vtl: Vtl,
657    redirect_sint: u8,
658    message_port: Box<dyn GuestMessagePort>,
659    hvsock_requests: usize,
660    hvsock_send: mesh::Sender<HvsockConnectRequest>,
661    saved_state_notify: Option<mesh::Sender<SavedStateRequest>>,
662    channels: HashMap<OfferId, Channel>,
663    channel_responses: FuturesUnordered<
664        Pin<Box<dyn Send + Future<Output = (OfferId, u64, Result<ChannelResponse, RpcError>)>>>,
665    >,
666    external_server_send: Option<mesh::Sender<InitiateContactRequest>>,
667    relay_send: mesh::Sender<ModifyRelayRequest>,
668    channel_bitmap: Option<Arc<ChannelBitmap>>,
669    shared_event_port: Option<Box<dyn Send>>,
670    reset_done: Vec<Rpc<(), ()>>,
671    enable_mnf: bool,
672}
673
674#[derive(Debug)]
675enum ChannelResponse {
676    Open(Option<OpenResult>),
677    Close,
678    Gpadl(GpadlId, bool),
679    TeardownGpadl(GpadlId),
680    Modify(i32),
681}
682
683struct Channel {
684    key: OfferKey,
685    send: mesh::Sender<ChannelRequest>,
686    seq: u64,
687    state: ChannelState,
688    gpadls: Arc<GpadlMap>,
689    flags: protocol::OfferFlags,
690    // A channel can be reserved no matter what state it is in. This allows the message port for a
691    // reserved channel to remain available even if the channel is closed, so the guest can read the
692    // close reserved channel response. The reserved state is cleared when the channel is revoked,
693    // reopened, or the guest sends an unload message.
694    reserved_state: ReservedState,
695}
696
697struct ReservedState {
698    message_port: Option<Box<dyn GuestMessagePort>>,
699    target: ConnectionTarget,
700}
701
702enum ChannelState {
703    Closed,
704    Opening {
705        open_params: OpenParams,
706        guest_event_port: Box<dyn GuestEventPort>,
707        host_to_guest_interrupt: Interrupt,
708    },
709    Open {
710        open_params: OpenParams,
711        _event_port: Box<dyn Send>,
712        guest_event_port: Box<dyn GuestEventPort>,
713        host_to_guest_interrupt: Interrupt,
714        guest_to_host_event: Arc<ChannelEvent>,
715    },
716    Closing,
717    FailedOpen,
718}
719
720impl ServerTask {
721    fn handle_offer(&mut self, mut info: OfferInfo) -> anyhow::Result<()> {
722        let key = info.params.key();
723        let flags = info.params.flags;
724
725        if self.inner.enable_mnf && self.inner.synic.monitor_support().is_some() {
726            // If this server is handling MnF, ignore any relayed monitor IDs but still enable MnF
727            // for those channels.
728            // N.B. Since this can only happen in OpenHCL, which emulates MnF, the latency is
729            //      ignored.
730            if info.params.use_mnf.is_relayed() {
731                info.params.use_mnf = MnfUsage::Enabled {
732                    latency: Duration::ZERO,
733                }
734            }
735        } else if info.params.use_mnf.is_enabled() {
736            // If the server is not handling MnF, disable it for the channel. This does not affect
737            // channels with a relayed monitor ID.
738            info.params.use_mnf = MnfUsage::Disabled;
739        }
740
741        let offer_id = self
742            .server
743            .with_notifier(&mut self.inner)
744            .offer_channel(info.params)
745            .context("channel offer failed")?;
746
747        tracing::debug!(?offer_id, %key, "offered channel");
748
749        let id = self.next_seq;
750        self.next_seq += 1;
751        self.inner.channels.insert(
752            offer_id,
753            Channel {
754                key,
755                send: info.request_send,
756                state: ChannelState::Closed,
757                gpadls: GpadlMap::new(),
758                seq: id,
759                flags,
760                reserved_state: ReservedState {
761                    message_port: None,
762                    target: ConnectionTarget { vp: 0, sint: 0 },
763                },
764            },
765        );
766
767        self.server_request_recv
768            .push(TaggedStream::new(offer_id, info.server_request_recv));
769
770        Ok(())
771    }
772
773    fn handle_revoke(&mut self, offer_id: OfferId) {
774        // The channel may or may not exist in the map depending on whether it's been explicitly
775        // revoked before being dropped.
776        if self.inner.channels.remove(&offer_id).is_some() {
777            tracing::info!(?offer_id, "revoking channel");
778            self.server
779                .with_notifier(&mut self.inner)
780                .revoke_channel(offer_id);
781        }
782    }
783
784    fn handle_response(
785        &mut self,
786        offer_id: OfferId,
787        seq: u64,
788        response: Result<ChannelResponse, RpcError>,
789    ) {
790        // Validate the sequence to ensure the response is not for a revoked channel.
791        let channel = self
792            .inner
793            .channels
794            .get(&offer_id)
795            .filter(|channel| channel.seq == seq);
796
797        if let Some(channel) = channel {
798            match response {
799                Ok(response) => match response {
800                    ChannelResponse::Open(result) => self.handle_open(offer_id, result),
801                    ChannelResponse::Close => self.handle_close(offer_id),
802                    ChannelResponse::Gpadl(gpadl_id, ok) => {
803                        self.handle_gpadl_create(offer_id, gpadl_id, ok)
804                    }
805                    ChannelResponse::TeardownGpadl(gpadl_id) => {
806                        self.handle_gpadl_teardown(offer_id, gpadl_id)
807                    }
808                    ChannelResponse::Modify(status) => self.handle_modify_channel(offer_id, status),
809                },
810                Err(err) => {
811                    tracing::error!(
812                        key = %channel.key,
813                        error = &err as &dyn std::error::Error,
814                        "channel response failure, channel is in inconsistent state until revoked"
815                    );
816                }
817            }
818        } else {
819            tracing::debug!(offer_id = ?offer_id, seq, ?response, "received response after revoke");
820        }
821    }
822
823    fn handle_open(&mut self, offer_id: OfferId, result: Option<OpenResult>) {
824        let status = if result.is_some() {
825            0
826        } else {
827            protocol::STATUS_UNSUCCESSFUL
828        };
829        if let Err(err) = self.inner.complete_open(offer_id, result) {
830            tracelimit::error_ratelimited!(
831                error = err.as_ref() as &dyn std::error::Error,
832                "failed to complete open"
833            );
834            // If complete_open failed, the channel is now in FailedOpen state and the device needs
835            // to notified to close it. Calling open_complete is postponed until the device responds
836            // to the close request.
837            self.inner.notify(offer_id, channels::Action::Close);
838        } else {
839            self.server
840                .with_notifier(&mut self.inner)
841                .open_complete(offer_id, status);
842        }
843    }
844
845    fn handle_close(&mut self, offer_id: OfferId) {
846        let channel = self
847            .inner
848            .channels
849            .get_mut(&offer_id)
850            .expect("channel still exists");
851
852        match &mut channel.state {
853            ChannelState::Closing => {
854                channel.state = ChannelState::Closed;
855                self.server
856                    .with_notifier(&mut self.inner)
857                    .close_complete(offer_id);
858            }
859            ChannelState::FailedOpen => {
860                // Now that the device has processed the close request after open failed, we can
861                // finish handling the failed open and send an open result to the guest.
862                channel.state = ChannelState::Closed;
863                self.server
864                    .with_notifier(&mut self.inner)
865                    .open_complete(offer_id, protocol::STATUS_UNSUCCESSFUL);
866            }
867            _ => {
868                tracing::error!(?offer_id, "invalid close channel response");
869            }
870        };
871    }
872
873    fn handle_gpadl_create(&mut self, offer_id: OfferId, gpadl_id: GpadlId, ok: bool) {
874        let status = if ok { 0 } else { protocol::STATUS_UNSUCCESSFUL };
875        self.server
876            .with_notifier(&mut self.inner)
877            .gpadl_create_complete(offer_id, gpadl_id, status);
878    }
879
880    fn handle_gpadl_teardown(&mut self, offer_id: OfferId, gpadl_id: GpadlId) {
881        self.server
882            .with_notifier(&mut self.inner)
883            .gpadl_teardown_complete(offer_id, gpadl_id);
884    }
885
886    fn handle_modify_channel(&mut self, offer_id: OfferId, status: i32) {
887        self.server
888            .with_notifier(&mut self.inner)
889            .modify_channel_complete(offer_id, status);
890    }
891
892    fn handle_restore_channel(
893        &mut self,
894        offer_id: OfferId,
895        open: Option<OpenResult>,
896    ) -> anyhow::Result<RestoreResult> {
897        let gpadls = self.server.channel_gpadls(offer_id);
898
899        // If the channel is opened, handle that before calling into channels so that failure can
900        // be handled before the channel is marked restored.
901        let open_request = open
902            .map(|result| -> anyhow::Result<_> {
903                let params = self.server.get_restore_open_params(offer_id)?;
904                let (_, interrupt) = self.inner.open_channel(offer_id, &params)?;
905                let channel = self.inner.complete_open(offer_id, Some(result))?;
906                Ok(OpenRequest::new(
907                    params.open_data,
908                    interrupt,
909                    self.server
910                        .get_version()
911                        .expect("must be connected")
912                        .feature_flags,
913                    channel.flags,
914                ))
915            })
916            .transpose()?;
917
918        self.server
919            .with_notifier(&mut self.inner)
920            .restore_channel(offer_id, open_request.is_some())?;
921
922        let channel = self.inner.channels.get_mut(&offer_id).unwrap();
923        for gpadl in &gpadls {
924            if let Ok(buf) =
925                MultiPagedRangeBuf::new(gpadl.request.count.into(), gpadl.request.buf.clone())
926            {
927                channel.gpadls.add(gpadl.request.id, buf);
928            }
929        }
930
931        let result = RestoreResult {
932            open_request,
933            gpadls,
934        };
935        Ok(result)
936    }
937
938    async fn handle_request(&mut self, request: VmbusRequest) {
939        tracing::debug!(?request, "handle_request");
940        match request {
941            VmbusRequest::Reset(rpc) => self.handle_reset(rpc),
942            VmbusRequest::Inspect(deferred) => {
943                deferred.respond(|resp| {
944                    resp.field("message_port", &self.inner.message_port)
945                        .field("running", self.inner.running)
946                        .field("hvsock_requests", self.inner.hvsock_requests)
947                        .field_mut_with("unstick_channels", |v| {
948                            let v: inspect::ValueKind = if let Some(v) = v {
949                                if v == "force" {
950                                    self.unstick_channels(true);
951                                    v.into()
952                                } else {
953                                    let v =
954                                        v.parse().ok().context("expected false, true, or force")?;
955                                    if v {
956                                        self.unstick_channels(false);
957                                    }
958                                    v.into()
959                                }
960                            } else {
961                                false.into()
962                            };
963                            anyhow::Ok(v)
964                        })
965                        .merge(&self.server.with_notifier(&mut self.inner));
966                });
967            }
968            VmbusRequest::Save(rpc) => rpc.handle_sync(|()| SavedState {
969                server: self.server.save(),
970                lost_synic_bug_fixed: true,
971            }),
972            VmbusRequest::Restore(rpc) => {
973                rpc.handle(async |state| {
974                    self.unstick_on_start = !state.lost_synic_bug_fixed;
975                    if let Some(sender) = &self.inner.saved_state_notify {
976                        tracing::trace!("sending saved state to proxy");
977                        if let Err(err) = sender
978                            .call_failable(SavedStateRequest::Set, Box::new(state.server.clone()))
979                            .await
980                        {
981                            tracing::error!(
982                                err = &err as &dyn std::error::Error,
983                                "failed to restore proxy saved state"
984                            );
985                            return Err(RestoreError::ServerError(err.into()));
986                        }
987                    }
988
989                    self.server
990                        .with_notifier(&mut self.inner)
991                        .restore(state.server)
992                })
993                .await
994            }
995            VmbusRequest::Stop(rpc) => rpc.handle_sync(|()| {
996                if self.inner.running {
997                    self.inner.running = false;
998                }
999            }),
1000            VmbusRequest::Start => {
1001                if !self.inner.running {
1002                    self.inner.running = true;
1003                    if let Some(sender) = self.inner.saved_state_notify.as_ref() {
1004                        // Indicate to the proxy that the server is starting and that it should
1005                        // clear its saved state cache.
1006                        tracing::trace!("sending clear saved state message to proxy");
1007                        sender
1008                            .call(SavedStateRequest::Clear, ())
1009                            .await
1010                            .expect("failed to clear proxy saved state");
1011                    }
1012
1013                    self.server
1014                        .with_notifier(&mut self.inner)
1015                        .revoke_unclaimed_channels();
1016                    if self.unstick_on_start {
1017                        tracing::info!(
1018                            "lost synic bug fix is not in yet, call unstick_channels to mitigate the issue."
1019                        );
1020                        self.unstick_channels(false);
1021                        self.unstick_on_start = false;
1022                    }
1023                }
1024            }
1025        }
1026    }
1027
1028    fn handle_reset(&mut self, rpc: Rpc<(), ()>) {
1029        let needs_reset = self.inner.reset_done.is_empty();
1030        self.inner.reset_done.push(rpc);
1031        if needs_reset {
1032            self.server.with_notifier(&mut self.inner).reset();
1033        }
1034    }
1035
1036    fn handle_relay_response(&mut self, response: ModifyConnectionResponse) {
1037        self.server
1038            .with_notifier(&mut self.inner)
1039            .complete_modify_connection(response);
1040    }
1041
1042    fn handle_tl_connect_result(&mut self, result: HvsockConnectResult) {
1043        assert_ne!(self.inner.hvsock_requests, 0);
1044        self.inner.hvsock_requests -= 1;
1045
1046        self.server
1047            .with_notifier(&mut self.inner)
1048            .send_tl_connect_result(result);
1049    }
1050
1051    fn handle_synic_message(&mut self, message: SynicMessage) {
1052        match self
1053            .server
1054            .with_notifier(&mut self.inner)
1055            .handle_synic_message(message)
1056        {
1057            Ok(()) => {}
1058            Err(err) => {
1059                tracing::warn!(
1060                    error = &err as &dyn std::error::Error,
1061                    "synic message error"
1062                );
1063            }
1064        }
1065    }
1066
1067    /// Handles a request forwarded by a different vmbus server. This is used to forward requests
1068    /// for different VTLs to different servers.
1069    ///
1070    /// N.B. This uses the same mechanism as the HCL server relay, so all requests, even the ones
1071    ///      meant for the primary server, are forwarded. In that case the primary server depends
1072    ///      on this server to send back a response so it can continue handling it.
1073    fn handle_external_request(&mut self, request: InitiateContactRequest) {
1074        self.server
1075            .with_notifier(&mut self.inner)
1076            .initiate_contact(request);
1077    }
1078
1079    async fn run(
1080        &mut self,
1081        mut relay_response_recv: impl futures::stream::FusedStream<Item = ModifyConnectionResponse>
1082        + Unpin,
1083        mut hvsock_recv: impl futures::stream::FusedStream<Item = HvsockConnectResult> + Unpin,
1084    ) {
1085        loop {
1086            // Create an OptionFuture for each event that should only be handled
1087            // while the VM is running. In other cases, leave the events in
1088            // their respective queues.
1089
1090            let running_not_resetting = self.inner.running && self.inner.reset_done.is_empty();
1091            let mut external_requests = OptionFuture::from(
1092                running_not_resetting
1093                    .then(|| {
1094                        self.external_requests
1095                            .as_mut()
1096                            .map(|r| r.select_next_some())
1097                    })
1098                    .flatten(),
1099            );
1100
1101            // Try to send any pending messages while the VM is running.
1102            let has_pending_messages = self.server.has_pending_messages();
1103            let message_port = self.inner.message_port.as_mut();
1104            let mut flush_pending_messages =
1105                OptionFuture::from((running_not_resetting && has_pending_messages).then(|| {
1106                    poll_fn(|cx| {
1107                        self.server.poll_flush_pending_messages(|msg| {
1108                            message_port.poll_post_message(cx, VMBUS_MESSAGE_TYPE, msg.data())
1109                        })
1110                    })
1111                    .fuse()
1112                }));
1113
1114            // Only handle new incoming messages if there are no outgoing messages pending, and not
1115            // too many hvsock requests outstanding. This puts a bound on the resources used by the
1116            // guest.
1117            let mut message_recv = OptionFuture::from(
1118                (running_not_resetting
1119                    && !has_pending_messages
1120                    && self.inner.hvsock_requests < MAX_CONCURRENT_HVSOCK_REQUESTS)
1121                    .then(|| self.message_recv.select_next_some()),
1122            );
1123
1124            // Accept channel responses until stopped or when resetting.
1125            let mut channel_response = OptionFuture::from(
1126                (self.inner.running || !self.inner.reset_done.is_empty())
1127                    .then(|| self.inner.channel_responses.select_next_some()),
1128            );
1129
1130            // Accept hvsock connect responses while the VM is running.
1131            let mut hvsock_response =
1132                OptionFuture::from(running_not_resetting.then(|| hvsock_recv.select_next_some()));
1133
1134            futures::select! { // merge semantics
1135                r = self.task_recv.recv().fuse() => {
1136                    if let Ok(request) = r {
1137                        self.handle_request(request).await;
1138                    } else {
1139                        break;
1140                    }
1141                }
1142                r = self.offer_recv.select_next_some() => {
1143                    match r {
1144                        OfferRequest::Offer(rpc) => {
1145                            rpc.handle_failable_sync(|request| { self.handle_offer(request) })
1146                        },
1147                        OfferRequest::ForceReset(rpc) => {
1148                            self.handle_reset(rpc);
1149                        }
1150                    }
1151                }
1152                r = self.server_request_recv.select_next_some() => {
1153                    match r {
1154                        (id, Some(request)) => match request {
1155                            ChannelServerRequest::Restore(rpc) => rpc.handle_failable_sync(|open| {
1156                                self.handle_restore_channel(id, open)
1157                            }),
1158                            ChannelServerRequest::Revoke(rpc) => rpc.handle_sync(|_| {
1159                                self.handle_revoke(id);
1160                            })
1161                        },
1162                        (id, None) => self.handle_revoke(id),
1163                    }
1164                }
1165                r = channel_response => {
1166                    let (id, seq, response) = r.unwrap();
1167                    self.handle_response(id, seq, response);
1168                }
1169                r = relay_response_recv.select_next_some() => {
1170                    self.handle_relay_response(r);
1171                },
1172                r = hvsock_response => {
1173                    self.handle_tl_connect_result(r.unwrap());
1174                }
1175                data = message_recv => {
1176                    let data = data.unwrap();
1177                    self.handle_synic_message(data);
1178                }
1179                r = external_requests => {
1180                    let r = r.unwrap();
1181                    self.handle_external_request(r);
1182                }
1183                _r = flush_pending_messages => {}
1184                complete => break,
1185            }
1186        }
1187    }
1188
1189    /// Wakes the host and guest for every open channel. If `force`, always
1190    /// wakes both the host and guest. If `!force`, only wake for rings that are
1191    /// in the state where a notification is expected.
1192    fn unstick_channels(&self, force: bool) {
1193        for channel in self.inner.channels.values() {
1194            if let Err(err) = self.unstick_channel(channel, force) {
1195                tracing::warn!(
1196                    channel = %channel.key,
1197                    error = err.as_ref() as &dyn std::error::Error,
1198                    "could not unstick channel"
1199                );
1200            }
1201        }
1202    }
1203
1204    fn unstick_channel(&self, channel: &Channel, force: bool) -> anyhow::Result<()> {
1205        if let ChannelState::Open {
1206            open_params,
1207            host_to_guest_interrupt,
1208            guest_to_host_event,
1209            ..
1210        } = &channel.state
1211        {
1212            if force {
1213                tracing::info!(channel = %channel.key, "waking host and guest");
1214                guest_to_host_event.0.deliver();
1215                host_to_guest_interrupt.deliver();
1216                return Ok(());
1217            }
1218
1219            let gpadl = channel
1220                .gpadls
1221                .clone()
1222                .view()
1223                .map(open_params.open_data.ring_gpadl_id)
1224                .context("couldn't find ring gpadl")?;
1225
1226            let aligned = AlignedGpadlView::new(gpadl)
1227                .ok()
1228                .context("ring not aligned")?;
1229            let (in_gpadl, out_gpadl) = aligned
1230                .split(open_params.open_data.ring_offset)
1231                .ok()
1232                .context("couldn't split ring")?;
1233
1234            if let Err(err) = self.unstick_incoming_ring(
1235                channel,
1236                in_gpadl,
1237                guest_to_host_event,
1238                host_to_guest_interrupt,
1239            ) {
1240                tracing::warn!(
1241                    channel = %channel.key,
1242                    error = err.as_ref() as &dyn std::error::Error,
1243                    "could not unstick incoming ring"
1244                );
1245            }
1246            if let Err(err) = self.unstick_outgoing_ring(
1247                channel,
1248                out_gpadl,
1249                guest_to_host_event,
1250                host_to_guest_interrupt,
1251            ) {
1252                tracing::warn!(
1253                    channel = %channel.key,
1254                    error = err.as_ref() as &dyn std::error::Error,
1255                    "could not unstick outgoing ring"
1256                );
1257            }
1258        }
1259        Ok(())
1260    }
1261
1262    fn unstick_incoming_ring(
1263        &self,
1264        channel: &Channel,
1265        in_gpadl: AlignedGpadlView,
1266        guest_to_host_event: &ChannelEvent,
1267        host_to_guest_interrupt: &Interrupt,
1268    ) -> Result<(), anyhow::Error> {
1269        let incoming_mem = GpadlRingMem::new(in_gpadl, &self.inner.gm)?;
1270        if ring::reader_needs_signal(&incoming_mem) {
1271            tracing::info!(channel = %channel.key, "waking host for incoming ring");
1272            guest_to_host_event.0.deliver();
1273        }
1274        if ring::writer_needs_signal(&incoming_mem) {
1275            tracing::info!(channel = %channel.key, "waking guest for incoming ring");
1276            host_to_guest_interrupt.deliver();
1277        }
1278        Ok(())
1279    }
1280
1281    fn unstick_outgoing_ring(
1282        &self,
1283        channel: &Channel,
1284        out_gpadl: AlignedGpadlView,
1285        guest_to_host_event: &ChannelEvent,
1286        host_to_guest_interrupt: &Interrupt,
1287    ) -> Result<(), anyhow::Error> {
1288        let outgoing_mem = GpadlRingMem::new(out_gpadl, &self.inner.gm)?;
1289        if ring::reader_needs_signal(&outgoing_mem) {
1290            tracing::info!(channel = %channel.key, "waking guest for outgoing ring");
1291            host_to_guest_interrupt.deliver();
1292        }
1293        if ring::writer_needs_signal(&outgoing_mem) {
1294            tracing::info!(channel = %channel.key, "waking host for outgoing ring");
1295            guest_to_host_event.0.deliver();
1296        }
1297        Ok(())
1298    }
1299}
1300
1301impl Notifier for ServerTaskInner {
1302    fn notify(&mut self, offer_id: OfferId, action: channels::Action) {
1303        let channel = self
1304            .channels
1305            .get_mut(&offer_id)
1306            .expect("channel does not exist");
1307
1308        fn handle<I: 'static + Send, R: 'static + Send>(
1309            offer_id: OfferId,
1310            channel: &Channel,
1311            req: impl FnOnce(Rpc<I, R>) -> ChannelRequest,
1312            input: I,
1313            f: impl 'static + Send + FnOnce(R) -> ChannelResponse,
1314        ) -> Pin<Box<dyn Send + Future<Output = (OfferId, u64, Result<ChannelResponse, RpcError>)>>>
1315        {
1316            let recv = channel.send.call(req, input);
1317            let seq = channel.seq;
1318            Box::pin(async move {
1319                let r = recv.await.map(f);
1320                (offer_id, seq, r)
1321            })
1322        }
1323
1324        let response = match action {
1325            channels::Action::Open(open_params, version) => {
1326                let seq = channel.seq;
1327                match self.open_channel(offer_id, &open_params) {
1328                    Ok((channel, interrupt)) => handle(
1329                        offer_id,
1330                        channel,
1331                        ChannelRequest::Open,
1332                        OpenRequest::new(
1333                            open_params.open_data,
1334                            interrupt,
1335                            version.feature_flags,
1336                            channel.flags,
1337                        ),
1338                        ChannelResponse::Open,
1339                    ),
1340                    Err(err) => {
1341                        tracelimit::error_ratelimited!(
1342                            err = err.as_ref() as &dyn std::error::Error,
1343                            ?offer_id,
1344                            "could not open channel",
1345                        );
1346
1347                        // Return an error response to the channels module if the open_channel call
1348                        // failed.
1349                        Box::pin(future::ready((
1350                            offer_id,
1351                            seq,
1352                            Ok(ChannelResponse::Open(None)),
1353                        )))
1354                    }
1355                }
1356            }
1357            channels::Action::Close => {
1358                if let Some(channel_bitmap) = self.channel_bitmap.as_ref() {
1359                    if let ChannelState::Open { open_params, .. } = channel.state {
1360                        channel_bitmap.unregister_channel(open_params.event_flag);
1361                    }
1362                }
1363
1364                channel.state = ChannelState::Closing;
1365                handle(offer_id, channel, ChannelRequest::Close, (), |()| {
1366                    ChannelResponse::Close
1367                })
1368            }
1369            channels::Action::Gpadl(gpadl_id, count, buf) => {
1370                channel.gpadls.add(
1371                    gpadl_id,
1372                    MultiPagedRangeBuf::new(count.into(), buf.clone()).unwrap(),
1373                );
1374                handle(
1375                    offer_id,
1376                    channel,
1377                    ChannelRequest::Gpadl,
1378                    GpadlRequest {
1379                        id: gpadl_id,
1380                        count,
1381                        buf,
1382                    },
1383                    move |r| ChannelResponse::Gpadl(gpadl_id, r),
1384                )
1385            }
1386            channels::Action::TeardownGpadl {
1387                gpadl_id,
1388                post_restore,
1389            } => {
1390                if !post_restore {
1391                    channel.gpadls.remove(gpadl_id, Box::new(|| ()));
1392                }
1393
1394                handle(
1395                    offer_id,
1396                    channel,
1397                    ChannelRequest::TeardownGpadl,
1398                    gpadl_id,
1399                    move |()| ChannelResponse::TeardownGpadl(gpadl_id),
1400                )
1401            }
1402            channels::Action::Modify { target_vp } => {
1403                if let ChannelState::Open {
1404                    guest_event_port, ..
1405                } = &mut channel.state
1406                {
1407                    if let Err(err) = guest_event_port.set_target_vp(target_vp) {
1408                        tracelimit::error_ratelimited!(
1409                            error = &err as &dyn std::error::Error,
1410                            channel = %channel.key,
1411                            "could not modify channel",
1412                        );
1413                        let seq = channel.seq;
1414                        Box::pin(async move {
1415                            (
1416                                offer_id,
1417                                seq,
1418                                Ok(ChannelResponse::Modify(protocol::STATUS_UNSUCCESSFUL)),
1419                            )
1420                        })
1421                    } else {
1422                        handle(
1423                            offer_id,
1424                            channel,
1425                            ChannelRequest::Modify,
1426                            ModifyRequest::TargetVp { target_vp },
1427                            ChannelResponse::Modify,
1428                        )
1429                    }
1430                } else {
1431                    unreachable!();
1432                }
1433            }
1434        };
1435        self.channel_responses.push(response);
1436    }
1437
1438    fn modify_connection(&mut self, mut request: ModifyConnectionRequest) -> anyhow::Result<()> {
1439        self.map_interrupt_page(request.interrupt_page)
1440            .context("Failed to map interrupt page.")?;
1441
1442        self.set_monitor_page(request.monitor_page)
1443            .context("Failed to map monitor page.")?;
1444
1445        if let Some(vp) = request.target_message_vp {
1446            self.message_port.set_target_vp(vp)?;
1447        }
1448
1449        if request.notify_relay {
1450            // If this server is handling MNF, the monitor pages should not be relayed.
1451            // N.B. Since the relay is being asked not to update the monitor pages, rather than
1452            //      reset them, this is only safe because the value of enable_mnf won't change after
1453            //      the server has been created.
1454            if self.enable_mnf {
1455                request.monitor_page = Update::Unchanged;
1456            }
1457
1458            self.relay_send.send(request.into());
1459        }
1460
1461        Ok(())
1462    }
1463
1464    fn forward_unhandled(&mut self, request: InitiateContactRequest) {
1465        if let Some(external_server) = &self.external_server_send {
1466            external_server.send(request);
1467        } else {
1468            tracing::warn!(?request, "nowhere to forward unhandled request")
1469        }
1470    }
1471
1472    fn inspect(&self, version: Option<VersionInfo>, offer_id: OfferId, req: inspect::Request<'_>) {
1473        let channel = self.channels.get(&offer_id).expect("should exist");
1474        let mut resp = req.respond();
1475        if let ChannelState::Open { open_params, .. } = &channel.state {
1476            let mem = if self.private_gm.is_some()
1477                && channel.flags.confidential_ring_buffer()
1478                && version
1479                    .expect("must be connected")
1480                    .feature_flags
1481                    .confidential_channels()
1482            {
1483                self.private_gm.as_ref().unwrap()
1484            } else {
1485                &self.gm
1486            };
1487
1488            inspect_rings(
1489                &mut resp,
1490                mem,
1491                channel.gpadls.clone(),
1492                &open_params.open_data,
1493            );
1494        }
1495    }
1496
1497    fn send_message(&mut self, message: &OutgoingMessage, target: MessageTarget) -> bool {
1498        // If the server is paused, queue all messages, to avoid affecting synic
1499        // state during/after it has been saved or reset.
1500        //
1501        // Note that messages to reserved channels or custom targets will be
1502        // dropped. However, such messages should only be sent in response to
1503        // guest requests, which should not be processed while the server is
1504        // paused.
1505        //
1506        // FUTURE: it would be better to ensure that no messages are generated
1507        // by operations that run while the server is paused. E.g., defer
1508        // sending offer or revoke messages for new or revoked offers. This
1509        // would prevent the queue from growing without bound.
1510        if !self.running && !self.send_messages_while_stopped {
1511            if !matches!(target, MessageTarget::Default) {
1512                tracelimit::error_ratelimited!(?target, "dropping message while paused");
1513            }
1514            return false;
1515        }
1516
1517        let mut port_storage;
1518        let port = match target {
1519            MessageTarget::Default => self.message_port.as_mut(),
1520            MessageTarget::ReservedChannel(offer_id, target) => {
1521                if let Some(port) = self.get_reserved_channel_message_port(offer_id, target) {
1522                    port.as_mut()
1523                } else {
1524                    // Updating the port failed, so there is no way to send the message.
1525                    return true;
1526                }
1527            }
1528            MessageTarget::Custom(target) => {
1529                port_storage = match self.synic.new_guest_message_port(
1530                    self.redirect_vtl,
1531                    target.vp,
1532                    target.sint,
1533                ) {
1534                    Ok(port) => port,
1535                    Err(err) => {
1536                        tracing::error!(
1537                            ?err,
1538                            ?self.redirect_vtl,
1539                            ?target,
1540                            "could not create message port"
1541                        );
1542
1543                        // There is no way to send the message.
1544                        return true;
1545                    }
1546                };
1547                port_storage.as_mut()
1548            }
1549        };
1550
1551        // If this returns Pending, the channels module will queue the message and the ServerTask
1552        // main loop will try to send it again later.
1553        matches!(
1554            port.poll_post_message(
1555                &mut std::task::Context::from_waker(std::task::Waker::noop()),
1556                VMBUS_MESSAGE_TYPE,
1557                message.data()
1558            ),
1559            Poll::Ready(())
1560        )
1561    }
1562
1563    fn notify_hvsock(&mut self, request: &HvsockConnectRequest) {
1564        self.hvsock_requests += 1;
1565        self.hvsock_send.send(*request);
1566    }
1567
1568    fn reset_complete(&mut self) {
1569        if let Some(monitor) = self.synic.monitor_support() {
1570            if let Err(err) = monitor.set_monitor_page(self.vtl, None) {
1571                tracing::warn!(?err, "resetting monitor page failed")
1572            }
1573        }
1574
1575        self.unreserve_channels();
1576        for done in self.reset_done.drain(..) {
1577            done.complete(());
1578        }
1579    }
1580
1581    fn unload_complete(&mut self) {
1582        self.unreserve_channels();
1583    }
1584}
1585
1586impl ServerTaskInner {
1587    fn open_channel(
1588        &mut self,
1589        offer_id: OfferId,
1590        open_params: &OpenParams,
1591    ) -> anyhow::Result<(&mut Channel, Interrupt)> {
1592        let channel = self
1593            .channels
1594            .get_mut(&offer_id)
1595            .expect("channel does not exist");
1596
1597        // For pre-Win8 guests, the host-to-guest event always targets vp 0 and the channel
1598        // bitmap is used instead of the event flag.
1599        let (target_vp, event_flag) = if self.channel_bitmap.is_some() {
1600            (0, 0)
1601        } else {
1602            (open_params.open_data.target_vp, open_params.event_flag)
1603        };
1604        let (target_vtl, target_sint) = if open_params.flags.redirect_interrupt() {
1605            (self.redirect_vtl, self.redirect_sint)
1606        } else {
1607            (self.vtl, SINT)
1608        };
1609
1610        let guest_event_port = self.synic.new_guest_event_port(
1611            VmbusServer::get_child_event_port_id(open_params.channel_id, SINT, self.vtl),
1612            target_vtl,
1613            target_vp,
1614            target_sint,
1615            event_flag,
1616            open_params.monitor_info,
1617        )?;
1618
1619        let interrupt = ChannelBitmap::create_interrupt(
1620            &self.channel_bitmap,
1621            guest_event_port.interrupt(),
1622            open_params.event_flag,
1623        );
1624
1625        // Delete any previously reserved state.
1626        channel.reserved_state.message_port = None;
1627
1628        // If the channel is reserved, create a message port for it.
1629        if let Some(target) = open_params.reserved_target {
1630            channel.reserved_state.message_port = Some(self.synic.new_guest_message_port(
1631                self.redirect_vtl,
1632                target.vp,
1633                target.sint,
1634            )?);
1635
1636            channel.reserved_state.target = target;
1637        }
1638
1639        channel.state = ChannelState::Opening {
1640            open_params: *open_params,
1641            guest_event_port,
1642            host_to_guest_interrupt: interrupt.clone(),
1643        };
1644        Ok((channel, interrupt))
1645    }
1646
1647    fn complete_open(
1648        &mut self,
1649        offer_id: OfferId,
1650        result: Option<OpenResult>,
1651    ) -> anyhow::Result<&mut Channel> {
1652        let channel = self
1653            .channels
1654            .get_mut(&offer_id)
1655            .expect("channel does not exist");
1656
1657        channel.state = if let Some(result) = result {
1658            // The channel will be left in the FailedOpen state only if an error occurs in the match
1659            // arm.
1660            match std::mem::replace(&mut channel.state, ChannelState::FailedOpen) {
1661                ChannelState::Opening {
1662                    open_params,
1663                    guest_event_port,
1664                    host_to_guest_interrupt,
1665                } => {
1666                    let guest_to_host_event =
1667                        Arc::new(ChannelEvent(result.guest_to_host_interrupt));
1668                    // Always register with the channel bitmap; if Win7, this may be unnecessary.
1669                    if let Some(channel_bitmap) = self.channel_bitmap.as_ref() {
1670                        channel_bitmap.register_channel(
1671                            open_params.event_flag,
1672                            guest_to_host_event.0.clone(),
1673                        );
1674                    }
1675                    // Always set up an event port; if V1, this will be unused.
1676                    let event_port = self
1677                        .synic
1678                        .add_event_port(
1679                            open_params.connection_id,
1680                            self.vtl,
1681                            guest_to_host_event.clone(),
1682                            open_params.monitor_info,
1683                        )
1684                        .with_context(|| {
1685                            format!(
1686                                "failed to create event port for VTL {:?}, connection ID {:#x}",
1687                                self.vtl, open_params.connection_id
1688                            )
1689                        })?;
1690
1691                    ChannelState::Open {
1692                        open_params,
1693                        _event_port: event_port,
1694                        guest_event_port,
1695                        host_to_guest_interrupt,
1696                        guest_to_host_event,
1697                    }
1698                }
1699                s => {
1700                    tracing::error!("attempting to complete open of open or closed channel");
1701                    // Restore the original state
1702                    s
1703                }
1704            }
1705        } else {
1706            ChannelState::Closed
1707        };
1708        Ok(channel)
1709    }
1710
1711    /// If the client specified an interrupt page, map it into host memory and
1712    /// set up the shared event port.
1713    fn map_interrupt_page(&mut self, interrupt_page: Update<u64>) -> anyhow::Result<()> {
1714        let interrupt_page = match interrupt_page {
1715            Update::Unchanged => return Ok(()),
1716            Update::Reset => {
1717                self.channel_bitmap = None;
1718                self.shared_event_port = None;
1719                return Ok(());
1720            }
1721            Update::Set(interrupt_page) => interrupt_page,
1722        };
1723
1724        assert_ne!(interrupt_page, 0);
1725
1726        if interrupt_page % PAGE_SIZE as u64 != 0 {
1727            anyhow::bail!("interrupt page {:#x} is not page aligned", interrupt_page);
1728        }
1729
1730        // Use a subrange to access the interrupt page to give GuestMemory's without a full mapping
1731        // a chance to create one.
1732        let interrupt_page = self
1733            .gm
1734            .lockable_subrange(interrupt_page, PAGE_SIZE as u64)?
1735            .lock_gpns(false, &[0])?;
1736
1737        let channel_bitmap = Arc::new(ChannelBitmap::new(interrupt_page));
1738        self.channel_bitmap = Some(channel_bitmap.clone());
1739
1740        // Create the shared event port for pre-Win8 guests.
1741        let interrupt = Interrupt::from_fn(move || {
1742            channel_bitmap.handle_shared_interrupt();
1743        });
1744
1745        self.shared_event_port = Some(self.synic.add_event_port(
1746            SHARED_EVENT_CONNECTION_ID,
1747            self.vtl,
1748            Arc::new(ChannelEvent(interrupt)),
1749            None,
1750        )?);
1751
1752        Ok(())
1753    }
1754
1755    fn set_monitor_page(&mut self, monitor_page: Update<MonitorPageGpas>) -> anyhow::Result<()> {
1756        let monitor_page = match monitor_page {
1757            Update::Unchanged => return Ok(()),
1758            Update::Reset => None,
1759            Update::Set(value) => Some(value),
1760        };
1761
1762        // TODO: can this check be moved into channels.rs?
1763        if self.channels.iter().any(|(_, c)| {
1764            matches!(
1765                &c.state,
1766                ChannelState::Open {
1767                    open_params,
1768                    ..
1769                } | ChannelState::Opening {
1770                    open_params,
1771                    ..
1772                } if open_params.monitor_info.is_some()
1773            )
1774        }) {
1775            anyhow::bail!("attempt to change monitor page while open channels using mnf");
1776        }
1777
1778        if self.enable_mnf {
1779            if let Some(monitor) = self.synic.monitor_support() {
1780                if let Err(err) = monitor.set_monitor_page(self.vtl, monitor_page) {
1781                    anyhow::bail!(
1782                        "setting monitor page failed, err = {err:?}, monitor_page = {monitor_page:?}"
1783                    );
1784                }
1785            }
1786        }
1787
1788        Ok(())
1789    }
1790
1791    fn get_reserved_channel_message_port(
1792        &mut self,
1793        offer_id: OfferId,
1794        new_target: ConnectionTarget,
1795    ) -> Option<&mut Box<dyn GuestMessagePort>> {
1796        let channel = self
1797            .channels
1798            .get_mut(&offer_id)
1799            .expect("channel does not exist");
1800
1801        assert!(
1802            channel.reserved_state.message_port.is_some(),
1803            "channel is not reserved"
1804        );
1805
1806        // On close, the guest may have changed the message target it wants to use for the close
1807        // response. If so, update the message port.
1808        if channel.reserved_state.target.sint != new_target.sint {
1809            // Destroy the old port before creating the new one.
1810            channel.reserved_state.message_port = None;
1811            let message_port = self
1812                .synic
1813                .new_guest_message_port(self.redirect_vtl, new_target.vp, new_target.sint)
1814                .inspect_err(|err| {
1815                    tracing::error!(
1816                        ?err,
1817                        ?self.redirect_vtl,
1818                        ?new_target,
1819                        "could not create reserved channel message port"
1820                    )
1821                })
1822                .ok()?;
1823
1824            channel.reserved_state.message_port = Some(message_port);
1825            channel.reserved_state.target = new_target;
1826        } else if channel.reserved_state.target.vp != new_target.vp {
1827            let message_port = channel.reserved_state.message_port.as_mut().unwrap();
1828
1829            // The vp has changed, but the SINT is the same. Just update the vp. If this fails,
1830            // ignore it and just send to the old vp.
1831            if let Err(err) = message_port.set_target_vp(new_target.vp) {
1832                tracing::error!(
1833                    ?err,
1834                    ?self.redirect_vtl,
1835                    ?new_target,
1836                    "could not update reserved channel message port"
1837                );
1838            }
1839
1840            channel.reserved_state.target = new_target;
1841            return Some(message_port);
1842        }
1843
1844        Some(channel.reserved_state.message_port.as_mut().unwrap())
1845    }
1846
1847    fn unreserve_channels(&mut self) {
1848        // Unreserve all closed channels.
1849        for channel in self.channels.values_mut() {
1850            if let ChannelState::Closed = channel.state {
1851                channel.reserved_state.message_port = None;
1852            }
1853        }
1854    }
1855}
1856
1857/// Control point for [`VmbusServer`], allowing callers to offer channels.
1858#[derive(Clone)]
1859pub struct VmbusServerControl {
1860    mem: GuestMemory,
1861    private_mem: Option<GuestMemory>,
1862    send: mesh::Sender<OfferRequest>,
1863    use_event: bool,
1864    force_confidential_external_memory: bool,
1865}
1866
1867impl VmbusServerControl {
1868    /// Offers a channel to the vmbus server, where the flags and user_defined data are already set.
1869    /// This is used by the relay to forward the host's parameters.
1870    pub async fn offer_core(&self, offer_info: OfferInfo) -> anyhow::Result<OfferResources> {
1871        let flags = offer_info.params.flags;
1872        self.send
1873            .call_failable(OfferRequest::Offer, offer_info)
1874            .await?;
1875        Ok(OfferResources::new(
1876            self.mem.clone(),
1877            if flags.confidential_ring_buffer() || flags.confidential_external_memory() {
1878                self.private_mem.clone()
1879            } else {
1880                None
1881            },
1882        ))
1883    }
1884
1885    /// Force reset all channels and protocol state, without requiring the
1886    /// server to be paused.
1887    pub async fn force_reset(&self) -> anyhow::Result<()> {
1888        self.send
1889            .call(OfferRequest::ForceReset, ())
1890            .await
1891            .context("vmbus server is gone")
1892    }
1893
1894    async fn offer(&self, request: OfferInput) -> anyhow::Result<OfferResources> {
1895        let mut offer_info = OfferInfo {
1896            params: request.params.into(),
1897            request_send: request.request_send,
1898            server_request_recv: request.server_request_recv,
1899        };
1900
1901        if self.force_confidential_external_memory {
1902            tracing::warn!(
1903                key = %offer_info.params.key(),
1904                "forcing confidential external memory for channel"
1905            );
1906
1907            offer_info
1908                .params
1909                .flags
1910                .set_confidential_external_memory(true);
1911        }
1912
1913        self.offer_core(offer_info).await
1914    }
1915}
1916
1917/// Inspects the specified ring buffer state by directly accessing guest memory.
1918fn inspect_rings(
1919    resp: &mut inspect::Response<'_>,
1920    gm: &GuestMemory,
1921    gpadl_map: Arc<GpadlMap>,
1922    open_data: &OpenData,
1923) -> Option<()> {
1924    let gpadl = gpadl_map
1925        .view()
1926        .map(GpadlId(open_data.ring_gpadl_id.0))
1927        .ok()?;
1928    let aligned = AlignedGpadlView::new(gpadl).ok()?;
1929    let (in_gpadl, out_gpadl) = aligned.split(open_data.ring_offset).ok()?;
1930    if let Ok(incoming_mem) = GpadlRingMem::new(in_gpadl, gm) {
1931        resp.child("incoming_ring", |req| ring::inspect_ring(incoming_mem, req));
1932    }
1933    if let Ok(outgoing_mem) = GpadlRingMem::new(out_gpadl, gm) {
1934        resp.child("outgoing_ring", |req| ring::inspect_ring(outgoing_mem, req));
1935    }
1936    Some(())
1937}
1938
1939pub(crate) struct MessageSender {
1940    send: mpsc::Sender<SynicMessage>,
1941    multiclient: bool,
1942}
1943
1944impl MessageSender {
1945    fn poll_handle_message(
1946        &self,
1947        cx: &mut std::task::Context<'_>,
1948        msg: &[u8],
1949        trusted: bool,
1950    ) -> Poll<Result<(), SendError>> {
1951        let mut send = self.send.clone();
1952        ready!(send.poll_ready(cx))?;
1953        send.start_send(SynicMessage {
1954            data: msg.to_vec(),
1955            multiclient: self.multiclient,
1956            trusted,
1957        })?;
1958
1959        Poll::Ready(Ok(()))
1960    }
1961}
1962
1963impl MessagePort for MessageSender {
1964    fn poll_handle_message(
1965        &self,
1966        cx: &mut std::task::Context<'_>,
1967        msg: &[u8],
1968        trusted: bool,
1969    ) -> Poll<()> {
1970        if let Err(err) = ready!(self.poll_handle_message(cx, msg, trusted)) {
1971            tracelimit::error_ratelimited!(
1972                error = &err as &dyn std::error::Error,
1973                "failed to send message"
1974            );
1975        }
1976
1977        Poll::Ready(())
1978    }
1979}
1980
1981#[async_trait]
1982impl ParentBus for VmbusServerControl {
1983    async fn add_child(&self, request: OfferInput) -> anyhow::Result<OfferResources> {
1984        self.offer(request).await
1985    }
1986
1987    fn clone_bus(&self) -> Box<dyn ParentBus> {
1988        Box::new(self.clone())
1989    }
1990
1991    fn use_event(&self) -> bool {
1992        self.use_event
1993    }
1994}
1995
1996#[cfg(test)]
1997mod tests {
1998    use super::*;
1999    use inspect::InspectMut;
2000    use mesh::CancelReason;
2001    use pal_async::DefaultDriver;
2002    use pal_async::async_test;
2003    use pal_async::driver::SpawnDriver;
2004    use pal_async::timer::Instant;
2005    use pal_async::timer::PolledTimer;
2006    use parking_lot::Mutex;
2007    use protocol::UserDefinedData;
2008    use std::time::Duration;
2009    use test_with_tracing::test;
2010    use vmbus_channel::bus::OfferParams;
2011    use vmbus_channel::channel::ChannelOpenError;
2012    use vmbus_channel::channel::DeviceResources;
2013    use vmbus_channel::channel::SaveRestoreVmbusDevice;
2014    use vmbus_channel::channel::VmbusDevice;
2015    use vmbus_channel::channel::offer_channel;
2016    use vmbus_core::protocol::ChannelId;
2017    use vmbus_core::protocol::VmbusMessage;
2018    use vmcore::synic::MonitorInfo;
2019    use vmcore::synic::SynicPortAccess;
2020    use zerocopy::FromBytes;
2021    use zerocopy::Immutable;
2022    use zerocopy::IntoBytes;
2023    use zerocopy::KnownLayout;
2024
2025    struct MockSynicInner {
2026        message_port: Option<Arc<dyn MessagePort>>,
2027    }
2028
2029    struct MockSynic {
2030        inner: Mutex<MockSynicInner>,
2031        message_send: mesh::Sender<Vec<u8>>,
2032        spawner: Arc<dyn SpawnDriver>,
2033    }
2034
2035    impl MockSynic {
2036        fn new(message_send: mesh::Sender<Vec<u8>>, spawner: Arc<dyn SpawnDriver>) -> Self {
2037            Self {
2038                inner: Mutex::new(MockSynicInner { message_port: None }),
2039                message_send,
2040                spawner,
2041            }
2042        }
2043
2044        fn send_message(&self, msg: impl VmbusMessage + IntoBytes + Immutable + KnownLayout) {
2045            self.send_message_core(OutgoingMessage::new(&msg), false);
2046        }
2047
2048        fn send_message_trusted(
2049            &self,
2050            msg: impl VmbusMessage + IntoBytes + Immutable + KnownLayout,
2051        ) {
2052            self.send_message_core(OutgoingMessage::new(&msg), true);
2053        }
2054
2055        fn send_message_core(&self, msg: OutgoingMessage, trusted: bool) {
2056            assert_eq!(
2057                self.inner
2058                    .lock()
2059                    .message_port
2060                    .as_ref()
2061                    .unwrap()
2062                    .poll_handle_message(
2063                        &mut std::task::Context::from_waker(std::task::Waker::noop()),
2064                        msg.data(),
2065                        trusted,
2066                    ),
2067                Poll::Ready(())
2068            );
2069        }
2070    }
2071
2072    #[derive(Debug)]
2073    struct MockGuestPort {}
2074
2075    impl GuestEventPort for MockGuestPort {
2076        fn interrupt(&self) -> Interrupt {
2077            Interrupt::null()
2078        }
2079
2080        fn set_target_vp(&mut self, _vp: u32) -> Result<(), vmcore::synic::HypervisorError> {
2081            Ok(())
2082        }
2083    }
2084
2085    struct MockGuestMessagePort {
2086        send: mesh::Sender<Vec<u8>>,
2087        spawner: Arc<dyn SpawnDriver>,
2088        timer: Option<(PolledTimer, Instant)>,
2089    }
2090
2091    impl GuestMessagePort for MockGuestMessagePort {
2092        fn poll_post_message(
2093            &mut self,
2094            cx: &mut std::task::Context<'_>,
2095            _typ: u32,
2096            payload: &[u8],
2097        ) -> Poll<()> {
2098            if let Some((timer, deadline)) = self.timer.as_mut() {
2099                ready!(timer.sleep_until(*deadline).poll_unpin(cx));
2100                self.timer = None;
2101            }
2102
2103            // Return pending 25% of the time.
2104            let mut pending_chance = [0; 1];
2105            getrandom::fill(&mut pending_chance).unwrap();
2106            if pending_chance[0] % 4 == 0 {
2107                let mut timer = PolledTimer::new(self.spawner.as_ref());
2108                let deadline = Instant::now() + Duration::from_millis(10);
2109                match timer.sleep_until(deadline).poll_unpin(cx) {
2110                    Poll::Ready(_) => {}
2111                    Poll::Pending => {
2112                        self.timer = Some((timer, deadline));
2113                        return Poll::Pending;
2114                    }
2115                }
2116            }
2117
2118            self.send.send(payload.into());
2119            Poll::Ready(())
2120        }
2121
2122        fn set_target_vp(&mut self, _vp: u32) -> Result<(), vmcore::synic::HypervisorError> {
2123            Ok(())
2124        }
2125    }
2126
2127    impl Inspect for MockGuestMessagePort {
2128        fn inspect(&self, _req: inspect::Request<'_>) {}
2129    }
2130
2131    impl SynicPortAccess for MockSynic {
2132        fn add_message_port(
2133            &self,
2134            connection_id: u32,
2135            _minimum_vtl: Vtl,
2136            port: Arc<dyn MessagePort>,
2137        ) -> Result<Box<dyn Sync + Send>, vmcore::synic::Error> {
2138            self.inner.lock().message_port = Some(port);
2139            Ok(Box::new(connection_id))
2140        }
2141
2142        fn add_event_port(
2143            &self,
2144            connection_id: u32,
2145            _minimum_vtl: Vtl,
2146            _port: Arc<dyn EventPort>,
2147            _monitor_info: Option<MonitorInfo>,
2148        ) -> Result<Box<dyn Sync + Send>, vmcore::synic::Error> {
2149            Ok(Box::new(connection_id))
2150        }
2151
2152        fn new_guest_message_port(
2153            &self,
2154            _vtl: Vtl,
2155            _vp: u32,
2156            _sint: u8,
2157        ) -> Result<Box<(dyn GuestMessagePort)>, vmcore::synic::HypervisorError> {
2158            Ok(Box::new(MockGuestMessagePort {
2159                send: self.message_send.clone(),
2160                spawner: Arc::clone(&self.spawner),
2161                timer: None,
2162            }))
2163        }
2164
2165        fn new_guest_event_port(
2166            &self,
2167            _port_id: u32,
2168            _vtl: Vtl,
2169            _vp: u32,
2170            _sint: u8,
2171            _flag: u16,
2172            _monitor_info: Option<MonitorInfo>,
2173        ) -> Result<Box<(dyn GuestEventPort)>, vmcore::synic::HypervisorError> {
2174            Ok(Box::new(MockGuestPort {}))
2175        }
2176
2177        fn prefer_os_events(&self) -> bool {
2178            false
2179        }
2180    }
2181
2182    struct TestChannel {
2183        request_recv: mesh::Receiver<ChannelRequest>,
2184        server_request_send: mesh::Sender<ChannelServerRequest>,
2185        _resources: OfferResources,
2186    }
2187
2188    impl TestChannel {
2189        async fn next_request(&mut self) -> ChannelRequest {
2190            self.request_recv.next().await.unwrap()
2191        }
2192
2193        async fn handle_gpadl(&mut self) {
2194            let ChannelRequest::Gpadl(rpc) = self.next_request().await else {
2195                panic!("Wrong request");
2196            };
2197
2198            rpc.complete(true);
2199        }
2200
2201        async fn handle_open(&mut self, f: fn(&OpenRequest)) {
2202            let ChannelRequest::Open(rpc) = self.next_request().await else {
2203                panic!("Wrong request");
2204            };
2205
2206            f(rpc.input());
2207            rpc.complete(Some(OpenResult {
2208                guest_to_host_interrupt: Interrupt::null(),
2209            }));
2210        }
2211
2212        async fn handle_gpadl_teardown(&mut self) {
2213            let rpc = self.get_gpadl_teardown().await;
2214            rpc.complete(());
2215        }
2216
2217        async fn get_gpadl_teardown(&mut self) -> Rpc<GpadlId, ()> {
2218            let ChannelRequest::TeardownGpadl(rpc) = self.next_request().await else {
2219                panic!("Wrong request");
2220            };
2221
2222            rpc
2223        }
2224
2225        async fn restore(&self) {
2226            self.server_request_send
2227                .call(ChannelServerRequest::Restore, None)
2228                .await
2229                .unwrap()
2230                .unwrap();
2231        }
2232    }
2233
2234    struct TestEnv {
2235        vmbus: VmbusServer,
2236        synic: Arc<MockSynic>,
2237        message_recv: mesh::Receiver<Vec<u8>>,
2238        trusted: bool,
2239    }
2240
2241    impl TestEnv {
2242        fn new(spawner: DefaultDriver) -> Self {
2243            let spawner: Arc<dyn SpawnDriver> = Arc::new(spawner);
2244            let (message_send, message_recv) = mesh::channel();
2245            let synic = Arc::new(MockSynic::new(message_send, Arc::clone(&spawner)));
2246            let gm = GuestMemory::empty();
2247            let vmbus = VmbusServerBuilder::new(&spawner, synic.clone(), gm)
2248                .build()
2249                .unwrap();
2250
2251            Self {
2252                vmbus,
2253                synic,
2254                message_recv,
2255                trusted: false,
2256            }
2257        }
2258
2259        async fn offer(&self, id: u32, allow_confidential_external_memory: bool) -> TestChannel {
2260            let guid = Guid {
2261                data1: id,
2262                ..Guid::ZERO
2263            };
2264            let (request_send, request_recv) = mesh::channel();
2265            let (server_request_send, server_request_recv) = mesh::channel();
2266            let offer = OfferInput {
2267                request_send,
2268                server_request_recv,
2269                params: OfferParams {
2270                    interface_name: "test".into(),
2271                    instance_id: guid,
2272                    interface_id: guid,
2273                    mmio_megabytes: 0,
2274                    mmio_megabytes_optional: 0,
2275                    channel_type: vmbus_channel::bus::ChannelType::Device {
2276                        pipe_packets: false,
2277                    },
2278                    subchannel_index: 0,
2279                    mnf_interrupt_latency: None,
2280                    offer_order: None,
2281                    allow_confidential_external_memory,
2282                },
2283            };
2284
2285            let control = self.vmbus.control();
2286            let _resources = control.add_child(offer).await.unwrap();
2287
2288            TestChannel {
2289                request_recv,
2290                server_request_send,
2291                _resources,
2292            }
2293        }
2294
2295        async fn gpadl(&mut self, channel_id: u32, gpadl_id: u32, channel: &mut TestChannel) {
2296            self.synic.send_message_core(
2297                OutgoingMessage::with_data(
2298                    &protocol::GpadlHeader {
2299                        channel_id: ChannelId(channel_id),
2300                        gpadl_id: GpadlId(gpadl_id),
2301                        count: 1,
2302                        len: 16,
2303                    },
2304                    [1u64, 0u64].as_bytes(),
2305                ),
2306                self.trusted,
2307            );
2308
2309            channel.handle_gpadl().await;
2310            self.expect_response(protocol::MessageType::GPADL_CREATED)
2311                .await;
2312        }
2313
2314        async fn open_channel(
2315            &mut self,
2316            channel_id: u32,
2317            ring_gpadl_id: u32,
2318            channel: &mut TestChannel,
2319            f: fn(&OpenRequest),
2320        ) {
2321            self.gpadl(channel_id, ring_gpadl_id, channel).await;
2322            self.synic.send_message_core(
2323                OutgoingMessage::new(&protocol::OpenChannel {
2324                    channel_id: ChannelId(channel_id),
2325                    open_id: 0,
2326                    ring_buffer_gpadl_id: GpadlId(ring_gpadl_id),
2327                    target_vp: 0,
2328                    downstream_ring_buffer_page_offset: 0,
2329                    user_data: UserDefinedData::default(),
2330                }),
2331                self.trusted,
2332            );
2333
2334            channel.handle_open(f).await;
2335            self.expect_response(protocol::MessageType::OPEN_CHANNEL_RESULT)
2336                .await;
2337        }
2338
2339        async fn expect_response(&mut self, expected: protocol::MessageType) {
2340            let data = self.message_recv.next().await.unwrap();
2341            let header = protocol::MessageHeader::read_from_prefix(&data).unwrap().0; // TODO: zerocopy: use-rest-of-range (https://github.com/microsoft/openvmm/issues/759)
2342            assert_eq!(expected, header.message_type())
2343        }
2344
2345        async fn get_response<T: VmbusMessage + FromBytes + Immutable + KnownLayout>(
2346            &mut self,
2347        ) -> T {
2348            let data = self.message_recv.next().await.unwrap();
2349            let (header, message) = protocol::MessageHeader::read_from_prefix(&data).unwrap(); // TODO: zerocopy: unwrap (https://github.com/microsoft/openvmm/issues/759)
2350            assert_eq!(T::MESSAGE_TYPE, header.message_type());
2351            T::read_from_prefix(message).unwrap().0 // TODO: zerocopy: use-rest-of-range (https://github.com/microsoft/openvmm/issues/759)
2352        }
2353
2354        fn initiate_contact(
2355            &mut self,
2356            version: protocol::Version,
2357            feature_flags: protocol::FeatureFlags,
2358            trusted: bool,
2359        ) {
2360            self.synic.send_message_core(
2361                OutgoingMessage::new(&protocol::InitiateContact {
2362                    version_requested: version as u32,
2363                    target_message_vp: 0,
2364                    child_to_parent_monitor_page_gpa: 0,
2365                    parent_to_child_monitor_page_gpa: 0,
2366                    interrupt_page_or_target_info: protocol::TargetInfo::new()
2367                        .with_sint(2)
2368                        .with_vtl(0)
2369                        .with_feature_flags(feature_flags.into())
2370                        .into(),
2371                }),
2372                trusted,
2373            );
2374
2375            self.trusted = trusted;
2376        }
2377
2378        async fn connect(
2379            &mut self,
2380            offer_count: u32,
2381            feature_flags: protocol::FeatureFlags,
2382            trusted: bool,
2383        ) {
2384            self.initiate_contact(protocol::Version::Copper, feature_flags, trusted);
2385
2386            self.expect_response(protocol::MessageType::VERSION_RESPONSE)
2387                .await;
2388
2389            self.synic
2390                .send_message_core(OutgoingMessage::new(&protocol::RequestOffers {}), trusted);
2391
2392            for _ in 0..offer_count {
2393                self.expect_response(protocol::MessageType::OFFER_CHANNEL)
2394                    .await;
2395            }
2396
2397            self.expect_response(protocol::MessageType::ALL_OFFERS_DELIVERED)
2398                .await;
2399        }
2400    }
2401
2402    #[async_test]
2403    async fn test_save_restore(spawner: DefaultDriver) {
2404        // Most save/restore state is tested in mod channels::tests; this test specifically checks
2405        // that ServerTaskInner correctly handles some aspects of the save/restore.
2406        //
2407        // If this test fails, it is more likely to hang than panic.
2408        let mut env = TestEnv::new(spawner);
2409        let mut channel = env.offer(1, false).await;
2410        env.vmbus.start();
2411        env.connect(1, protocol::FeatureFlags::new(), false).await;
2412
2413        // Create a GPADL for the channel.
2414        env.gpadl(1, 10, &mut channel).await;
2415
2416        // Start tearing it down.
2417        env.synic.send_message(protocol::GpadlTeardown {
2418            channel_id: ChannelId(1),
2419            gpadl_id: GpadlId(10),
2420        });
2421
2422        // Wait for the teardown request here to make sure the server has processed the teardown
2423        // message, but do not complete it before saving.
2424        let rpc = channel.get_gpadl_teardown().await;
2425        env.vmbus.stop().await;
2426        let saved_state = env.vmbus.save().await;
2427        env.vmbus.start();
2428
2429        // Finish tearing down the gpadl and release the channel so the server can reset.
2430        rpc.complete(());
2431        env.expect_response(protocol::MessageType::GPADL_TORNDOWN)
2432            .await;
2433
2434        env.synic.send_message(protocol::RelIdReleased {
2435            channel_id: ChannelId(1),
2436        });
2437
2438        env.vmbus.reset().await;
2439        env.vmbus.stop().await;
2440
2441        // When restoring with a gpadl in the TearingDown state, the teardown request for the device
2442        // will be repeated. This must not panic.
2443        env.vmbus.restore(saved_state).await.unwrap();
2444        channel.restore().await;
2445        env.vmbus.start();
2446
2447        // Handle the teardown after restore.
2448        channel.handle_gpadl_teardown().await;
2449        env.expect_response(protocol::MessageType::GPADL_TORNDOWN)
2450            .await;
2451
2452        env.synic.send_message(protocol::RelIdReleased {
2453            channel_id: ChannelId(1),
2454        });
2455    }
2456
2457    struct TestDeviceState {
2458        id: u32,
2459        started: bool,
2460        resources: Option<DeviceResources>,
2461        open_requests: HashMap<u16, OpenRequest>,
2462        target_vps: HashMap<u16, u32>,
2463    }
2464
2465    impl TestDeviceState {
2466        pub fn id(this: &Arc<Mutex<Self>>) -> u32 {
2467            this.lock().id
2468        }
2469
2470        pub fn started(this: &Arc<Mutex<Self>>) -> bool {
2471            this.lock().started
2472        }
2473        pub fn set_started(this: &Arc<Mutex<Self>>, started: bool) {
2474            this.lock().started = started;
2475        }
2476
2477        pub fn open_request(this: &Arc<Mutex<Self>>, channel_idx: u16) -> Option<OpenRequest> {
2478            this.lock().open_requests.get(&channel_idx).cloned()
2479        }
2480        pub fn set_open_request(
2481            this: &Arc<Mutex<Self>>,
2482            channel_idx: u16,
2483            open_request: OpenRequest,
2484        ) {
2485            assert!(
2486                this.lock()
2487                    .open_requests
2488                    .insert(channel_idx, open_request)
2489                    .is_none()
2490            );
2491        }
2492        pub fn remove_open_request(
2493            this: &Arc<Mutex<Self>>,
2494            channel_idx: u16,
2495        ) -> Option<OpenRequest> {
2496            this.lock().open_requests.remove(&channel_idx)
2497        }
2498
2499        pub fn target_vp(this: &Arc<Mutex<Self>>, channel_idx: u16) -> Option<u32> {
2500            this.lock().target_vps.get(&channel_idx).copied()
2501        }
2502        pub fn set_target_vp(this: &Arc<Mutex<Self>>, channel_idx: u16, target_vp: u32) {
2503            let _ = this.lock().target_vps.insert(channel_idx, target_vp);
2504        }
2505    }
2506
2507    #[derive(InspectMut)]
2508    struct TestDevice {
2509        #[inspect(skip)]
2510        pub state: Arc<Mutex<TestDeviceState>>,
2511    }
2512
2513    impl TestDevice {
2514        pub fn new_and_state(id: u32) -> (Self, Arc<Mutex<TestDeviceState>>) {
2515            let state = TestDeviceState {
2516                id,
2517                resources: None,
2518                open_requests: HashMap::new(),
2519                target_vps: HashMap::new(),
2520                started: false,
2521            };
2522            let state = Arc::new(Mutex::new(state));
2523            let this = Self {
2524                state: state.clone(),
2525            };
2526            (this, state)
2527        }
2528    }
2529
2530    #[async_trait]
2531    impl VmbusDevice for TestDevice {
2532        fn offer(&self) -> OfferParams {
2533            let guid = Guid {
2534                data1: TestDeviceState::id(&self.state),
2535                ..Guid::ZERO
2536            };
2537
2538            OfferParams {
2539                interface_name: "test".into(),
2540                instance_id: guid,
2541                interface_id: guid,
2542                channel_type: vmbus_channel::bus::ChannelType::Device {
2543                    pipe_packets: false,
2544                },
2545                ..Default::default()
2546            }
2547        }
2548
2549        fn max_subchannels(&self) -> u16 {
2550            0
2551        }
2552
2553        fn install(&mut self, resources: DeviceResources) {
2554            self.state.lock().resources = Some(resources);
2555        }
2556
2557        async fn open(
2558            &mut self,
2559            channel_idx: u16,
2560            open_request: &OpenRequest,
2561        ) -> Result<(), ChannelOpenError> {
2562            tracing::info!("OPEN");
2563            TestDeviceState::set_open_request(&self.state, channel_idx, open_request.clone());
2564            Ok(())
2565        }
2566
2567        async fn close(&mut self, channel_idx: u16) {
2568            tracing::info!("CLOSE");
2569            assert!(TestDeviceState::remove_open_request(&self.state, channel_idx).is_some());
2570        }
2571
2572        async fn retarget_vp(&mut self, channel_idx: u16, target_vp: u32) {
2573            TestDeviceState::set_target_vp(&self.state, channel_idx, target_vp);
2574        }
2575
2576        fn start(&mut self) {
2577            tracing::info!("START");
2578            TestDeviceState::set_started(&self.state, true);
2579        }
2580
2581        async fn stop(&mut self) {
2582            tracing::info!("STOP");
2583            TestDeviceState::set_started(&self.state, false);
2584        }
2585
2586        fn supports_save_restore(&mut self) -> Option<&mut dyn SaveRestoreVmbusDevice> {
2587            None
2588        }
2589    }
2590
2591    #[async_test]
2592    async fn test_stopped_child(spawner: DefaultDriver) {
2593        // This is mostly testing vmbus_channel behavior when a channel is
2594        // stopped but vbmus_server is not and continues to receive
2595        // messages.
2596        let mut env = TestEnv::new(spawner.clone());
2597        let (test_device, test_device_state) = TestDevice::new_and_state(1);
2598        let control = env.vmbus.control();
2599        let channel = offer_channel(&spawner, control.as_ref(), test_device)
2600            .await
2601            .expect("test device failed to offer");
2602
2603        env.vmbus.start();
2604        env.connect(1, protocol::FeatureFlags::new(), false).await;
2605
2606        // Stop the channel.
2607        channel.stop().await;
2608
2609        assert_eq!(TestDeviceState::started(&test_device_state), false);
2610
2611        // GPADL processing is currently allowed while the channel is stopped,
2612        // so this should complete.
2613        env.synic.send_message_core(
2614            OutgoingMessage::with_data(
2615                &protocol::GpadlHeader {
2616                    channel_id: ChannelId(1),
2617                    gpadl_id: GpadlId(1),
2618                    count: 1,
2619                    len: 16,
2620                },
2621                [1u64, 0u64].as_bytes(),
2622            ),
2623            false,
2624        );
2625        env.expect_response(protocol::MessageType::GPADL_CREATED)
2626            .await;
2627
2628        // Open will pend while the channel is stopped.
2629        env.synic.send_message_core(
2630            OutgoingMessage::new(&protocol::OpenChannel {
2631                channel_id: ChannelId(1),
2632                open_id: 0,
2633                ring_buffer_gpadl_id: GpadlId(1),
2634                target_vp: 0,
2635                downstream_ring_buffer_page_offset: 0,
2636                user_data: UserDefinedData::default(),
2637            }),
2638            false,
2639        );
2640        let wait_for_response = mesh::CancelContext::new()
2641            .with_timeout(Duration::from_millis(150))
2642            .until_cancelled(env.expect_response(protocol::MessageType::OPEN_CHANNEL_RESULT))
2643            .await;
2644        assert!(matches!(
2645            wait_for_response,
2646            Err(CancelReason::DeadlineExceeded)
2647        ));
2648        assert!(TestDeviceState::open_request(&test_device_state, 0).is_none());
2649
2650        // Restart the channel and confirm that open completes.
2651        channel.start();
2652        env.expect_response(protocol::MessageType::OPEN_CHANNEL_RESULT)
2653            .await;
2654        assert!(TestDeviceState::open_request(&test_device_state, 0).is_some());
2655
2656        // Stop the channel and send a modify request.
2657        assert!(TestDeviceState::target_vp(&test_device_state, 0).is_none());
2658        channel.stop().await;
2659        env.synic.send_message_core(
2660            OutgoingMessage::new(&protocol::ModifyChannel {
2661                channel_id: ChannelId(1),
2662                target_vp: 2,
2663            }),
2664            false,
2665        );
2666        let wait_for_response = mesh::CancelContext::new()
2667            .with_timeout(Duration::from_millis(150))
2668            .until_cancelled(env.expect_response(protocol::MessageType::MODIFY_CHANNEL_RESPONSE))
2669            .await;
2670        assert!(matches!(
2671            wait_for_response,
2672            Err(CancelReason::DeadlineExceeded)
2673        ));
2674
2675        // Restart the channel and verify the modify request completes.
2676        channel.start();
2677        env.expect_response(protocol::MessageType::MODIFY_CHANNEL_RESPONSE)
2678            .await;
2679        assert_eq!(
2680            TestDeviceState::target_vp(&test_device_state, 0)
2681                .expect("Modify channel request received"),
2682            2
2683        );
2684
2685        // Stop the channel and send a close request. Close is currently
2686        // allowed through in order to support reset of the vmbus
2687        // server, so try that.
2688        channel.stop().await;
2689        env.vmbus.reset().await;
2690        assert!(TestDeviceState::open_request(&test_device_state, 0).is_none());
2691
2692        env.vmbus.stop().await;
2693    }
2694
2695    #[async_test]
2696    async fn test_confidential_connection(spawner: DefaultDriver) {
2697        let mut env = TestEnv::new(spawner);
2698        // Add regular bus child channels, one of which supports confidential external memory.
2699        let mut channel = env.offer(1, false).await;
2700        let mut channel2 = env.offer(2, true).await;
2701
2702        // Add a channel directly, like the relay would do.
2703        let (request_send, request_recv) = mesh::channel();
2704        let (server_request_send, server_request_recv) = mesh::channel();
2705        let id = Guid {
2706            data1: 3,
2707            ..Guid::ZERO
2708        };
2709        let control = env.vmbus.control();
2710        let relay_resources = control
2711            .offer_core(OfferInfo {
2712                params: OfferParamsInternal {
2713                    interface_name: "test".into(),
2714                    instance_id: id,
2715                    interface_id: id,
2716                    mmio_megabytes: 0,
2717                    mmio_megabytes_optional: 0,
2718                    subchannel_index: 0,
2719                    use_mnf: MnfUsage::Disabled,
2720                    offer_order: None,
2721                    flags: protocol::OfferFlags::new().with_enumerate_device_interface(true),
2722                    ..Default::default()
2723                },
2724                request_send,
2725                server_request_recv,
2726            })
2727            .await
2728            .unwrap();
2729
2730        let mut relay_channel = TestChannel {
2731            request_recv,
2732            server_request_send,
2733            _resources: relay_resources,
2734        };
2735
2736        env.vmbus.start();
2737        env.initiate_contact(
2738            protocol::Version::Copper,
2739            protocol::FeatureFlags::new().with_confidential_channels(true),
2740            true,
2741        );
2742
2743        env.expect_response(protocol::MessageType::VERSION_RESPONSE)
2744            .await;
2745
2746        env.synic.send_message_trusted(protocol::RequestOffers {});
2747
2748        // All offers added with add_child have confidential ring support.
2749        let offer = env.get_response::<protocol::OfferChannel>().await;
2750        assert!(offer.flags.confidential_ring_buffer());
2751        assert!(!offer.flags.confidential_external_memory());
2752        let offer = env.get_response::<protocol::OfferChannel>().await;
2753        assert!(offer.flags.confidential_ring_buffer());
2754        assert!(offer.flags.confidential_external_memory());
2755
2756        // The "relay" channel will not have its flags modified.
2757        let offer = env.get_response::<protocol::OfferChannel>().await;
2758        assert!(!offer.flags.confidential_ring_buffer());
2759        assert!(!offer.flags.confidential_external_memory());
2760
2761        env.expect_response(protocol::MessageType::ALL_OFFERS_DELIVERED)
2762            .await;
2763
2764        // Make sure that the correct confidential flags are set in the open request when opening
2765        // the channels.
2766        env.open_channel(1, 1, &mut channel, |request| {
2767            assert!(request.use_confidential_ring);
2768            assert!(!request.use_confidential_external_memory);
2769        })
2770        .await;
2771
2772        env.open_channel(2, 2, &mut channel2, |request| {
2773            assert!(request.use_confidential_ring);
2774            assert!(request.use_confidential_external_memory);
2775        })
2776        .await;
2777
2778        env.open_channel(3, 3, &mut relay_channel, |request| {
2779            assert!(!request.use_confidential_ring);
2780            assert!(!request.use_confidential_external_memory);
2781        })
2782        .await;
2783    }
2784
2785    #[async_test]
2786    async fn test_confidential_channels_unsupported(spawner: DefaultDriver) {
2787        let mut env = TestEnv::new(spawner);
2788        let mut channel = env.offer(1, false).await;
2789        let mut channel2 = env.offer(2, true).await;
2790
2791        env.vmbus.start();
2792        env.connect(2, protocol::FeatureFlags::new(), true).await;
2793
2794        // Make sure that the correct confidential flags are always false when the client doesn't
2795        // support confidential channels.
2796        env.open_channel(1, 1, &mut channel, |request| {
2797            assert!(!request.use_confidential_ring);
2798            assert!(!request.use_confidential_external_memory);
2799        })
2800        .await;
2801
2802        env.open_channel(2, 2, &mut channel2, |request| {
2803            assert!(!request.use_confidential_ring);
2804            assert!(!request.use_confidential_external_memory);
2805        })
2806        .await;
2807    }
2808
2809    #[async_test]
2810    async fn test_confidential_channels_untrusted(spawner: DefaultDriver) {
2811        let mut env = TestEnv::new(spawner);
2812        let mut channel = env.offer(1, false).await;
2813        let mut channel2 = env.offer(2, true).await;
2814
2815        env.vmbus.start();
2816        // Client claims to support confidential channels, but they can't be used because the
2817        // connection is untrusted.
2818        env.connect(
2819            2,
2820            protocol::FeatureFlags::new().with_confidential_channels(true),
2821            false,
2822        )
2823        .await;
2824
2825        // Make sure that the correct confidential flags are always false when the client doesn't
2826        // support confidential channels.
2827        env.open_channel(1, 1, &mut channel, |request| {
2828            assert!(!request.use_confidential_ring);
2829            assert!(!request.use_confidential_external_memory);
2830        })
2831        .await;
2832
2833        env.open_channel(2, 2, &mut channel2, |request| {
2834            assert!(!request.use_confidential_ring);
2835            assert!(!request.use_confidential_external_memory);
2836        })
2837        .await;
2838    }
2839}