vmbus_server/
lib.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4#![expect(missing_docs)]
5#![forbid(unsafe_code)]
6
7mod channel_bitmap;
8pub mod channels;
9pub mod event;
10pub mod hvsock;
11mod monitor;
12mod proxyintegration;
13
14/// The GUID type used for vmbus channel identifiers.
15pub type Guid = guid::Guid;
16
17use anyhow::Context;
18use async_trait::async_trait;
19use channel_bitmap::ChannelBitmap;
20use channels::ConnectionTarget;
21pub use channels::InitiateContactRequest;
22use channels::MessageTarget;
23pub use channels::MnfUsage;
24use channels::ModifyConnectionRequest;
25pub use channels::ModifyConnectionResponse;
26use channels::Notifier;
27use channels::OfferId;
28pub use channels::OfferParamsInternal;
29use channels::OpenParams;
30use channels::RestoreError;
31pub use channels::Update;
32use futures::FutureExt;
33use futures::StreamExt;
34use futures::channel::mpsc;
35use futures::channel::mpsc::SendError;
36use futures::future::OptionFuture;
37use futures::future::poll_fn;
38use futures::stream::SelectAll;
39use guestmem::GuestMemory;
40use hvdef::Vtl;
41use inspect::Inspect;
42use mesh::payload::Protobuf;
43use mesh::rpc::FailableRpc;
44use mesh::rpc::Rpc;
45use mesh::rpc::RpcError;
46use mesh::rpc::RpcSend;
47use pal_async::driver::Driver;
48use pal_async::driver::SpawnDriver;
49use pal_async::task::Task;
50use pal_async::timer::PolledTimer;
51use pal_event::Event;
52#[cfg(windows)]
53pub use proxyintegration::ProxyIntegration;
54#[cfg(windows)]
55pub use proxyintegration::ProxyServerInfo;
56use ring::PAGE_SIZE;
57use std::collections::HashMap;
58use std::future;
59use std::future::Future;
60use std::pin::Pin;
61use std::sync::Arc;
62use std::task::Poll;
63use std::task::ready;
64use std::time::Duration;
65use unicycle::FuturesUnordered;
66use vmbus_channel::bus::ChannelRequest;
67use vmbus_channel::bus::ChannelServerRequest;
68use vmbus_channel::bus::GpadlRequest;
69use vmbus_channel::bus::ModifyRequest;
70use vmbus_channel::bus::OfferInput;
71use vmbus_channel::bus::OfferKey;
72use vmbus_channel::bus::OfferResources;
73use vmbus_channel::bus::OpenData;
74use vmbus_channel::bus::OpenRequest;
75use vmbus_channel::bus::ParentBus;
76use vmbus_channel::bus::RestoreResult;
77use vmbus_channel::gpadl::GpadlMap;
78use vmbus_channel::gpadl_ring::AlignedGpadlView;
79use vmbus_core::HvsockConnectRequest;
80use vmbus_core::HvsockConnectResult;
81use vmbus_core::MaxVersionInfo;
82use vmbus_core::OutgoingMessage;
83use vmbus_core::TaggedStream;
84use vmbus_core::VersionInfo;
85use vmbus_core::protocol;
86pub use vmbus_core::protocol::GpadlId;
87#[cfg(windows)]
88use vmbus_proxy::ProxyHandle;
89use vmbus_ring as ring;
90use vmbus_ring::gparange::MultiPagedRangeBuf;
91use vmcore::interrupt::Interrupt;
92use vmcore::save_restore::SavedStateRoot;
93use vmcore::synic::EventPort;
94use vmcore::synic::GuestEventPort;
95use vmcore::synic::GuestMessagePort;
96use vmcore::synic::MessagePort;
97use vmcore::synic::MonitorPageGpas;
98use vmcore::synic::SynicPortAccess;
99
100const SINT: u8 = 2;
101pub const REDIRECT_SINT: u8 = 7;
102pub const REDIRECT_VTL: Vtl = Vtl::Vtl2;
103const SHARED_EVENT_CONNECTION_ID: u32 = 2;
104const EVENT_PORT_ID: u32 = 2;
105const VMBUS_MESSAGE_TYPE: u32 = 1;
106
107const MAX_CONCURRENT_HVSOCK_REQUESTS: usize = 16;
108
109pub struct VmbusServer {
110    task_send: mesh::Sender<VmbusRequest>,
111    control: Arc<VmbusServerControl>,
112    _message_port: Box<dyn Sync + Send>,
113    _multiclient_message_port: Option<Box<dyn Sync + Send>>,
114    task: Task<ServerTask>,
115}
116
117pub struct VmbusServerBuilder<T: SpawnDriver> {
118    spawner: T,
119    synic: Arc<dyn SynicPortAccess>,
120    gm: GuestMemory,
121    private_gm: Option<GuestMemory>,
122    vtl: Vtl,
123    hvsock_notify: Option<HvsockServerChannelHalf>,
124    server_relay: Option<VmbusServerChannelHalf>,
125    saved_state_notify: Option<mesh::Sender<SavedStateRequest>>,
126    external_server: Option<mesh::Sender<InitiateContactRequest>>,
127    external_requests: Option<mesh::Receiver<InitiateContactRequest>>,
128    use_message_redirect: bool,
129    channel_id_offset: u16,
130    max_version: Option<MaxVersionInfo>,
131    delay_max_version: bool,
132    enable_mnf: bool,
133    force_confidential_external_memory: bool,
134    send_messages_while_stopped: bool,
135    channel_unstick_delay: Option<Duration>,
136}
137
138#[derive(mesh::MeshPayload)]
139/// The request to send to the proxy to set or clear its saved state cache.
140pub enum SavedStateRequest {
141    Set(FailableRpc<Box<channels::SavedState>, ()>),
142    Clear(Rpc<(), ()>),
143}
144
145/// The server side of the connection between a vmbus server and a relay.
146pub struct ServerChannelHalf<Request, Response> {
147    request_send: mesh::Sender<Request>,
148    response_receive: mesh::Receiver<Response>,
149}
150
151/// The relay side of a connection between a vmbus server and a relay.
152pub struct RelayChannelHalf<Request, Response> {
153    pub request_receive: mesh::Receiver<Request>,
154    pub response_send: mesh::Sender<Response>,
155}
156
157/// A connection between a vmbus server and a relay.
158pub struct RelayChannel<Request, Response> {
159    pub relay_half: RelayChannelHalf<Request, Response>,
160    pub server_half: ServerChannelHalf<Request, Response>,
161}
162
163impl<Request: 'static + Send, Response: 'static + Send> RelayChannel<Request, Response> {
164    /// Creates a new channel between the vmbus server and a relay.
165    pub fn new() -> Self {
166        let (request_send, request_receive) = mesh::channel();
167        let (response_send, response_receive) = mesh::channel();
168        Self {
169            relay_half: RelayChannelHalf {
170                request_receive,
171                response_send,
172            },
173            server_half: ServerChannelHalf {
174                request_send,
175                response_receive,
176            },
177        }
178    }
179}
180
181pub type VmbusServerChannelHalf = ServerChannelHalf<ModifyRelayRequest, ModifyConnectionResponse>;
182pub type VmbusRelayChannelHalf = RelayChannelHalf<ModifyRelayRequest, ModifyConnectionResponse>;
183pub type VmbusRelayChannel = RelayChannel<ModifyRelayRequest, ModifyConnectionResponse>;
184pub type HvsockServerChannelHalf = ServerChannelHalf<HvsockConnectRequest, HvsockConnectResult>;
185pub type HvsockRelayChannelHalf = RelayChannelHalf<HvsockConnectRequest, HvsockConnectResult>;
186pub type HvsockRelayChannel = RelayChannel<HvsockConnectRequest, HvsockConnectResult>;
187
188/// A request from the server to the relay to modify connection state.
189///
190/// The version, use_interrupt_page and target_message_vp field can only be present if this request
191/// was sent for an InitiateContact message from the guest.
192#[derive(Debug, Copy, Clone)]
193pub struct ModifyRelayRequest {
194    pub version: Option<u32>,
195    pub monitor_page: Update<MonitorPageGpas>,
196    pub use_interrupt_page: Option<bool>,
197}
198
199impl From<ModifyConnectionRequest> for ModifyRelayRequest {
200    fn from(value: ModifyConnectionRequest) -> Self {
201        Self {
202            version: value.version,
203            monitor_page: value.monitor_page,
204            use_interrupt_page: match value.interrupt_page {
205                Update::Unchanged => None,
206                Update::Reset => Some(false),
207                Update::Set(_) => Some(true),
208            },
209        }
210    }
211}
212
213#[derive(Debug)]
214enum VmbusRequest {
215    Reset(Rpc<(), ()>),
216    Inspect(inspect::Deferred),
217    Save(Rpc<(), SavedState>),
218    Restore(Rpc<Box<SavedState>, Result<(), RestoreError>>),
219    Start,
220    Stop(Rpc<(), ()>),
221}
222
223#[derive(mesh::MeshPayload, Debug)]
224pub struct OfferInfo {
225    pub params: OfferParamsInternal,
226    pub event: Interrupt,
227    pub request_send: mesh::Sender<ChannelRequest>,
228    pub server_request_recv: mesh::Receiver<ChannelServerRequest>,
229}
230
231#[expect(clippy::large_enum_variant)]
232#[derive(mesh::MeshPayload)]
233pub(crate) enum OfferRequest {
234    Offer(FailableRpc<OfferInfo, ()>),
235    ForceReset(Rpc<(), ()>),
236}
237
238impl Inspect for VmbusServer {
239    fn inspect(&self, req: inspect::Request<'_>) {
240        self.task_send.send(VmbusRequest::Inspect(req.defer()));
241    }
242}
243
244struct ChannelEvent(Interrupt);
245
246impl EventPort for ChannelEvent {
247    fn handle_event(&self, _flag: u16) {
248        self.0.deliver();
249    }
250
251    fn os_event(&self) -> Option<&Event> {
252        self.0.event()
253    }
254}
255
256#[derive(Debug, Protobuf, SavedStateRoot)]
257#[mesh(package = "vmbus.server")]
258pub struct SavedState {
259    #[mesh(1)]
260    server: channels::SavedState,
261    // Indicates if the lost synic bug is fixed or not. By default it's false.
262    // During the restore process, we check if the field is not true then
263    // unstick_channels() function will be called to mitigate the issue.
264    #[mesh(2)]
265    lost_synic_bug_fixed: bool,
266}
267
268const MESSAGE_CONNECTION_ID: u32 = 1;
269const MULTICLIENT_MESSAGE_CONNECTION_ID: u32 = 4;
270
271impl<T: SpawnDriver + Clone> VmbusServerBuilder<T> {
272    /// Creates a new builder for `VmbusServer` with the default options.
273    pub fn new(spawner: T, synic: Arc<dyn SynicPortAccess>, gm: GuestMemory) -> Self {
274        Self {
275            spawner,
276            synic,
277            gm,
278            private_gm: None,
279            vtl: Vtl::Vtl0,
280            hvsock_notify: None,
281            server_relay: None,
282            saved_state_notify: None,
283            external_server: None,
284            external_requests: None,
285            use_message_redirect: false,
286            channel_id_offset: 0,
287            max_version: None,
288            delay_max_version: false,
289            enable_mnf: false,
290            force_confidential_external_memory: false,
291            send_messages_while_stopped: false,
292            channel_unstick_delay: Some(Duration::from_millis(100)),
293        }
294    }
295
296    /// Sets a separate guest memory instance to use for channels that are confidential (non-relay
297    /// channels in Underhill on a hardware isolated VM). This is not relevant for a non-Underhill
298    /// VmBus server.
299    pub fn private_gm(mut self, private_gm: Option<GuestMemory>) -> Self {
300        self.private_gm = private_gm;
301        self
302    }
303
304    /// Sets the VTL that this instance will serve.
305    pub fn vtl(mut self, vtl: Vtl) -> Self {
306        self.vtl = vtl;
307        self
308    }
309
310    /// Sets a send/receive pair used to handle hvsocket requests.
311    pub fn hvsock_notify(mut self, hvsock_notify: Option<HvsockServerChannelHalf>) -> Self {
312        self.hvsock_notify = hvsock_notify;
313        self
314    }
315
316    /// Sets a send channel used to enlighten ProxyIntegration about saved channels.
317    pub fn saved_state_notify(
318        mut self,
319        saved_state_notify: Option<mesh::Sender<SavedStateRequest>>,
320    ) -> Self {
321        self.saved_state_notify = saved_state_notify;
322        self
323    }
324
325    /// Sets a send/receive pair that will be notified of server requests. This is used by the
326    /// Underhill relay.
327    pub fn server_relay(mut self, server_relay: Option<VmbusServerChannelHalf>) -> Self {
328        self.server_relay = server_relay;
329        self
330    }
331
332    /// Sets a receiver that receives requests from another server.
333    pub fn external_requests(
334        mut self,
335        external_requests: Option<mesh::Receiver<InitiateContactRequest>>,
336    ) -> Self {
337        self.external_requests = external_requests;
338        self
339    }
340
341    /// Sets a sender used to forward unhandled connect requests (which used a different VTL)
342    /// to another server.
343    pub fn external_server(
344        mut self,
345        external_server: Option<mesh::Sender<InitiateContactRequest>>,
346    ) -> Self {
347        self.external_server = external_server;
348        self
349    }
350
351    /// Sets a value which indicates whether the vmbus control plane is redirected to Underhill.
352    pub fn use_message_redirect(mut self, use_message_redirect: bool) -> Self {
353        self.use_message_redirect = use_message_redirect;
354        self
355    }
356
357    /// Tells the server to use an offset when generating channel IDs to void collisions with
358    /// another vmbus server.
359    ///
360    /// N.B. This should only be used by the Underhill vmbus server.
361    pub fn enable_channel_id_offset(mut self, enable: bool) -> Self {
362        self.channel_id_offset = if enable { 1024 } else { 0 };
363        self
364    }
365
366    /// Tells the server to limit the protocol version offered to the guest.
367    ///
368    /// N.B. This is used for testing older protocols without requiring a specific guest OS.
369    pub fn max_version(mut self, max_version: Option<MaxVersionInfo>) -> Self {
370        self.max_version = max_version;
371        self
372    }
373
374    /// Delay limiting the maximum version until after the first `Unload` message.
375    ///
376    /// N.B. This is used to enable the use of versions older than `Version::Win10` with Uefi boot,
377    ///      since that's the oldest version the Uefi client supports.
378    pub fn delay_max_version(mut self, delay: bool) -> Self {
379        self.delay_max_version = delay;
380        self
381    }
382
383    /// Enable MNF support in the server.
384    ///
385    /// N.B. Enabling this has no effect if the synic does not support mapping monitor pages.
386    pub fn enable_mnf(mut self, enable: bool) -> Self {
387        self.enable_mnf = enable;
388        self
389    }
390
391    /// Force all non-relay channels to use encrypted external memory. Used for testing purposes
392    /// only.
393    pub fn force_confidential_external_memory(mut self, force: bool) -> Self {
394        self.force_confidential_external_memory = force;
395        self
396    }
397
398    /// Send messages to the partition even while stopped, which can cause
399    /// corrupted synic states across VM reset.
400    ///
401    /// This option is used to prevent messages from getting into the queue, for
402    /// saved state compatibility with release/2411. It can be removed once that
403    /// release is no longer supported.
404    pub fn send_messages_while_stopped(mut self, send: bool) -> Self {
405        self.send_messages_while_stopped = send;
406        self
407    }
408
409    /// Sets the delay before unsticking a vmbus channel after it has been opened.
410    ///
411    /// This option provides a work around for guests that ignore interrupts before they receive the
412    /// OpenResult message, by triggering an interrupt after the channel has been opened.
413    ///
414    /// If not set, the default is 100ms. If set to `None`, no interrupt will be triggered.
415    pub fn channel_unstick_delay(mut self, delay: Option<Duration>) -> Self {
416        self.channel_unstick_delay = delay;
417        self
418    }
419
420    /// Creates a new instance of the server.
421    ///
422    /// When the object is dropped, all channels will be closed and revoked
423    /// automatically.
424    pub fn build(self) -> anyhow::Result<VmbusServer> {
425        #[expect(clippy::disallowed_methods)] // TODO
426        let (message_send, message_recv) = mpsc::channel(64);
427        let message_sender = Arc::new(MessageSender {
428            send: message_send.clone(),
429            multiclient: self.use_message_redirect,
430        });
431
432        let (redirect_vtl, redirect_sint) = if self.use_message_redirect {
433            (REDIRECT_VTL, REDIRECT_SINT)
434        } else {
435            (self.vtl, SINT)
436        };
437
438        // If this server is not for VTL2, use a server-specific connection ID rather than the
439        // standard one.
440        let connection_id = if self.vtl == Vtl::Vtl0 && !self.use_message_redirect {
441            MESSAGE_CONNECTION_ID
442        } else {
443            // TODO: This ID should be using the correct target VP, but that is not known until
444            //       InitiateContact.
445            VmbusServer::get_child_message_connection_id(0, redirect_sint, redirect_vtl)
446        };
447
448        let _message_port = self
449            .synic
450            .add_message_port(connection_id, redirect_vtl, message_sender)
451            .context("failed to create vmbus synic ports")?;
452
453        // If this server is for VTL0, it is also responsible for the multiclient message port.
454        // N.B. If control plane redirection is enabled, the redirected message port is used for
455        //      multiclient and no separate multiclient port is created.
456        let _multiclient_message_port = if self.vtl == Vtl::Vtl0 && !self.use_message_redirect {
457            let multiclient_message_sender = Arc::new(MessageSender {
458                send: message_send,
459                multiclient: true,
460            });
461
462            Some(
463                self.synic
464                    .add_message_port(
465                        MULTICLIENT_MESSAGE_CONNECTION_ID,
466                        self.vtl,
467                        multiclient_message_sender,
468                    )
469                    .context("failed to create vmbus synic ports")?,
470            )
471        } else {
472            None
473        };
474
475        let (offer_send, offer_recv) = mesh::mpsc_channel();
476        let control = Arc::new(VmbusServerControl {
477            mem: self.gm.clone(),
478            private_mem: self.private_gm.clone(),
479            send: offer_send,
480            use_event: self.synic.prefer_os_events(),
481            force_confidential_external_memory: self.force_confidential_external_memory,
482        });
483
484        let mut server = channels::Server::new(self.vtl, connection_id, self.channel_id_offset);
485
486        // If requested, limit the maximum protocol version and feature flags.
487        if let Some(version) = self.max_version {
488            server.set_compatibility_version(version, self.delay_max_version);
489        }
490        let (relay_request_send, relay_response_recv) =
491            if let Some(server_relay) = self.server_relay {
492                let r = server_relay.response_receive.boxed().fuse();
493                (server_relay.request_send, r)
494            } else {
495                let (req_send, req_recv) = mesh::channel();
496                let resp_recv = req_recv
497                    .map(|_| {
498                        ModifyConnectionResponse::Supported(
499                            protocol::ConnectionState::SUCCESSFUL,
500                            protocol::FeatureFlags::from_bits(u32::MAX),
501                        )
502                    })
503                    .boxed()
504                    .fuse();
505                (req_send, resp_recv)
506            };
507
508        // If no hvsock notifier was specified, use a default one that always sends an error response.
509        let (hvsock_send, hvsock_recv) = if let Some(hvsock_notify) = self.hvsock_notify {
510            let r = hvsock_notify.response_receive.boxed().fuse();
511            (hvsock_notify.request_send, r)
512        } else {
513            let (req_send, req_recv) = mesh::channel();
514            let resp_recv = req_recv
515                .map(|r: HvsockConnectRequest| HvsockConnectResult::from_request(&r, false))
516                .boxed()
517                .fuse();
518            (req_send, resp_recv)
519        };
520
521        let inner = ServerTaskInner {
522            running: false,
523            send_messages_while_stopped: self.send_messages_while_stopped,
524            gm: self.gm,
525            private_gm: self.private_gm,
526            vtl: self.vtl,
527            redirect_vtl,
528            redirect_sint,
529            message_port: self
530                .synic
531                .new_guest_message_port(redirect_vtl, 0, redirect_sint)?,
532            synic: self.synic,
533            hvsock_requests: 0,
534            hvsock_send,
535            saved_state_notify: self.saved_state_notify,
536            channels: HashMap::new(),
537            channel_responses: FuturesUnordered::new(),
538            relay_send: relay_request_send,
539            external_server_send: self.external_server,
540            channel_bitmap: None,
541            shared_event_port: None,
542            reset_done: Vec::new(),
543            enable_mnf: self.enable_mnf,
544        };
545
546        let (task_send, task_recv) = mesh::channel();
547        let mut server_task = ServerTask {
548            driver: Box::new(self.spawner.clone()),
549            server,
550            task_recv,
551            offer_recv,
552            message_recv,
553            server_request_recv: SelectAll::new(),
554            inner,
555            external_requests: self.external_requests,
556            next_seq: 0,
557            unstick_on_start: false,
558            channel_unstickers: FuturesUnordered::new(),
559            channel_unstick_delay: self.channel_unstick_delay,
560        };
561
562        let task = self.spawner.spawn("vmbus server", async move {
563            server_task.run(relay_response_recv, hvsock_recv).await;
564            server_task
565        });
566
567        Ok(VmbusServer {
568            task_send,
569            control,
570            _message_port,
571            _multiclient_message_port,
572            task,
573        })
574    }
575}
576
577impl VmbusServer {
578    /// Creates a new builder for `VmbusServer` with the default options.
579    pub fn builder<T: SpawnDriver + Clone>(
580        spawner: T,
581        synic: Arc<dyn SynicPortAccess>,
582        gm: GuestMemory,
583    ) -> VmbusServerBuilder<T> {
584        VmbusServerBuilder::new(spawner, synic, gm)
585    }
586
587    pub async fn save(&self) -> SavedState {
588        self.task_send.call(VmbusRequest::Save, ()).await.unwrap()
589    }
590
591    pub async fn restore(&self, state: SavedState) -> Result<(), RestoreError> {
592        self.task_send
593            .call(VmbusRequest::Restore, Box::new(state))
594            .await
595            .unwrap()
596    }
597
598    /// Stop the control plane.
599    pub async fn stop(&self) {
600        self.task_send.call(VmbusRequest::Stop, ()).await.unwrap()
601    }
602
603    /// Starts the control plane.
604    pub fn start(&self) {
605        self.task_send.send(VmbusRequest::Start);
606    }
607
608    /// Resets the vmbus channel state.
609    pub async fn reset(&self) {
610        tracing::debug!("resetting channel state");
611        self.task_send.call(VmbusRequest::Reset, ()).await.unwrap()
612    }
613
614    /// Tears down the vmbus control plane.
615    pub async fn shutdown(self) {
616        drop(self.task_send);
617        let _ = self.task.await;
618    }
619
620    /// Returns an object that can be used to offer channels.
621    pub fn control(&self) -> Arc<VmbusServerControl> {
622        self.control.clone()
623    }
624
625    /// Returns the message connection ID to use for a communication from the guest for servers
626    /// that use a non-standard SINT or VTL.
627    fn get_child_message_connection_id(vp_index: u32, sint_index: u8, vtl: Vtl) -> u32 {
628        MULTICLIENT_MESSAGE_CONNECTION_ID
629            | (vtl as u32) << 22
630            | vp_index << 8
631            | (sint_index as u32) << 4
632    }
633
634    fn get_child_event_port_id(channel_id: protocol::ChannelId, sint_index: u8, vtl: Vtl) -> u32 {
635        EVENT_PORT_ID | (vtl as u32) << 22 | channel_id.0 << 8 | (sint_index as u32) << 4
636    }
637}
638
639#[derive(mesh::MeshPayload)]
640pub struct RestoreInfo {
641    open_data: Option<OpenData>,
642    gpadls: Vec<(GpadlId, u16, Vec<u64>)>,
643    interrupt: Option<Interrupt>,
644}
645
646#[derive(Default)]
647pub struct SynicMessage {
648    data: Vec<u8>,
649    multiclient: bool,
650    trusted: bool,
651}
652
653/// Disambiguates offer instances that may have reused the same offer ID.
654#[derive(Debug, Clone, Copy)]
655struct OfferInstanceId {
656    offer_id: OfferId,
657    seq: u64,
658}
659
660struct ServerTask {
661    driver: Box<dyn Driver>,
662    server: channels::Server,
663    task_recv: mesh::Receiver<VmbusRequest>,
664    offer_recv: mesh::Receiver<OfferRequest>,
665    message_recv: mpsc::Receiver<SynicMessage>,
666    server_request_recv:
667        SelectAll<TaggedStream<OfferInstanceId, mesh::Receiver<ChannelServerRequest>>>,
668    inner: ServerTaskInner,
669    external_requests: Option<mesh::Receiver<InitiateContactRequest>>,
670    /// Next value for [`Channel::seq`].
671    next_seq: u64,
672    unstick_on_start: bool,
673    channel_unstickers: FuturesUnordered<Pin<Box<dyn Send + Future<Output = OfferInstanceId>>>>,
674    channel_unstick_delay: Option<Duration>,
675}
676
677struct ServerTaskInner {
678    running: bool,
679    send_messages_while_stopped: bool,
680    gm: GuestMemory,
681    private_gm: Option<GuestMemory>,
682    synic: Arc<dyn SynicPortAccess>,
683    vtl: Vtl,
684    redirect_vtl: Vtl,
685    redirect_sint: u8,
686    message_port: Box<dyn GuestMessagePort>,
687    hvsock_requests: usize,
688    hvsock_send: mesh::Sender<HvsockConnectRequest>,
689    saved_state_notify: Option<mesh::Sender<SavedStateRequest>>,
690    channels: HashMap<OfferId, Channel>,
691    channel_responses: FuturesUnordered<
692        Pin<Box<dyn Send + Future<Output = (OfferId, u64, Result<ChannelResponse, RpcError>)>>>,
693    >,
694    external_server_send: Option<mesh::Sender<InitiateContactRequest>>,
695    relay_send: mesh::Sender<ModifyRelayRequest>,
696    channel_bitmap: Option<Arc<ChannelBitmap>>,
697    shared_event_port: Option<Box<dyn Send>>,
698    reset_done: Vec<Rpc<(), ()>>,
699    enable_mnf: bool,
700}
701
702#[derive(Debug)]
703enum ChannelResponse {
704    Open(bool),
705    Close,
706    Gpadl(GpadlId, bool),
707    TeardownGpadl(GpadlId),
708    Modify(i32),
709}
710
711#[derive(Debug, Copy, Clone, PartialEq, Eq)]
712enum ChannelUnstickState {
713    None,
714    Queued,
715    NeedsRequeue,
716}
717
718struct Channel {
719    key: OfferKey,
720    send: mesh::Sender<ChannelRequest>,
721    seq: u64,
722    state: ChannelState,
723    gpadls: Arc<GpadlMap>,
724    guest_to_host_event: Arc<ChannelEvent>,
725    flags: protocol::OfferFlags,
726    // A channel can be reserved no matter what state it is in. This allows the message port for a
727    // reserved channel to remain available even if the channel is closed, so the guest can read the
728    // close reserved channel response. The reserved state is cleared when the channel is revoked,
729    // reopened, or the guest sends an unload message.
730    reserved_state: ReservedState,
731    unstick_state: ChannelUnstickState,
732}
733
734struct ReservedState {
735    message_port: Option<Box<dyn GuestMessagePort>>,
736    target: ConnectionTarget,
737}
738
739struct ChannelOpenState {
740    open_params: OpenParams,
741    _event_port: Box<dyn Send>,
742    guest_event_port: Box<dyn GuestEventPort>,
743    host_to_guest_interrupt: Interrupt,
744}
745
746enum ChannelState {
747    Closed,
748    Open(Box<ChannelOpenState>),
749    Closing,
750}
751
752impl ServerTask {
753    fn handle_offer(&mut self, mut info: OfferInfo) -> anyhow::Result<()> {
754        let key = info.params.key();
755        let flags = info.params.flags;
756
757        if self.inner.enable_mnf && self.inner.synic.monitor_support().is_some() {
758            // If this server is handling MnF, ignore any relayed monitor IDs but still enable MnF
759            // for those channels.
760            // N.B. Since this can only happen in OpenHCL, which emulates MnF, the latency is
761            //      ignored.
762            if info.params.use_mnf.is_relayed() {
763                info.params.use_mnf = MnfUsage::Enabled {
764                    latency: Duration::ZERO,
765                }
766            }
767        } else if info.params.use_mnf.is_enabled() {
768            // If the server is not handling MnF, disable it for the channel. This does not affect
769            // channels with a relayed monitor ID.
770            info.params.use_mnf = MnfUsage::Disabled;
771        }
772
773        let offer_id = self
774            .server
775            .with_notifier(&mut self.inner)
776            .offer_channel(info.params)
777            .context("channel offer failed")?;
778
779        tracing::debug!(?offer_id, %key, "offered channel");
780
781        let seq = self.next_seq;
782        self.next_seq += 1;
783        self.inner.channels.insert(
784            offer_id,
785            Channel {
786                key,
787                send: info.request_send,
788                state: ChannelState::Closed,
789                gpadls: GpadlMap::new(),
790                guest_to_host_event: Arc::new(ChannelEvent(info.event)),
791                seq,
792                flags,
793                reserved_state: ReservedState {
794                    message_port: None,
795                    target: ConnectionTarget { vp: 0, sint: 0 },
796                },
797                unstick_state: ChannelUnstickState::None,
798            },
799        );
800
801        self.server_request_recv.push(TaggedStream::new(
802            OfferInstanceId { offer_id, seq },
803            info.server_request_recv,
804        ));
805
806        Ok(())
807    }
808
809    fn handle_revoke(&mut self, id: OfferInstanceId) {
810        // The channel may or may not exist in the map depending on whether it's been explicitly
811        // revoked before being dropped.
812        if let Some(channel) = self.inner.channels.get(&id.offer_id) {
813            if channel.seq == id.seq {
814                tracing::info!(?id.offer_id, "revoking channel");
815                self.inner.channels.remove(&id.offer_id);
816                self.server
817                    .with_notifier(&mut self.inner)
818                    .revoke_channel(id.offer_id);
819            }
820        }
821    }
822
823    fn handle_response(
824        &mut self,
825        offer_id: OfferId,
826        seq: u64,
827        response: Result<ChannelResponse, RpcError>,
828    ) {
829        // Validate the sequence to ensure the response is not for a revoked channel.
830        let channel = self
831            .inner
832            .channels
833            .get(&offer_id)
834            .filter(|channel| channel.seq == seq);
835
836        if let Some(channel) = channel {
837            match response {
838                Ok(response) => match response {
839                    ChannelResponse::Open(result) => self.handle_open(offer_id, result),
840                    ChannelResponse::Close => self.handle_close(offer_id),
841                    ChannelResponse::Gpadl(gpadl_id, ok) => {
842                        self.handle_gpadl_create(offer_id, gpadl_id, ok)
843                    }
844                    ChannelResponse::TeardownGpadl(gpadl_id) => {
845                        self.handle_gpadl_teardown(offer_id, gpadl_id)
846                    }
847                    ChannelResponse::Modify(status) => self.handle_modify_channel(offer_id, status),
848                },
849                Err(err) => {
850                    tracing::error!(
851                        key = %channel.key,
852                        error = &err as &dyn std::error::Error,
853                        "channel response failure, channel is in inconsistent state until revoked"
854                    );
855                }
856            }
857        } else {
858            tracing::debug!(offer_id = ?offer_id, seq, ?response, "received response after revoke");
859        }
860    }
861
862    fn handle_open(&mut self, offer_id: OfferId, success: bool) {
863        let status = if success {
864            let channel = self
865                .inner
866                .channels
867                .get_mut(&offer_id)
868                .expect("channel exists");
869
870            // Some guests ignore interrupts before they receive the OpenResult message. To avoid
871            // a potential hang, signal the channel after a delay if needed.
872            if let Some(delay) = self.channel_unstick_delay {
873                if channel.unstick_state == ChannelUnstickState::None {
874                    channel.unstick_state = ChannelUnstickState::Queued;
875                    let seq = channel.seq;
876                    let mut timer = PolledTimer::new(&self.driver);
877                    self.channel_unstickers.push(Box::pin(async move {
878                        timer.sleep(delay).await;
879                        OfferInstanceId { offer_id, seq }
880                    }));
881                } else {
882                    channel.unstick_state = ChannelUnstickState::NeedsRequeue;
883                }
884            }
885
886            0
887        } else {
888            protocol::STATUS_UNSUCCESSFUL
889        };
890
891        self.server
892            .with_notifier(&mut self.inner)
893            .open_complete(offer_id, status);
894    }
895
896    fn handle_close(&mut self, offer_id: OfferId) {
897        let channel = self
898            .inner
899            .channels
900            .get_mut(&offer_id)
901            .expect("channel still exists");
902
903        match &mut channel.state {
904            ChannelState::Closing => {
905                channel.state = ChannelState::Closed;
906                self.server
907                    .with_notifier(&mut self.inner)
908                    .close_complete(offer_id);
909            }
910            _ => {
911                tracing::error!(?offer_id, "invalid close channel response");
912            }
913        };
914    }
915
916    fn handle_gpadl_create(&mut self, offer_id: OfferId, gpadl_id: GpadlId, ok: bool) {
917        let status = if ok { 0 } else { protocol::STATUS_UNSUCCESSFUL };
918        self.server
919            .with_notifier(&mut self.inner)
920            .gpadl_create_complete(offer_id, gpadl_id, status);
921    }
922
923    fn handle_gpadl_teardown(&mut self, offer_id: OfferId, gpadl_id: GpadlId) {
924        self.server
925            .with_notifier(&mut self.inner)
926            .gpadl_teardown_complete(offer_id, gpadl_id);
927    }
928
929    fn handle_modify_channel(&mut self, offer_id: OfferId, status: i32) {
930        self.server
931            .with_notifier(&mut self.inner)
932            .modify_channel_complete(offer_id, status);
933    }
934
935    fn handle_restore_channel(
936        &mut self,
937        offer_id: OfferId,
938        open: bool,
939    ) -> anyhow::Result<RestoreResult> {
940        let gpadls = self.server.channel_gpadls(offer_id);
941
942        // If the channel is opened, handle that before calling into channels so that failure can
943        // be handled before the channel is marked restored.
944        let open_request = open
945            .then(|| -> anyhow::Result<_> {
946                let params = self.server.get_restore_open_params(offer_id)?;
947                let (channel, interrupt) = self.inner.open_channel(offer_id, &params)?;
948                Ok(OpenRequest::new(
949                    params.open_data,
950                    interrupt,
951                    self.server
952                        .get_version()
953                        .expect("must be connected")
954                        .feature_flags,
955                    channel.flags,
956                ))
957            })
958            .transpose()?;
959
960        self.server
961            .with_notifier(&mut self.inner)
962            .restore_channel(offer_id, open_request.is_some())?;
963
964        let channel = self.inner.channels.get_mut(&offer_id).unwrap();
965        for gpadl in &gpadls {
966            if let Ok(buf) =
967                MultiPagedRangeBuf::new(gpadl.request.count.into(), gpadl.request.buf.clone())
968            {
969                channel.gpadls.add(gpadl.request.id, buf);
970            }
971        }
972
973        let result = RestoreResult {
974            open_request,
975            gpadls,
976        };
977        Ok(result)
978    }
979
980    async fn handle_request(&mut self, request: VmbusRequest) {
981        tracing::debug!(?request, "handle_request");
982        match request {
983            VmbusRequest::Reset(rpc) => self.handle_reset(rpc),
984            VmbusRequest::Inspect(deferred) => {
985                deferred.respond(|resp| {
986                    resp.field("message_port", &self.inner.message_port)
987                        .field("running", self.inner.running)
988                        .field("hvsock_requests", self.inner.hvsock_requests)
989                        .field("channel_unstick_delay", self.channel_unstick_delay)
990                        .field_mut_with("unstick_channels", |v| {
991                            let v: inspect::ValueKind = if let Some(v) = v {
992                                if v == "force" {
993                                    self.unstick_channels(true);
994                                    v.into()
995                                } else {
996                                    let v =
997                                        v.parse().ok().context("expected false, true, or force")?;
998                                    if v {
999                                        self.unstick_channels(false);
1000                                    }
1001                                    v.into()
1002                                }
1003                            } else {
1004                                false.into()
1005                            };
1006                            anyhow::Ok(v)
1007                        })
1008                        .merge(&self.server.with_notifier(&mut self.inner));
1009                });
1010            }
1011            VmbusRequest::Save(rpc) => rpc.handle_sync(|()| SavedState {
1012                server: self.server.save(),
1013                lost_synic_bug_fixed: true,
1014            }),
1015            VmbusRequest::Restore(rpc) => {
1016                rpc.handle(async |state| {
1017                    self.unstick_on_start = !state.lost_synic_bug_fixed;
1018                    if let Some(sender) = &self.inner.saved_state_notify {
1019                        tracing::trace!("sending saved state to proxy");
1020                        if let Err(err) = sender
1021                            .call_failable(SavedStateRequest::Set, Box::new(state.server.clone()))
1022                            .await
1023                        {
1024                            tracing::error!(
1025                                err = &err as &dyn std::error::Error,
1026                                "failed to restore proxy saved state"
1027                            );
1028                            return Err(RestoreError::ServerError(err.into()));
1029                        }
1030                    }
1031
1032                    self.server
1033                        .with_notifier(&mut self.inner)
1034                        .restore(state.server)
1035                })
1036                .await
1037            }
1038            VmbusRequest::Stop(rpc) => rpc.handle_sync(|()| {
1039                if self.inner.running {
1040                    self.inner.running = false;
1041                }
1042            }),
1043            VmbusRequest::Start => {
1044                if !self.inner.running {
1045                    self.inner.running = true;
1046                    if let Some(sender) = self.inner.saved_state_notify.as_ref() {
1047                        // Indicate to the proxy that the server is starting and that it should
1048                        // clear its saved state cache.
1049                        tracing::trace!("sending clear saved state message to proxy");
1050                        sender
1051                            .call(SavedStateRequest::Clear, ())
1052                            .await
1053                            .expect("failed to clear proxy saved state");
1054                    }
1055
1056                    self.server
1057                        .with_notifier(&mut self.inner)
1058                        .revoke_unclaimed_channels();
1059                    if self.unstick_on_start {
1060                        tracing::info!(
1061                            "lost synic bug fix is not in yet, call unstick_channels to mitigate the issue."
1062                        );
1063                        self.unstick_channels(false);
1064                        self.unstick_on_start = false;
1065                    }
1066                }
1067            }
1068        }
1069    }
1070
1071    fn handle_reset(&mut self, rpc: Rpc<(), ()>) {
1072        let needs_reset = self.inner.reset_done.is_empty();
1073        self.inner.reset_done.push(rpc);
1074        if needs_reset {
1075            self.server.with_notifier(&mut self.inner).reset();
1076        }
1077    }
1078
1079    fn handle_relay_response(&mut self, response: ModifyConnectionResponse) {
1080        self.server
1081            .with_notifier(&mut self.inner)
1082            .complete_modify_connection(response);
1083    }
1084
1085    fn handle_tl_connect_result(&mut self, result: HvsockConnectResult) {
1086        assert_ne!(self.inner.hvsock_requests, 0);
1087        self.inner.hvsock_requests -= 1;
1088
1089        self.server
1090            .with_notifier(&mut self.inner)
1091            .send_tl_connect_result(result);
1092    }
1093
1094    fn handle_synic_message(&mut self, message: SynicMessage) {
1095        match self
1096            .server
1097            .with_notifier(&mut self.inner)
1098            .handle_synic_message(message)
1099        {
1100            Ok(()) => {}
1101            Err(err) => {
1102                tracing::warn!(
1103                    error = &err as &dyn std::error::Error,
1104                    "synic message error"
1105                );
1106            }
1107        }
1108    }
1109
1110    /// Handles a request forwarded by a different vmbus server. This is used to forward requests
1111    /// for different VTLs to different servers.
1112    ///
1113    /// N.B. This uses the same mechanism as the HCL server relay, so all requests, even the ones
1114    ///      meant for the primary server, are forwarded. In that case the primary server depends
1115    ///      on this server to send back a response so it can continue handling it.
1116    fn handle_external_request(&mut self, request: InitiateContactRequest) {
1117        self.server
1118            .with_notifier(&mut self.inner)
1119            .initiate_contact(request);
1120    }
1121
1122    async fn run(
1123        &mut self,
1124        mut relay_response_recv: impl futures::stream::FusedStream<Item = ModifyConnectionResponse>
1125        + Unpin,
1126        mut hvsock_recv: impl futures::stream::FusedStream<Item = HvsockConnectResult> + Unpin,
1127    ) {
1128        loop {
1129            // Create an OptionFuture for each event that should only be handled
1130            // while the VM is running. In other cases, leave the events in
1131            // their respective queues.
1132
1133            let running_not_resetting = self.inner.running && self.inner.reset_done.is_empty();
1134            let mut external_requests = OptionFuture::from(
1135                running_not_resetting
1136                    .then(|| {
1137                        self.external_requests
1138                            .as_mut()
1139                            .map(|r| r.select_next_some())
1140                    })
1141                    .flatten(),
1142            );
1143
1144            // Try to send any pending messages while the VM is running.
1145            let has_pending_messages = self.server.has_pending_messages();
1146            let message_port = self.inner.message_port.as_mut();
1147            let mut flush_pending_messages =
1148                OptionFuture::from((running_not_resetting && has_pending_messages).then(|| {
1149                    poll_fn(|cx| {
1150                        self.server.poll_flush_pending_messages(|msg| {
1151                            message_port.poll_post_message(cx, VMBUS_MESSAGE_TYPE, msg.data())
1152                        })
1153                    })
1154                    .fuse()
1155                }));
1156
1157            // Only handle new incoming messages if there are no outgoing messages pending, and not
1158            // too many hvsock requests outstanding. This puts a bound on the resources used by the
1159            // guest.
1160            let mut message_recv = OptionFuture::from(
1161                (running_not_resetting
1162                    && !has_pending_messages
1163                    && self.inner.hvsock_requests < MAX_CONCURRENT_HVSOCK_REQUESTS)
1164                    .then(|| self.message_recv.select_next_some()),
1165            );
1166
1167            // Accept channel responses until stopped or when resetting.
1168            let mut channel_response = OptionFuture::from(
1169                (self.inner.running || !self.inner.reset_done.is_empty())
1170                    .then(|| self.inner.channel_responses.select_next_some()),
1171            );
1172
1173            // Accept hvsock connect responses while the VM is running.
1174            let mut hvsock_response =
1175                OptionFuture::from(running_not_resetting.then(|| hvsock_recv.select_next_some()));
1176
1177            let mut channel_unstickers = OptionFuture::from(
1178                running_not_resetting.then(|| self.channel_unstickers.select_next_some()),
1179            );
1180
1181            futures::select! { // merge semantics
1182                r = self.task_recv.recv().fuse() => {
1183                    if let Ok(request) = r {
1184                        self.handle_request(request).await;
1185                    } else {
1186                        break;
1187                    }
1188                }
1189                r = self.offer_recv.select_next_some() => {
1190                    match r {
1191                        OfferRequest::Offer(rpc) => {
1192                            rpc.handle_failable_sync(|request| { self.handle_offer(request) })
1193                        },
1194                        OfferRequest::ForceReset(rpc) => {
1195                            self.handle_reset(rpc);
1196                        }
1197                    }
1198                }
1199                r = self.server_request_recv.select_next_some() => {
1200                    match r {
1201                        (id, Some(request)) => match request {
1202                            ChannelServerRequest::Restore(rpc) => rpc.handle_failable_sync(|open| {
1203                                self.handle_restore_channel(id.offer_id, open)
1204                            }),
1205                            ChannelServerRequest::Revoke(rpc) => rpc.handle_sync(|_| {
1206                                self.handle_revoke(id);
1207                            })
1208                        },
1209                        (id, None) => self.handle_revoke(id),
1210                    }
1211                }
1212                r = channel_response => {
1213                    let (id, seq, response) = r.unwrap();
1214                    self.handle_response(id, seq, response);
1215                }
1216                r = relay_response_recv.select_next_some() => {
1217                    self.handle_relay_response(r);
1218                },
1219                r = hvsock_response => {
1220                    self.handle_tl_connect_result(r.unwrap());
1221                }
1222                data = message_recv => {
1223                    let data = data.unwrap();
1224                    self.handle_synic_message(data);
1225                }
1226                r = external_requests => {
1227                    let r = r.unwrap();
1228                    self.handle_external_request(r);
1229                }
1230                r = channel_unstickers => {
1231                    self.unstick_channel_by_id(r.unwrap());
1232                }
1233                _r = flush_pending_messages => {}
1234                complete => break,
1235            }
1236        }
1237    }
1238
1239    /// Wakes the guest and optionally the host for every open channel. If `force`, always wakes
1240    /// them. If `!force`, only wake for rings that are in the state where a notification is
1241    /// expected.
1242    fn unstick_channels(&self, force: bool) {
1243        let Some(version) = self.server.get_version() else {
1244            tracing::warn!("cannot unstick when not connected");
1245            return;
1246        };
1247
1248        for channel in self.inner.channels.values() {
1249            let gm = self.inner.get_gm_for_channel(version, channel);
1250            if let Err(err) = Self::unstick_channel(gm, channel, force, true) {
1251                tracing::warn!(
1252                    channel = %channel.key,
1253                    error = err.as_ref() as &dyn std::error::Error,
1254                    "could not unstick channel"
1255                );
1256            }
1257        }
1258    }
1259
1260    /// Wakes the guest for the specified channel if it's open and the rings are in a state where
1261    /// notification is expected.
1262    fn unstick_channel_by_id(&mut self, id: OfferInstanceId) {
1263        let Some(version) = self.server.get_version() else {
1264            tracelimit::warn_ratelimited!("cannot unstick when not connected");
1265            return;
1266        };
1267
1268        if let Some(channel) = self.inner.channels.get_mut(&id.offer_id) {
1269            if channel.seq != id.seq {
1270                // The channel was revoked.
1271                return;
1272            }
1273
1274            // The channel was closed and reopened before the delay expired, so wait again to ensure
1275            // we don't signal too early.
1276            if channel.unstick_state == ChannelUnstickState::NeedsRequeue {
1277                channel.unstick_state = ChannelUnstickState::Queued;
1278                let mut timer = PolledTimer::new(&self.driver);
1279                let delay = self.channel_unstick_delay.unwrap();
1280                self.channel_unstickers.push(Box::pin(async move {
1281                    timer.sleep(delay).await;
1282                    id
1283                }));
1284
1285                return;
1286            }
1287
1288            channel.unstick_state = ChannelUnstickState::None;
1289            let gm = select_gm_for_channel(
1290                &self.inner.gm,
1291                self.inner.private_gm.as_ref(),
1292                version,
1293                channel,
1294            );
1295            if let Err(err) = Self::unstick_channel(gm, channel, false, false) {
1296                tracelimit::warn_ratelimited!(
1297                    channel = %channel.key,
1298                    error = err.as_ref() as &dyn std::error::Error,
1299                    "could not unstick channel"
1300                );
1301            }
1302        }
1303    }
1304
1305    fn unstick_channel(
1306        gm: &GuestMemory,
1307        channel: &Channel,
1308        force: bool,
1309        unstick_host: bool,
1310    ) -> anyhow::Result<()> {
1311        if let ChannelState::Open(state) = &channel.state {
1312            if force {
1313                tracing::info!(channel = %channel.key, "waking host and guest");
1314                if unstick_host {
1315                    channel.guest_to_host_event.0.deliver();
1316                }
1317                state.host_to_guest_interrupt.deliver();
1318                return Ok(());
1319            }
1320
1321            let gpadl = channel
1322                .gpadls
1323                .clone()
1324                .view()
1325                .map(state.open_params.open_data.ring_gpadl_id)
1326                .context("couldn't find ring gpadl")?;
1327
1328            let aligned = AlignedGpadlView::new(gpadl)
1329                .ok()
1330                .context("ring not aligned")?;
1331            let (in_gpadl, out_gpadl) = aligned
1332                .split(state.open_params.open_data.ring_offset)
1333                .ok()
1334                .context("couldn't split ring")?;
1335
1336            if let Err(err) = Self::unstick_incoming_ring(
1337                gm,
1338                channel,
1339                in_gpadl,
1340                unstick_host.then_some(channel.guest_to_host_event.as_ref()),
1341                &state.host_to_guest_interrupt,
1342            ) {
1343                tracelimit::warn_ratelimited!(
1344                    channel = %channel.key,
1345                    error = err.as_ref() as &dyn std::error::Error,
1346                    "could not unstick incoming ring"
1347                );
1348            }
1349            if let Err(err) = Self::unstick_outgoing_ring(
1350                gm,
1351                channel,
1352                out_gpadl,
1353                unstick_host.then_some(channel.guest_to_host_event.as_ref()),
1354                &state.host_to_guest_interrupt,
1355            ) {
1356                tracelimit::warn_ratelimited!(
1357                    channel = %channel.key,
1358                    error = err.as_ref() as &dyn std::error::Error,
1359                    "could not unstick outgoing ring"
1360                );
1361            }
1362        }
1363        Ok(())
1364    }
1365
1366    fn unstick_incoming_ring(
1367        gm: &GuestMemory,
1368        channel: &Channel,
1369        in_gpadl: AlignedGpadlView,
1370        guest_to_host_event: Option<&ChannelEvent>,
1371        host_to_guest_interrupt: &Interrupt,
1372    ) -> anyhow::Result<()> {
1373        let control_page = lock_gpn_with_subrange(gm, in_gpadl.gpns()[0])?;
1374        if let Some(guest_to_host_event) = guest_to_host_event {
1375            if ring::reader_needs_signal(control_page.pages()[0]) {
1376                tracelimit::info_ratelimited!(channel = %channel.key, "waking host for incoming ring");
1377                guest_to_host_event.0.deliver();
1378            }
1379        }
1380
1381        let ring_size = gpadl_ring_size(&in_gpadl).try_into()?;
1382        if ring::writer_needs_signal(control_page.pages()[0], ring_size) {
1383            tracelimit::info_ratelimited!(channel = %channel.key, "waking guest for incoming ring");
1384            host_to_guest_interrupt.deliver();
1385        }
1386        Ok(())
1387    }
1388
1389    fn unstick_outgoing_ring(
1390        gm: &GuestMemory,
1391        channel: &Channel,
1392        out_gpadl: AlignedGpadlView,
1393        guest_to_host_event: Option<&ChannelEvent>,
1394        host_to_guest_interrupt: &Interrupt,
1395    ) -> anyhow::Result<()> {
1396        let control_page = lock_gpn_with_subrange(gm, out_gpadl.gpns()[0])?;
1397        if ring::reader_needs_signal(control_page.pages()[0]) {
1398            tracelimit::info_ratelimited!(channel = %channel.key, "waking guest for outgoing ring");
1399            host_to_guest_interrupt.deliver();
1400        }
1401
1402        if let Some(guest_to_host_event) = guest_to_host_event {
1403            let ring_size = gpadl_ring_size(&out_gpadl).try_into()?;
1404            if ring::writer_needs_signal(control_page.pages()[0], ring_size) {
1405                tracelimit::info_ratelimited!(channel = %channel.key, "waking host for outgoing ring");
1406                guest_to_host_event.0.deliver();
1407            }
1408        }
1409        Ok(())
1410    }
1411}
1412
1413impl Notifier for ServerTaskInner {
1414    fn notify(&mut self, offer_id: OfferId, action: channels::Action) {
1415        let channel = self
1416            .channels
1417            .get_mut(&offer_id)
1418            .expect("channel does not exist");
1419
1420        fn handle<I: 'static + Send, R: 'static + Send>(
1421            offer_id: OfferId,
1422            channel: &Channel,
1423            req: impl FnOnce(Rpc<I, R>) -> ChannelRequest,
1424            input: I,
1425            f: impl 'static + Send + FnOnce(R) -> ChannelResponse,
1426        ) -> Pin<Box<dyn Send + Future<Output = (OfferId, u64, Result<ChannelResponse, RpcError>)>>>
1427        {
1428            let recv = channel.send.call(req, input);
1429            let seq = channel.seq;
1430            Box::pin(async move {
1431                let r = recv.await.map(f);
1432                (offer_id, seq, r)
1433            })
1434        }
1435
1436        let response = match action {
1437            channels::Action::Open(open_params, version) => {
1438                let seq = channel.seq;
1439                match self.open_channel(offer_id, &open_params) {
1440                    Ok((channel, interrupt)) => handle(
1441                        offer_id,
1442                        channel,
1443                        ChannelRequest::Open,
1444                        OpenRequest::new(
1445                            open_params.open_data,
1446                            interrupt,
1447                            version.feature_flags,
1448                            channel.flags,
1449                        ),
1450                        ChannelResponse::Open,
1451                    ),
1452                    Err(err) => {
1453                        tracelimit::error_ratelimited!(
1454                            err = err.as_ref() as &dyn std::error::Error,
1455                            ?offer_id,
1456                            "could not open channel",
1457                        );
1458
1459                        // Return an error response to the channels module if the open_channel call
1460                        // failed.
1461                        Box::pin(future::ready((
1462                            offer_id,
1463                            seq,
1464                            Ok(ChannelResponse::Open(false)),
1465                        )))
1466                    }
1467                }
1468            }
1469            channels::Action::Close => {
1470                if let Some(channel_bitmap) = self.channel_bitmap.as_ref() {
1471                    if let ChannelState::Open(ref state) = channel.state {
1472                        channel_bitmap.unregister_channel(state.open_params.event_flag);
1473                    }
1474                }
1475
1476                channel.state = ChannelState::Closing;
1477                handle(offer_id, channel, ChannelRequest::Close, (), |()| {
1478                    ChannelResponse::Close
1479                })
1480            }
1481            channels::Action::Gpadl(gpadl_id, count, buf) => {
1482                channel.gpadls.add(
1483                    gpadl_id,
1484                    MultiPagedRangeBuf::new(count.into(), buf.clone()).unwrap(),
1485                );
1486                handle(
1487                    offer_id,
1488                    channel,
1489                    ChannelRequest::Gpadl,
1490                    GpadlRequest {
1491                        id: gpadl_id,
1492                        count,
1493                        buf,
1494                    },
1495                    move |r| ChannelResponse::Gpadl(gpadl_id, r),
1496                )
1497            }
1498            channels::Action::TeardownGpadl {
1499                gpadl_id,
1500                post_restore,
1501            } => {
1502                if !post_restore {
1503                    channel.gpadls.remove(gpadl_id, Box::new(|| ()));
1504                }
1505
1506                handle(
1507                    offer_id,
1508                    channel,
1509                    ChannelRequest::TeardownGpadl,
1510                    gpadl_id,
1511                    move |()| ChannelResponse::TeardownGpadl(gpadl_id),
1512                )
1513            }
1514            channels::Action::Modify { target_vp } => {
1515                if let ChannelState::Open(state) = &mut channel.state {
1516                    if let Err(err) = state.guest_event_port.set_target_vp(target_vp) {
1517                        tracelimit::error_ratelimited!(
1518                            error = &err as &dyn std::error::Error,
1519                            channel = %channel.key,
1520                            "could not modify channel",
1521                        );
1522                        let seq = channel.seq;
1523                        Box::pin(async move {
1524                            (
1525                                offer_id,
1526                                seq,
1527                                Ok(ChannelResponse::Modify(protocol::STATUS_UNSUCCESSFUL)),
1528                            )
1529                        })
1530                    } else {
1531                        handle(
1532                            offer_id,
1533                            channel,
1534                            ChannelRequest::Modify,
1535                            ModifyRequest::TargetVp { target_vp },
1536                            ChannelResponse::Modify,
1537                        )
1538                    }
1539                } else {
1540                    unreachable!();
1541                }
1542            }
1543        };
1544        self.channel_responses.push(response);
1545    }
1546
1547    fn modify_connection(&mut self, mut request: ModifyConnectionRequest) -> anyhow::Result<()> {
1548        self.map_interrupt_page(request.interrupt_page)
1549            .context("Failed to map interrupt page.")?;
1550
1551        self.set_monitor_page(request.monitor_page)
1552            .context("Failed to map monitor page.")?;
1553
1554        if let Some(vp) = request.target_message_vp {
1555            self.message_port.set_target_vp(vp)?;
1556        }
1557
1558        if request.notify_relay {
1559            // If this server is handling MNF, the monitor pages should not be relayed.
1560            // N.B. Since the relay is being asked not to update the monitor pages, rather than
1561            //      reset them, this is only safe because the value of enable_mnf won't change after
1562            //      the server has been created.
1563            if self.enable_mnf {
1564                request.monitor_page = Update::Unchanged;
1565            }
1566
1567            self.relay_send.send(request.into());
1568        }
1569
1570        Ok(())
1571    }
1572
1573    fn forward_unhandled(&mut self, request: InitiateContactRequest) {
1574        if let Some(external_server) = &self.external_server_send {
1575            external_server.send(request);
1576        } else {
1577            tracing::warn!(?request, "nowhere to forward unhandled request")
1578        }
1579    }
1580
1581    fn inspect(&self, version: Option<VersionInfo>, offer_id: OfferId, req: inspect::Request<'_>) {
1582        let channel = self.channels.get(&offer_id).expect("should exist");
1583        let mut resp = req.respond();
1584        if let ChannelState::Open(state) = &channel.state {
1585            let mem = self.get_gm_for_channel(version.expect("must be connected"), channel);
1586            inspect_rings(
1587                &mut resp,
1588                mem,
1589                channel.gpadls.clone(),
1590                &state.open_params.open_data,
1591            );
1592        }
1593    }
1594
1595    fn send_message(&mut self, message: &OutgoingMessage, target: MessageTarget) -> bool {
1596        // If the server is paused, queue all messages, to avoid affecting synic
1597        // state during/after it has been saved or reset.
1598        //
1599        // Note that messages to reserved channels or custom targets will be
1600        // dropped. However, such messages should only be sent in response to
1601        // guest requests, which should not be processed while the server is
1602        // paused.
1603        //
1604        // FUTURE: it would be better to ensure that no messages are generated
1605        // by operations that run while the server is paused. E.g., defer
1606        // sending offer or revoke messages for new or revoked offers. This
1607        // would prevent the queue from growing without bound.
1608        if !self.running && !self.send_messages_while_stopped {
1609            if !matches!(target, MessageTarget::Default) {
1610                tracelimit::error_ratelimited!(?target, "dropping message while paused");
1611            }
1612            return false;
1613        }
1614
1615        let mut port_storage;
1616        let port = match target {
1617            MessageTarget::Default => self.message_port.as_mut(),
1618            MessageTarget::ReservedChannel(offer_id, target) => {
1619                if let Some(port) = self.get_reserved_channel_message_port(offer_id, target) {
1620                    port.as_mut()
1621                } else {
1622                    // Updating the port failed, so there is no way to send the message.
1623                    return true;
1624                }
1625            }
1626            MessageTarget::Custom(target) => {
1627                port_storage = match self.synic.new_guest_message_port(
1628                    self.redirect_vtl,
1629                    target.vp,
1630                    target.sint,
1631                ) {
1632                    Ok(port) => port,
1633                    Err(err) => {
1634                        tracing::error!(
1635                            ?err,
1636                            ?self.redirect_vtl,
1637                            ?target,
1638                            "could not create message port"
1639                        );
1640
1641                        // There is no way to send the message.
1642                        return true;
1643                    }
1644                };
1645                port_storage.as_mut()
1646            }
1647        };
1648
1649        // If this returns Pending, the channels module will queue the message and the ServerTask
1650        // main loop will try to send it again later.
1651        matches!(
1652            port.poll_post_message(
1653                &mut std::task::Context::from_waker(std::task::Waker::noop()),
1654                VMBUS_MESSAGE_TYPE,
1655                message.data()
1656            ),
1657            Poll::Ready(())
1658        )
1659    }
1660
1661    fn notify_hvsock(&mut self, request: &HvsockConnectRequest) {
1662        self.hvsock_requests += 1;
1663        self.hvsock_send.send(*request);
1664    }
1665
1666    fn reset_complete(&mut self) {
1667        if let Some(monitor) = self.synic.monitor_support() {
1668            if let Err(err) = monitor.set_monitor_page(self.vtl, None) {
1669                tracing::warn!(?err, "resetting monitor page failed")
1670            }
1671        }
1672
1673        self.unreserve_channels();
1674        for done in self.reset_done.drain(..) {
1675            done.complete(());
1676        }
1677    }
1678
1679    fn unload_complete(&mut self) {
1680        self.unreserve_channels();
1681    }
1682}
1683
1684impl ServerTaskInner {
1685    fn open_channel(
1686        &mut self,
1687        offer_id: OfferId,
1688        open_params: &OpenParams,
1689    ) -> anyhow::Result<(&mut Channel, Interrupt)> {
1690        let channel = self
1691            .channels
1692            .get_mut(&offer_id)
1693            .expect("channel does not exist");
1694
1695        // Always register with the channel bitmap; if Win7, this may be unnecessary.
1696        if let Some(channel_bitmap) = self.channel_bitmap.as_ref() {
1697            channel_bitmap.register_channel(
1698                open_params.event_flag,
1699                channel.guest_to_host_event.0.clone(),
1700            );
1701        }
1702        // Always set up an event port; if V1, this will be unused.
1703        // N.B. The event port must be created before the device is notified of the open by the
1704        //      caller. The device may begin communicating with the guest immediately when it is
1705        //      notified, so the event port must exist so that the guest can send interrupts.
1706        let event_port = self
1707            .synic
1708            .add_event_port(
1709                open_params.connection_id,
1710                self.vtl,
1711                channel.guest_to_host_event.clone(),
1712                open_params.monitor_info,
1713            )
1714            .context("failed to create guest-to-host event port")?;
1715
1716        // For pre-Win8 guests, the host-to-guest event always targets vp 0 and the channel
1717        // bitmap is used instead of the event flag.
1718        let (target_vp, event_flag) = if self.channel_bitmap.is_some() {
1719            (0, 0)
1720        } else {
1721            (open_params.open_data.target_vp, open_params.event_flag)
1722        };
1723        let (target_vtl, target_sint) = if open_params.flags.redirect_interrupt() {
1724            (self.redirect_vtl, self.redirect_sint)
1725        } else {
1726            (self.vtl, SINT)
1727        };
1728
1729        let guest_event_port = self.synic.new_guest_event_port(
1730            VmbusServer::get_child_event_port_id(open_params.channel_id, SINT, self.vtl),
1731            target_vtl,
1732            target_vp,
1733            target_sint,
1734            event_flag,
1735            open_params.monitor_info,
1736        )?;
1737
1738        let interrupt = ChannelBitmap::create_interrupt(
1739            &self.channel_bitmap,
1740            guest_event_port.interrupt(),
1741            open_params.event_flag,
1742        );
1743
1744        // Delete any previously reserved state.
1745        channel.reserved_state.message_port = None;
1746
1747        // If the channel is reserved, create a message port for it.
1748        if let Some(target) = open_params.reserved_target {
1749            channel.reserved_state.message_port = Some(self.synic.new_guest_message_port(
1750                self.redirect_vtl,
1751                target.vp,
1752                target.sint,
1753            )?);
1754
1755            channel.reserved_state.target = target;
1756        }
1757
1758        channel.state = ChannelState::Open(Box::new(ChannelOpenState {
1759            open_params: *open_params,
1760            _event_port: event_port,
1761            guest_event_port,
1762            host_to_guest_interrupt: interrupt.clone(),
1763        }));
1764        Ok((channel, interrupt))
1765    }
1766
1767    /// If the client specified an interrupt page, map it into host memory and
1768    /// set up the shared event port.
1769    fn map_interrupt_page(&mut self, interrupt_page: Update<u64>) -> anyhow::Result<()> {
1770        let interrupt_page = match interrupt_page {
1771            Update::Unchanged => return Ok(()),
1772            Update::Reset => {
1773                self.channel_bitmap = None;
1774                self.shared_event_port = None;
1775                return Ok(());
1776            }
1777            Update::Set(interrupt_page) => interrupt_page,
1778        };
1779
1780        assert_ne!(interrupt_page, 0);
1781
1782        if interrupt_page % PAGE_SIZE as u64 != 0 {
1783            anyhow::bail!("interrupt page {:#x} is not page aligned", interrupt_page);
1784        }
1785
1786        // Use a subrange to access the interrupt page to give GuestMemory's without a full mapping
1787        // a chance to create one.
1788        let interrupt_page = lock_page_with_subrange(&self.gm, interrupt_page)?;
1789        let channel_bitmap = Arc::new(ChannelBitmap::new(interrupt_page));
1790        self.channel_bitmap = Some(channel_bitmap.clone());
1791
1792        // Create the shared event port for pre-Win8 guests.
1793        let interrupt = Interrupt::from_fn(move || {
1794            channel_bitmap.handle_shared_interrupt();
1795        });
1796
1797        self.shared_event_port = Some(self.synic.add_event_port(
1798            SHARED_EVENT_CONNECTION_ID,
1799            self.vtl,
1800            Arc::new(ChannelEvent(interrupt)),
1801            None,
1802        )?);
1803
1804        Ok(())
1805    }
1806
1807    fn set_monitor_page(&mut self, monitor_page: Update<MonitorPageGpas>) -> anyhow::Result<()> {
1808        let monitor_page = match monitor_page {
1809            Update::Unchanged => return Ok(()),
1810            Update::Reset => None,
1811            Update::Set(value) => Some(value),
1812        };
1813
1814        // TODO: can this check be moved into channels.rs?
1815        if self.channels.iter().any(|(_, c)| {
1816            matches!(
1817                &c.state,
1818                ChannelState::Open(state) if state.open_params.monitor_info.is_some()
1819            )
1820        }) {
1821            anyhow::bail!("attempt to change monitor page while open channels using mnf");
1822        }
1823
1824        if self.enable_mnf {
1825            if let Some(monitor) = self.synic.monitor_support() {
1826                if let Err(err) = monitor.set_monitor_page(self.vtl, monitor_page) {
1827                    anyhow::bail!(
1828                        "setting monitor page failed, err = {err:?}, monitor_page = {monitor_page:?}"
1829                    );
1830                }
1831            }
1832        }
1833
1834        Ok(())
1835    }
1836
1837    fn get_reserved_channel_message_port(
1838        &mut self,
1839        offer_id: OfferId,
1840        new_target: ConnectionTarget,
1841    ) -> Option<&mut Box<dyn GuestMessagePort>> {
1842        let channel = self
1843            .channels
1844            .get_mut(&offer_id)
1845            .expect("channel does not exist");
1846
1847        assert!(
1848            channel.reserved_state.message_port.is_some(),
1849            "channel is not reserved"
1850        );
1851
1852        // On close, the guest may have changed the message target it wants to use for the close
1853        // response. If so, update the message port.
1854        if channel.reserved_state.target.sint != new_target.sint {
1855            // Destroy the old port before creating the new one.
1856            channel.reserved_state.message_port = None;
1857            let message_port = self
1858                .synic
1859                .new_guest_message_port(self.redirect_vtl, new_target.vp, new_target.sint)
1860                .inspect_err(|err| {
1861                    tracing::error!(
1862                        ?err,
1863                        ?self.redirect_vtl,
1864                        ?new_target,
1865                        "could not create reserved channel message port"
1866                    )
1867                })
1868                .ok()?;
1869
1870            channel.reserved_state.message_port = Some(message_port);
1871            channel.reserved_state.target = new_target;
1872        } else if channel.reserved_state.target.vp != new_target.vp {
1873            let message_port = channel.reserved_state.message_port.as_mut().unwrap();
1874
1875            // The vp has changed, but the SINT is the same. Just update the vp. If this fails,
1876            // ignore it and just send to the old vp.
1877            if let Err(err) = message_port.set_target_vp(new_target.vp) {
1878                tracing::error!(
1879                    ?err,
1880                    ?self.redirect_vtl,
1881                    ?new_target,
1882                    "could not update reserved channel message port"
1883                );
1884            }
1885
1886            channel.reserved_state.target = new_target;
1887            return Some(message_port);
1888        }
1889
1890        Some(channel.reserved_state.message_port.as_mut().unwrap())
1891    }
1892
1893    fn unreserve_channels(&mut self) {
1894        // Unreserve all closed channels.
1895        for channel in self.channels.values_mut() {
1896            if let ChannelState::Closed = channel.state {
1897                channel.reserved_state.message_port = None;
1898            }
1899        }
1900    }
1901
1902    fn get_gm_for_channel(&self, version: VersionInfo, channel: &Channel) -> &GuestMemory {
1903        select_gm_for_channel(&self.gm, self.private_gm.as_ref(), version, channel)
1904    }
1905}
1906
1907fn select_gm_for_channel<'a>(
1908    gm: &'a GuestMemory,
1909    private_gm: Option<&'a GuestMemory>,
1910    version: VersionInfo,
1911    channel: &Channel,
1912) -> &'a GuestMemory {
1913    if channel.flags.confidential_ring_buffer() && version.feature_flags.confidential_channels() {
1914        if let Some(private_gm) = private_gm {
1915            return private_gm;
1916        }
1917    }
1918
1919    gm
1920}
1921
1922/// Control point for [`VmbusServer`], allowing callers to offer channels.
1923#[derive(Clone)]
1924pub struct VmbusServerControl {
1925    mem: GuestMemory,
1926    private_mem: Option<GuestMemory>,
1927    send: mesh::Sender<OfferRequest>,
1928    use_event: bool,
1929    force_confidential_external_memory: bool,
1930}
1931
1932impl VmbusServerControl {
1933    /// Offers a channel to the vmbus server, where the flags and user_defined data are already set.
1934    /// This is used by the relay to forward the host's parameters.
1935    pub async fn offer_core(&self, offer_info: OfferInfo) -> anyhow::Result<OfferResources> {
1936        let flags = offer_info.params.flags;
1937        self.send
1938            .call_failable(OfferRequest::Offer, offer_info)
1939            .await?;
1940        Ok(OfferResources::new(
1941            self.mem.clone(),
1942            if flags.confidential_ring_buffer() || flags.confidential_external_memory() {
1943                self.private_mem.clone()
1944            } else {
1945                None
1946            },
1947        ))
1948    }
1949
1950    /// Force reset all channels and protocol state, without requiring the
1951    /// server to be paused.
1952    pub async fn force_reset(&self) -> anyhow::Result<()> {
1953        self.send
1954            .call(OfferRequest::ForceReset, ())
1955            .await
1956            .context("vmbus server is gone")
1957    }
1958
1959    async fn offer(&self, request: OfferInput) -> anyhow::Result<OfferResources> {
1960        let mut offer_info = OfferInfo {
1961            params: request.params.into(),
1962            event: request.event,
1963            request_send: request.request_send,
1964            server_request_recv: request.server_request_recv,
1965        };
1966
1967        if self.force_confidential_external_memory {
1968            tracing::warn!(
1969                key = %offer_info.params.key(),
1970                "forcing confidential external memory for channel"
1971            );
1972
1973            offer_info
1974                .params
1975                .flags
1976                .set_confidential_external_memory(true);
1977        }
1978
1979        self.offer_core(offer_info).await
1980    }
1981}
1982
1983/// Inspects the specified ring buffer state by directly accessing guest memory.
1984fn inspect_rings(
1985    resp: &mut inspect::Response<'_>,
1986    gm: &GuestMemory,
1987    gpadl_map: Arc<GpadlMap>,
1988    open_data: &OpenData,
1989) -> Option<()> {
1990    let gpadl = gpadl_map
1991        .view()
1992        .map(GpadlId(open_data.ring_gpadl_id.0))
1993        .ok()?;
1994
1995    let aligned = AlignedGpadlView::new(gpadl).ok()?;
1996    let (in_gpadl, out_gpadl) = aligned.split(open_data.ring_offset).ok()?;
1997    resp.child("incoming_ring", |req| inspect_ring(req, &in_gpadl, gm));
1998    resp.child("outgoing_ring", |req| inspect_ring(req, &out_gpadl, gm));
1999    Some(())
2000}
2001
2002/// Inspects the incoming or outgoing ring buffer by directly accessing guest memory.
2003fn inspect_ring(req: inspect::Request<'_>, gpadl: &AlignedGpadlView, gm: &GuestMemory) {
2004    let mut resp = req.respond();
2005
2006    resp.hex("ring_size", gpadl_ring_size(gpadl));
2007
2008    // Lock just the control page. Use a subrange to allow a GuestMemory without a full mapping to
2009    // create one.
2010    if let Ok(pages) = lock_gpn_with_subrange(gm, gpadl.gpns()[0]) {
2011        ring::inspect_ring(pages.pages()[0], &mut resp);
2012    }
2013}
2014
2015fn gpadl_ring_size(gpadl: &AlignedGpadlView) -> usize {
2016    // Data size excluding the control page.
2017    (gpadl.gpns().len() - 1) * PAGE_SIZE
2018}
2019
2020/// Helper to create a subrange before locking a single page.
2021///
2022/// This allows us to lock a page in a `GuestMemory` that doesn't have a full mapping, but can
2023/// create one for a subrange.
2024fn lock_page_with_subrange(gm: &GuestMemory, offset: u64) -> anyhow::Result<guestmem::LockedPages> {
2025    Ok(gm
2026        .lockable_subrange(offset, PAGE_SIZE as u64)?
2027        .lock_gpns(false, &[0])?)
2028}
2029
2030/// Helper to create a subrange before locking a single page from a gpn.
2031///
2032/// This allows us to lock a page in a `GuestMemory` that doesn't have a full mapping, but can
2033/// create one for a subrange.
2034fn lock_gpn_with_subrange(gm: &GuestMemory, gpn: u64) -> anyhow::Result<guestmem::LockedPages> {
2035    lock_page_with_subrange(gm, gpn * PAGE_SIZE as u64)
2036}
2037
2038pub(crate) struct MessageSender {
2039    send: mpsc::Sender<SynicMessage>,
2040    multiclient: bool,
2041}
2042
2043impl MessageSender {
2044    fn poll_handle_message(
2045        &self,
2046        cx: &mut std::task::Context<'_>,
2047        msg: &[u8],
2048        trusted: bool,
2049    ) -> Poll<Result<(), SendError>> {
2050        let mut send = self.send.clone();
2051        ready!(send.poll_ready(cx))?;
2052        send.start_send(SynicMessage {
2053            data: msg.to_vec(),
2054            multiclient: self.multiclient,
2055            trusted,
2056        })?;
2057
2058        Poll::Ready(Ok(()))
2059    }
2060}
2061
2062impl MessagePort for MessageSender {
2063    fn poll_handle_message(
2064        &self,
2065        cx: &mut std::task::Context<'_>,
2066        msg: &[u8],
2067        trusted: bool,
2068    ) -> Poll<()> {
2069        if let Err(err) = ready!(self.poll_handle_message(cx, msg, trusted)) {
2070            tracelimit::error_ratelimited!(
2071                error = &err as &dyn std::error::Error,
2072                "failed to send message"
2073            );
2074        }
2075
2076        Poll::Ready(())
2077    }
2078}
2079
2080#[async_trait]
2081impl ParentBus for VmbusServerControl {
2082    async fn add_child(&self, request: OfferInput) -> anyhow::Result<OfferResources> {
2083        self.offer(request).await
2084    }
2085
2086    fn clone_bus(&self) -> Box<dyn ParentBus> {
2087        Box::new(self.clone())
2088    }
2089
2090    fn use_event(&self) -> bool {
2091        self.use_event
2092    }
2093}
2094
2095#[cfg(test)]
2096mod tests {
2097    use super::*;
2098    use inspect::InspectMut;
2099    use mesh::CancelReason;
2100    use pal_async::DefaultDriver;
2101    use pal_async::async_test;
2102    use pal_async::timer::Instant;
2103    use pal_async::timer::PolledTimer;
2104    use parking_lot::Mutex;
2105    use protocol::UserDefinedData;
2106    use std::time::Duration;
2107    use test_with_tracing::test;
2108    use vmbus_channel::bus::OfferParams;
2109    use vmbus_channel::channel::ChannelOpenError;
2110    use vmbus_channel::channel::DeviceResources;
2111    use vmbus_channel::channel::SaveRestoreVmbusDevice;
2112    use vmbus_channel::channel::VmbusDevice;
2113    use vmbus_channel::channel::offer_channel;
2114    use vmbus_core::protocol::ChannelId;
2115    use vmbus_core::protocol::VmbusMessage;
2116    use vmcore::synic::MonitorInfo;
2117    use vmcore::synic::SynicPortAccess;
2118    use zerocopy::FromBytes;
2119    use zerocopy::Immutable;
2120    use zerocopy::IntoBytes;
2121    use zerocopy::KnownLayout;
2122
2123    struct MockSynicInner {
2124        message_port: Option<Arc<dyn MessagePort>>,
2125    }
2126
2127    struct MockSynic {
2128        inner: Mutex<MockSynicInner>,
2129        message_send: mesh::Sender<Vec<u8>>,
2130        spawner: DefaultDriver,
2131    }
2132
2133    impl MockSynic {
2134        fn new(message_send: mesh::Sender<Vec<u8>>, spawner: DefaultDriver) -> Self {
2135            Self {
2136                inner: Mutex::new(MockSynicInner { message_port: None }),
2137                message_send,
2138                spawner,
2139            }
2140        }
2141
2142        fn send_message(&self, msg: impl VmbusMessage + IntoBytes + Immutable + KnownLayout) {
2143            self.send_message_core(OutgoingMessage::new(&msg), false);
2144        }
2145
2146        fn send_message_trusted(
2147            &self,
2148            msg: impl VmbusMessage + IntoBytes + Immutable + KnownLayout,
2149        ) {
2150            self.send_message_core(OutgoingMessage::new(&msg), true);
2151        }
2152
2153        fn send_message_core(&self, msg: OutgoingMessage, trusted: bool) {
2154            assert_eq!(
2155                self.inner
2156                    .lock()
2157                    .message_port
2158                    .as_ref()
2159                    .unwrap()
2160                    .poll_handle_message(
2161                        &mut std::task::Context::from_waker(std::task::Waker::noop()),
2162                        msg.data(),
2163                        trusted,
2164                    ),
2165                Poll::Ready(())
2166            );
2167        }
2168    }
2169
2170    #[derive(Debug)]
2171    struct MockGuestPort {}
2172
2173    impl GuestEventPort for MockGuestPort {
2174        fn interrupt(&self) -> Interrupt {
2175            Interrupt::null()
2176        }
2177
2178        fn set_target_vp(&mut self, _vp: u32) -> Result<(), vmcore::synic::HypervisorError> {
2179            Ok(())
2180        }
2181    }
2182
2183    struct MockGuestMessagePort {
2184        send: mesh::Sender<Vec<u8>>,
2185        spawner: DefaultDriver,
2186        timer: Option<(PolledTimer, Instant)>,
2187    }
2188
2189    impl GuestMessagePort for MockGuestMessagePort {
2190        fn poll_post_message(
2191            &mut self,
2192            cx: &mut std::task::Context<'_>,
2193            _typ: u32,
2194            payload: &[u8],
2195        ) -> Poll<()> {
2196            if let Some((timer, deadline)) = self.timer.as_mut() {
2197                ready!(timer.sleep_until(*deadline).poll_unpin(cx));
2198                self.timer = None;
2199            }
2200
2201            // Return pending 25% of the time.
2202            let mut pending_chance = [0; 1];
2203            getrandom::fill(&mut pending_chance).unwrap();
2204            if pending_chance[0] % 4 == 0 {
2205                let mut timer = PolledTimer::new(&self.spawner);
2206                let deadline = Instant::now() + Duration::from_millis(10);
2207                match timer.sleep_until(deadline).poll_unpin(cx) {
2208                    Poll::Ready(_) => {}
2209                    Poll::Pending => {
2210                        self.timer = Some((timer, deadline));
2211                        return Poll::Pending;
2212                    }
2213                }
2214            }
2215
2216            self.send.send(payload.into());
2217            Poll::Ready(())
2218        }
2219
2220        fn set_target_vp(&mut self, _vp: u32) -> Result<(), vmcore::synic::HypervisorError> {
2221            Ok(())
2222        }
2223    }
2224
2225    impl Inspect for MockGuestMessagePort {
2226        fn inspect(&self, _req: inspect::Request<'_>) {}
2227    }
2228
2229    impl SynicPortAccess for MockSynic {
2230        fn add_message_port(
2231            &self,
2232            connection_id: u32,
2233            _minimum_vtl: Vtl,
2234            port: Arc<dyn MessagePort>,
2235        ) -> Result<Box<dyn Sync + Send>, vmcore::synic::Error> {
2236            self.inner.lock().message_port = Some(port);
2237            Ok(Box::new(connection_id))
2238        }
2239
2240        fn add_event_port(
2241            &self,
2242            connection_id: u32,
2243            _minimum_vtl: Vtl,
2244            _port: Arc<dyn EventPort>,
2245            _monitor_info: Option<MonitorInfo>,
2246        ) -> Result<Box<dyn Sync + Send>, vmcore::synic::Error> {
2247            Ok(Box::new(connection_id))
2248        }
2249
2250        fn new_guest_message_port(
2251            &self,
2252            _vtl: Vtl,
2253            _vp: u32,
2254            _sint: u8,
2255        ) -> Result<Box<(dyn GuestMessagePort)>, vmcore::synic::HypervisorError> {
2256            Ok(Box::new(MockGuestMessagePort {
2257                send: self.message_send.clone(),
2258                spawner: self.spawner.clone(),
2259                timer: None,
2260            }))
2261        }
2262
2263        fn new_guest_event_port(
2264            &self,
2265            _port_id: u32,
2266            _vtl: Vtl,
2267            _vp: u32,
2268            _sint: u8,
2269            _flag: u16,
2270            _monitor_info: Option<MonitorInfo>,
2271        ) -> Result<Box<(dyn GuestEventPort)>, vmcore::synic::HypervisorError> {
2272            Ok(Box::new(MockGuestPort {}))
2273        }
2274
2275        fn prefer_os_events(&self) -> bool {
2276            false
2277        }
2278    }
2279
2280    struct TestChannel {
2281        request_recv: mesh::Receiver<ChannelRequest>,
2282        server_request_send: mesh::Sender<ChannelServerRequest>,
2283        _resources: OfferResources,
2284    }
2285
2286    impl TestChannel {
2287        async fn next_request(&mut self) -> ChannelRequest {
2288            self.request_recv.next().await.unwrap()
2289        }
2290
2291        async fn handle_gpadl(&mut self) {
2292            let ChannelRequest::Gpadl(rpc) = self.next_request().await else {
2293                panic!("Wrong request");
2294            };
2295
2296            rpc.complete(true);
2297        }
2298
2299        async fn handle_open(&mut self, f: fn(&OpenRequest)) {
2300            let ChannelRequest::Open(rpc) = self.next_request().await else {
2301                panic!("Wrong request");
2302            };
2303
2304            f(rpc.input());
2305            rpc.complete(true);
2306        }
2307
2308        async fn handle_gpadl_teardown(&mut self) {
2309            let rpc = self.get_gpadl_teardown().await;
2310            rpc.complete(());
2311        }
2312
2313        async fn get_gpadl_teardown(&mut self) -> Rpc<GpadlId, ()> {
2314            let ChannelRequest::TeardownGpadl(rpc) = self.next_request().await else {
2315                panic!("Wrong request");
2316            };
2317
2318            rpc
2319        }
2320
2321        async fn restore(&self) {
2322            self.server_request_send
2323                .call(ChannelServerRequest::Restore, false)
2324                .await
2325                .unwrap()
2326                .unwrap();
2327        }
2328    }
2329
2330    struct TestEnv {
2331        vmbus: VmbusServer,
2332        synic: Arc<MockSynic>,
2333        message_recv: mesh::Receiver<Vec<u8>>,
2334        trusted: bool,
2335    }
2336
2337    impl TestEnv {
2338        fn new(spawner: DefaultDriver) -> Self {
2339            let (message_send, message_recv) = mesh::channel();
2340            let synic = Arc::new(MockSynic::new(message_send, spawner.clone()));
2341            let gm = GuestMemory::empty();
2342            let vmbus = VmbusServerBuilder::new(spawner, synic.clone(), gm)
2343                .build()
2344                .unwrap();
2345
2346            Self {
2347                vmbus,
2348                synic,
2349                message_recv,
2350                trusted: false,
2351            }
2352        }
2353
2354        async fn offer(&self, id: u32, allow_confidential_external_memory: bool) -> TestChannel {
2355            let guid = Guid {
2356                data1: id,
2357                ..Guid::ZERO
2358            };
2359            let (request_send, request_recv) = mesh::channel();
2360            let (server_request_send, server_request_recv) = mesh::channel();
2361            let offer = OfferInput {
2362                event: Interrupt::from_fn(|| {}),
2363                request_send,
2364                server_request_recv,
2365                params: OfferParams {
2366                    interface_name: "test".into(),
2367                    instance_id: guid,
2368                    interface_id: guid,
2369                    mmio_megabytes: 0,
2370                    mmio_megabytes_optional: 0,
2371                    channel_type: vmbus_channel::bus::ChannelType::Device {
2372                        pipe_packets: false,
2373                    },
2374                    subchannel_index: 0,
2375                    mnf_interrupt_latency: None,
2376                    offer_order: None,
2377                    allow_confidential_external_memory,
2378                },
2379            };
2380
2381            let control = self.vmbus.control();
2382            let _resources = control.add_child(offer).await.unwrap();
2383
2384            TestChannel {
2385                request_recv,
2386                server_request_send,
2387                _resources,
2388            }
2389        }
2390
2391        async fn gpadl(&mut self, channel_id: u32, gpadl_id: u32, channel: &mut TestChannel) {
2392            self.synic.send_message_core(
2393                OutgoingMessage::with_data(
2394                    &protocol::GpadlHeader {
2395                        channel_id: ChannelId(channel_id),
2396                        gpadl_id: GpadlId(gpadl_id),
2397                        count: 1,
2398                        len: 16,
2399                    },
2400                    [1u64, 0u64].as_bytes(),
2401                ),
2402                self.trusted,
2403            );
2404
2405            channel.handle_gpadl().await;
2406            self.expect_response(protocol::MessageType::GPADL_CREATED)
2407                .await;
2408        }
2409
2410        async fn open_channel(
2411            &mut self,
2412            channel_id: u32,
2413            ring_gpadl_id: u32,
2414            channel: &mut TestChannel,
2415            f: fn(&OpenRequest),
2416        ) {
2417            self.gpadl(channel_id, ring_gpadl_id, channel).await;
2418            self.synic.send_message_core(
2419                OutgoingMessage::new(&protocol::OpenChannel {
2420                    channel_id: ChannelId(channel_id),
2421                    open_id: 0,
2422                    ring_buffer_gpadl_id: GpadlId(ring_gpadl_id),
2423                    target_vp: 0,
2424                    downstream_ring_buffer_page_offset: 0,
2425                    user_data: UserDefinedData::default(),
2426                }),
2427                self.trusted,
2428            );
2429
2430            channel.handle_open(f).await;
2431            self.expect_response(protocol::MessageType::OPEN_CHANNEL_RESULT)
2432                .await;
2433        }
2434
2435        async fn expect_response(&mut self, expected: protocol::MessageType) {
2436            let data = self.message_recv.next().await.unwrap();
2437            let header = protocol::MessageHeader::read_from_prefix(&data).unwrap().0; // TODO: zerocopy: use-rest-of-range (https://github.com/microsoft/openvmm/issues/759)
2438            assert_eq!(expected, header.message_type())
2439        }
2440
2441        async fn get_response<T: VmbusMessage + FromBytes + Immutable + KnownLayout>(
2442            &mut self,
2443        ) -> T {
2444            let data = self.message_recv.next().await.unwrap();
2445            let (header, message) = protocol::MessageHeader::read_from_prefix(&data).unwrap(); // TODO: zerocopy: unwrap (https://github.com/microsoft/openvmm/issues/759)
2446            assert_eq!(T::MESSAGE_TYPE, header.message_type());
2447            T::read_from_prefix(message).unwrap().0 // TODO: zerocopy: use-rest-of-range (https://github.com/microsoft/openvmm/issues/759)
2448        }
2449
2450        fn initiate_contact(
2451            &mut self,
2452            version: protocol::Version,
2453            feature_flags: protocol::FeatureFlags,
2454            trusted: bool,
2455        ) {
2456            self.synic.send_message_core(
2457                OutgoingMessage::new(&protocol::InitiateContact {
2458                    version_requested: version as u32,
2459                    target_message_vp: 0,
2460                    child_to_parent_monitor_page_gpa: 0,
2461                    parent_to_child_monitor_page_gpa: 0,
2462                    interrupt_page_or_target_info: protocol::TargetInfo::new()
2463                        .with_sint(2)
2464                        .with_vtl(0)
2465                        .with_feature_flags(feature_flags.into())
2466                        .into(),
2467                }),
2468                trusted,
2469            );
2470
2471            self.trusted = trusted;
2472        }
2473
2474        async fn connect(
2475            &mut self,
2476            offer_count: u32,
2477            feature_flags: protocol::FeatureFlags,
2478            trusted: bool,
2479        ) {
2480            self.initiate_contact(protocol::Version::Copper, feature_flags, trusted);
2481
2482            self.expect_response(protocol::MessageType::VERSION_RESPONSE)
2483                .await;
2484
2485            self.synic
2486                .send_message_core(OutgoingMessage::new(&protocol::RequestOffers {}), trusted);
2487
2488            for _ in 0..offer_count {
2489                self.expect_response(protocol::MessageType::OFFER_CHANNEL)
2490                    .await;
2491            }
2492
2493            self.expect_response(protocol::MessageType::ALL_OFFERS_DELIVERED)
2494                .await;
2495        }
2496    }
2497
2498    #[async_test]
2499    async fn test_save_restore(spawner: DefaultDriver) {
2500        // Most save/restore state is tested in mod channels::tests; this test specifically checks
2501        // that ServerTaskInner correctly handles some aspects of the save/restore.
2502        //
2503        // If this test fails, it is more likely to hang than panic.
2504        let mut env = TestEnv::new(spawner);
2505        let mut channel = env.offer(1, false).await;
2506        env.vmbus.start();
2507        env.connect(1, protocol::FeatureFlags::new(), false).await;
2508
2509        // Create a GPADL for the channel.
2510        env.gpadl(1, 10, &mut channel).await;
2511
2512        // Start tearing it down.
2513        env.synic.send_message(protocol::GpadlTeardown {
2514            channel_id: ChannelId(1),
2515            gpadl_id: GpadlId(10),
2516        });
2517
2518        // Wait for the teardown request here to make sure the server has processed the teardown
2519        // message, but do not complete it before saving.
2520        let rpc = channel.get_gpadl_teardown().await;
2521        env.vmbus.stop().await;
2522        let saved_state = env.vmbus.save().await;
2523        env.vmbus.start();
2524
2525        // Finish tearing down the gpadl and release the channel so the server can reset.
2526        rpc.complete(());
2527        env.expect_response(protocol::MessageType::GPADL_TORNDOWN)
2528            .await;
2529
2530        env.synic.send_message(protocol::RelIdReleased {
2531            channel_id: ChannelId(1),
2532        });
2533
2534        env.vmbus.reset().await;
2535        env.vmbus.stop().await;
2536
2537        // When restoring with a gpadl in the TearingDown state, the teardown request for the device
2538        // will be repeated. This must not panic.
2539        env.vmbus.restore(saved_state).await.unwrap();
2540        channel.restore().await;
2541        env.vmbus.start();
2542
2543        // Handle the teardown after restore.
2544        channel.handle_gpadl_teardown().await;
2545        env.expect_response(protocol::MessageType::GPADL_TORNDOWN)
2546            .await;
2547
2548        env.synic.send_message(protocol::RelIdReleased {
2549            channel_id: ChannelId(1),
2550        });
2551    }
2552
2553    struct TestDeviceState {
2554        id: u32,
2555        started: bool,
2556        resources: Option<DeviceResources>,
2557        open_requests: HashMap<u16, OpenRequest>,
2558        target_vps: HashMap<u16, u32>,
2559    }
2560
2561    impl TestDeviceState {
2562        pub fn id(this: &Arc<Mutex<Self>>) -> u32 {
2563            this.lock().id
2564        }
2565
2566        pub fn started(this: &Arc<Mutex<Self>>) -> bool {
2567            this.lock().started
2568        }
2569        pub fn set_started(this: &Arc<Mutex<Self>>, started: bool) {
2570            this.lock().started = started;
2571        }
2572
2573        pub fn open_request(this: &Arc<Mutex<Self>>, channel_idx: u16) -> Option<OpenRequest> {
2574            this.lock().open_requests.get(&channel_idx).cloned()
2575        }
2576        pub fn set_open_request(
2577            this: &Arc<Mutex<Self>>,
2578            channel_idx: u16,
2579            open_request: OpenRequest,
2580        ) {
2581            assert!(
2582                this.lock()
2583                    .open_requests
2584                    .insert(channel_idx, open_request)
2585                    .is_none()
2586            );
2587        }
2588        pub fn remove_open_request(
2589            this: &Arc<Mutex<Self>>,
2590            channel_idx: u16,
2591        ) -> Option<OpenRequest> {
2592            this.lock().open_requests.remove(&channel_idx)
2593        }
2594
2595        pub fn target_vp(this: &Arc<Mutex<Self>>, channel_idx: u16) -> Option<u32> {
2596            this.lock().target_vps.get(&channel_idx).copied()
2597        }
2598        pub fn set_target_vp(this: &Arc<Mutex<Self>>, channel_idx: u16, target_vp: u32) {
2599            let _ = this.lock().target_vps.insert(channel_idx, target_vp);
2600        }
2601    }
2602
2603    #[derive(InspectMut)]
2604    struct TestDevice {
2605        #[inspect(skip)]
2606        pub state: Arc<Mutex<TestDeviceState>>,
2607    }
2608
2609    impl TestDevice {
2610        pub fn new_and_state(id: u32) -> (Self, Arc<Mutex<TestDeviceState>>) {
2611            let state = TestDeviceState {
2612                id,
2613                resources: None,
2614                open_requests: HashMap::new(),
2615                target_vps: HashMap::new(),
2616                started: false,
2617            };
2618            let state = Arc::new(Mutex::new(state));
2619            let this = Self {
2620                state: state.clone(),
2621            };
2622            (this, state)
2623        }
2624    }
2625
2626    #[async_trait]
2627    impl VmbusDevice for TestDevice {
2628        fn offer(&self) -> OfferParams {
2629            let guid = Guid {
2630                data1: TestDeviceState::id(&self.state),
2631                ..Guid::ZERO
2632            };
2633
2634            OfferParams {
2635                interface_name: "test".into(),
2636                instance_id: guid,
2637                interface_id: guid,
2638                channel_type: vmbus_channel::bus::ChannelType::Device {
2639                    pipe_packets: false,
2640                },
2641                ..Default::default()
2642            }
2643        }
2644
2645        fn max_subchannels(&self) -> u16 {
2646            0
2647        }
2648
2649        fn install(&mut self, resources: DeviceResources) {
2650            self.state.lock().resources = Some(resources);
2651        }
2652
2653        async fn open(
2654            &mut self,
2655            channel_idx: u16,
2656            open_request: &OpenRequest,
2657        ) -> Result<(), ChannelOpenError> {
2658            tracing::info!("OPEN");
2659            TestDeviceState::set_open_request(&self.state, channel_idx, open_request.clone());
2660            Ok(())
2661        }
2662
2663        async fn close(&mut self, channel_idx: u16) {
2664            tracing::info!("CLOSE");
2665            assert!(TestDeviceState::remove_open_request(&self.state, channel_idx).is_some());
2666        }
2667
2668        async fn retarget_vp(&mut self, channel_idx: u16, target_vp: u32) {
2669            TestDeviceState::set_target_vp(&self.state, channel_idx, target_vp);
2670        }
2671
2672        fn start(&mut self) {
2673            tracing::info!("START");
2674            TestDeviceState::set_started(&self.state, true);
2675        }
2676
2677        async fn stop(&mut self) {
2678            tracing::info!("STOP");
2679            TestDeviceState::set_started(&self.state, false);
2680        }
2681
2682        fn supports_save_restore(&mut self) -> Option<&mut dyn SaveRestoreVmbusDevice> {
2683            None
2684        }
2685    }
2686
2687    #[async_test]
2688    async fn test_stopped_child(spawner: DefaultDriver) {
2689        // This is mostly testing vmbus_channel behavior when a channel is
2690        // stopped but vbmus_server is not and continues to receive
2691        // messages.
2692        let mut env = TestEnv::new(spawner.clone());
2693        let (test_device, test_device_state) = TestDevice::new_and_state(1);
2694        let control = env.vmbus.control();
2695        let channel = offer_channel(&spawner, control.as_ref(), test_device)
2696            .await
2697            .expect("test device failed to offer");
2698
2699        env.vmbus.start();
2700        env.connect(1, protocol::FeatureFlags::new(), false).await;
2701
2702        // Stop the channel.
2703        channel.stop().await;
2704
2705        assert_eq!(TestDeviceState::started(&test_device_state), false);
2706
2707        // GPADL processing is currently allowed while the channel is stopped,
2708        // so this should complete.
2709        env.synic.send_message_core(
2710            OutgoingMessage::with_data(
2711                &protocol::GpadlHeader {
2712                    channel_id: ChannelId(1),
2713                    gpadl_id: GpadlId(1),
2714                    count: 1,
2715                    len: 16,
2716                },
2717                [1u64, 0u64].as_bytes(),
2718            ),
2719            false,
2720        );
2721        env.expect_response(protocol::MessageType::GPADL_CREATED)
2722            .await;
2723
2724        // Open will pend while the channel is stopped.
2725        env.synic.send_message_core(
2726            OutgoingMessage::new(&protocol::OpenChannel {
2727                channel_id: ChannelId(1),
2728                open_id: 0,
2729                ring_buffer_gpadl_id: GpadlId(1),
2730                target_vp: 0,
2731                downstream_ring_buffer_page_offset: 0,
2732                user_data: UserDefinedData::default(),
2733            }),
2734            false,
2735        );
2736        let wait_for_response = mesh::CancelContext::new()
2737            .with_timeout(Duration::from_millis(150))
2738            .until_cancelled(env.expect_response(protocol::MessageType::OPEN_CHANNEL_RESULT))
2739            .await;
2740        assert!(matches!(
2741            wait_for_response,
2742            Err(CancelReason::DeadlineExceeded)
2743        ));
2744        assert!(TestDeviceState::open_request(&test_device_state, 0).is_none());
2745
2746        // Restart the channel and confirm that open completes.
2747        channel.start();
2748        env.expect_response(protocol::MessageType::OPEN_CHANNEL_RESULT)
2749            .await;
2750        assert!(TestDeviceState::open_request(&test_device_state, 0).is_some());
2751
2752        // Stop the channel and send a modify request.
2753        assert!(TestDeviceState::target_vp(&test_device_state, 0).is_none());
2754        channel.stop().await;
2755        env.synic.send_message_core(
2756            OutgoingMessage::new(&protocol::ModifyChannel {
2757                channel_id: ChannelId(1),
2758                target_vp: 2,
2759            }),
2760            false,
2761        );
2762        let wait_for_response = mesh::CancelContext::new()
2763            .with_timeout(Duration::from_millis(150))
2764            .until_cancelled(env.expect_response(protocol::MessageType::MODIFY_CHANNEL_RESPONSE))
2765            .await;
2766        assert!(matches!(
2767            wait_for_response,
2768            Err(CancelReason::DeadlineExceeded)
2769        ));
2770
2771        // Restart the channel and verify the modify request completes.
2772        channel.start();
2773        env.expect_response(protocol::MessageType::MODIFY_CHANNEL_RESPONSE)
2774            .await;
2775        assert_eq!(
2776            TestDeviceState::target_vp(&test_device_state, 0)
2777                .expect("Modify channel request received"),
2778            2
2779        );
2780
2781        // Stop the channel and send a close request. Close is currently
2782        // allowed through in order to support reset of the vmbus
2783        // server, so try that.
2784        channel.stop().await;
2785        env.vmbus.reset().await;
2786        assert!(TestDeviceState::open_request(&test_device_state, 0).is_none());
2787
2788        env.vmbus.stop().await;
2789    }
2790
2791    #[async_test]
2792    async fn test_confidential_connection(spawner: DefaultDriver) {
2793        let mut env = TestEnv::new(spawner);
2794        // Add regular bus child channels, one of which supports confidential external memory.
2795        let mut channel = env.offer(1, false).await;
2796        let mut channel2 = env.offer(2, true).await;
2797
2798        // Add a channel directly, like the relay would do.
2799        let (request_send, request_recv) = mesh::channel();
2800        let (server_request_send, server_request_recv) = mesh::channel();
2801        let id = Guid {
2802            data1: 3,
2803            ..Guid::ZERO
2804        };
2805        let control = env.vmbus.control();
2806        let relay_resources = control
2807            .offer_core(OfferInfo {
2808                params: OfferParamsInternal {
2809                    interface_name: "test".into(),
2810                    instance_id: id,
2811                    interface_id: id,
2812                    mmio_megabytes: 0,
2813                    mmio_megabytes_optional: 0,
2814                    subchannel_index: 0,
2815                    use_mnf: MnfUsage::Disabled,
2816                    offer_order: None,
2817                    flags: protocol::OfferFlags::new().with_enumerate_device_interface(true),
2818                    ..Default::default()
2819                },
2820                event: Interrupt::from_fn(|| {}),
2821                request_send,
2822                server_request_recv,
2823            })
2824            .await
2825            .unwrap();
2826
2827        let mut relay_channel = TestChannel {
2828            request_recv,
2829            server_request_send,
2830            _resources: relay_resources,
2831        };
2832
2833        env.vmbus.start();
2834        env.initiate_contact(
2835            protocol::Version::Copper,
2836            protocol::FeatureFlags::new().with_confidential_channels(true),
2837            true,
2838        );
2839
2840        env.expect_response(protocol::MessageType::VERSION_RESPONSE)
2841            .await;
2842
2843        env.synic.send_message_trusted(protocol::RequestOffers {});
2844
2845        // All offers added with add_child have confidential ring support.
2846        let offer = env.get_response::<protocol::OfferChannel>().await;
2847        assert!(offer.flags.confidential_ring_buffer());
2848        assert!(!offer.flags.confidential_external_memory());
2849        let offer = env.get_response::<protocol::OfferChannel>().await;
2850        assert!(offer.flags.confidential_ring_buffer());
2851        assert!(offer.flags.confidential_external_memory());
2852
2853        // The "relay" channel will not have its flags modified.
2854        let offer = env.get_response::<protocol::OfferChannel>().await;
2855        assert!(!offer.flags.confidential_ring_buffer());
2856        assert!(!offer.flags.confidential_external_memory());
2857
2858        env.expect_response(protocol::MessageType::ALL_OFFERS_DELIVERED)
2859            .await;
2860
2861        // Make sure that the correct confidential flags are set in the open request when opening
2862        // the channels.
2863        env.open_channel(1, 1, &mut channel, |request| {
2864            assert!(request.use_confidential_ring);
2865            assert!(!request.use_confidential_external_memory);
2866        })
2867        .await;
2868
2869        env.open_channel(2, 2, &mut channel2, |request| {
2870            assert!(request.use_confidential_ring);
2871            assert!(request.use_confidential_external_memory);
2872        })
2873        .await;
2874
2875        env.open_channel(3, 3, &mut relay_channel, |request| {
2876            assert!(!request.use_confidential_ring);
2877            assert!(!request.use_confidential_external_memory);
2878        })
2879        .await;
2880    }
2881
2882    #[async_test]
2883    async fn test_confidential_channels_unsupported(spawner: DefaultDriver) {
2884        let mut env = TestEnv::new(spawner);
2885        let mut channel = env.offer(1, false).await;
2886        let mut channel2 = env.offer(2, true).await;
2887
2888        env.vmbus.start();
2889        env.connect(2, protocol::FeatureFlags::new(), true).await;
2890
2891        // Make sure that the correct confidential flags are always false when the client doesn't
2892        // support confidential channels.
2893        env.open_channel(1, 1, &mut channel, |request| {
2894            assert!(!request.use_confidential_ring);
2895            assert!(!request.use_confidential_external_memory);
2896        })
2897        .await;
2898
2899        env.open_channel(2, 2, &mut channel2, |request| {
2900            assert!(!request.use_confidential_ring);
2901            assert!(!request.use_confidential_external_memory);
2902        })
2903        .await;
2904    }
2905
2906    #[async_test]
2907    async fn test_confidential_channels_untrusted(spawner: DefaultDriver) {
2908        let mut env = TestEnv::new(spawner);
2909        let mut channel = env.offer(1, false).await;
2910        let mut channel2 = env.offer(2, true).await;
2911
2912        env.vmbus.start();
2913        // Client claims to support confidential channels, but they can't be used because the
2914        // connection is untrusted.
2915        env.connect(
2916            2,
2917            protocol::FeatureFlags::new().with_confidential_channels(true),
2918            false,
2919        )
2920        .await;
2921
2922        // Make sure that the correct confidential flags are always false when the client doesn't
2923        // support confidential channels.
2924        env.open_channel(1, 1, &mut channel, |request| {
2925            assert!(!request.use_confidential_ring);
2926            assert!(!request.use_confidential_external_memory);
2927        })
2928        .await;
2929
2930        env.open_channel(2, 2, &mut channel2, |request| {
2931            assert!(!request.use_confidential_ring);
2932            assert!(!request.use_confidential_external_memory);
2933        })
2934        .await;
2935    }
2936}