vmbus_server/
lib.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4#![expect(missing_docs)]
5#![forbid(unsafe_code)]
6
7mod channel_bitmap;
8mod channels;
9pub mod event;
10pub mod hvsock;
11mod monitor;
12mod proxyintegration;
13
14/// The GUID type used for vmbus channel identifiers.
15pub type Guid = guid::Guid;
16
17use anyhow::Context;
18use async_trait::async_trait;
19use channel_bitmap::ChannelBitmap;
20use channels::ConnectionTarget;
21pub use channels::InitiateContactRequest;
22use channels::MessageTarget;
23pub use channels::MnfUsage;
24use channels::ModifyConnectionRequest;
25pub use channels::ModifyConnectionResponse;
26use channels::Notifier;
27use channels::OfferId;
28pub use channels::OfferParamsInternal;
29use channels::OpenParams;
30use channels::RestoreError;
31pub use channels::Update;
32use futures::FutureExt;
33use futures::StreamExt;
34use futures::channel::mpsc;
35use futures::channel::mpsc::SendError;
36use futures::future::OptionFuture;
37use futures::future::poll_fn;
38use futures::stream::SelectAll;
39use guestmem::GuestMemory;
40use hvdef::Vtl;
41use inspect::Inspect;
42use mesh::payload::Protobuf;
43use mesh::rpc::FailableRpc;
44use mesh::rpc::Rpc;
45use mesh::rpc::RpcError;
46use mesh::rpc::RpcSend;
47use pal_async::task::Spawn;
48use pal_async::task::Task;
49use pal_event::Event;
50#[cfg(windows)]
51pub use proxyintegration::ProxyIntegration;
52#[cfg(windows)]
53pub use proxyintegration::ProxyServerInfo;
54use ring::PAGE_SIZE;
55use std::collections::HashMap;
56use std::future;
57use std::future::Future;
58use std::pin::Pin;
59use std::sync::Arc;
60use std::task::Poll;
61use std::task::ready;
62use std::time::Duration;
63use unicycle::FuturesUnordered;
64use vmbus_channel::bus::ChannelRequest;
65use vmbus_channel::bus::ChannelServerRequest;
66use vmbus_channel::bus::GpadlRequest;
67use vmbus_channel::bus::ModifyRequest;
68use vmbus_channel::bus::OfferInput;
69use vmbus_channel::bus::OfferKey;
70use vmbus_channel::bus::OfferResources;
71use vmbus_channel::bus::OpenData;
72use vmbus_channel::bus::OpenRequest;
73use vmbus_channel::bus::OpenResult;
74use vmbus_channel::bus::ParentBus;
75use vmbus_channel::bus::RestoreResult;
76use vmbus_channel::gpadl::GpadlMap;
77use vmbus_channel::gpadl_ring::AlignedGpadlView;
78use vmbus_channel::gpadl_ring::GpadlRingMem;
79use vmbus_core::HvsockConnectRequest;
80use vmbus_core::HvsockConnectResult;
81use vmbus_core::MaxVersionInfo;
82use vmbus_core::OutgoingMessage;
83use vmbus_core::TaggedStream;
84use vmbus_core::VersionInfo;
85use vmbus_core::protocol;
86pub use vmbus_core::protocol::GpadlId;
87#[cfg(windows)]
88use vmbus_proxy::ProxyHandle;
89use vmbus_ring as ring;
90use vmbus_ring::gparange::MultiPagedRangeBuf;
91use vmcore::interrupt::Interrupt;
92use vmcore::save_restore::SavedStateRoot;
93use vmcore::synic::EventPort;
94use vmcore::synic::GuestEventPort;
95use vmcore::synic::GuestMessagePort;
96use vmcore::synic::MessagePort;
97use vmcore::synic::MonitorPageGpas;
98use vmcore::synic::SynicPortAccess;
99
100const SINT: u8 = 2;
101pub const REDIRECT_SINT: u8 = 7;
102pub const REDIRECT_VTL: Vtl = Vtl::Vtl2;
103const SHARED_EVENT_CONNECTION_ID: u32 = 2;
104const EVENT_PORT_ID: u32 = 2;
105const VMBUS_MESSAGE_TYPE: u32 = 1;
106
107const MAX_CONCURRENT_HVSOCK_REQUESTS: usize = 16;
108
109pub struct VmbusServer {
110    task_send: mesh::Sender<VmbusRequest>,
111    control: Arc<VmbusServerControl>,
112    _message_port: Box<dyn Sync + Send>,
113    _multiclient_message_port: Option<Box<dyn Sync + Send>>,
114    task: Task<ServerTask>,
115}
116
117pub struct VmbusServerBuilder<'a, T: Spawn> {
118    spawner: &'a T,
119    synic: Arc<dyn SynicPortAccess>,
120    gm: GuestMemory,
121    private_gm: Option<GuestMemory>,
122    vtl: Vtl,
123    hvsock_notify: Option<HvsockServerChannelHalf>,
124    server_relay: Option<VmbusServerChannelHalf>,
125    external_server: Option<mesh::Sender<InitiateContactRequest>>,
126    external_requests: Option<mesh::Receiver<InitiateContactRequest>>,
127    use_message_redirect: bool,
128    channel_id_offset: u16,
129    max_version: Option<MaxVersionInfo>,
130    delay_max_version: bool,
131    enable_mnf: bool,
132    force_confidential_external_memory: bool,
133    send_messages_while_stopped: bool,
134}
135
136/// The server side of the connection between a vmbus server and a relay.
137pub struct ServerChannelHalf<Request, Response> {
138    request_send: mesh::Sender<Request>,
139    response_receive: mesh::Receiver<Response>,
140}
141
142/// The relay side of a connection between a vmbus server and a relay.
143pub struct RelayChannelHalf<Request, Response> {
144    pub request_receive: mesh::Receiver<Request>,
145    pub response_send: mesh::Sender<Response>,
146}
147
148/// A connection between a vmbus server and a relay.
149pub struct RelayChannel<Request, Response> {
150    pub relay_half: RelayChannelHalf<Request, Response>,
151    pub server_half: ServerChannelHalf<Request, Response>,
152}
153
154impl<Request: 'static + Send, Response: 'static + Send> RelayChannel<Request, Response> {
155    /// Creates a new channel between the vmbus server and a relay.
156    pub fn new() -> Self {
157        let (request_send, request_receive) = mesh::channel();
158        let (response_send, response_receive) = mesh::channel();
159        Self {
160            relay_half: RelayChannelHalf {
161                request_receive,
162                response_send,
163            },
164            server_half: ServerChannelHalf {
165                request_send,
166                response_receive,
167            },
168        }
169    }
170}
171
172pub type VmbusServerChannelHalf = ServerChannelHalf<ModifyRelayRequest, ModifyConnectionResponse>;
173pub type VmbusRelayChannelHalf = RelayChannelHalf<ModifyRelayRequest, ModifyConnectionResponse>;
174pub type VmbusRelayChannel = RelayChannel<ModifyRelayRequest, ModifyConnectionResponse>;
175pub type HvsockServerChannelHalf = ServerChannelHalf<HvsockConnectRequest, HvsockConnectResult>;
176pub type HvsockRelayChannelHalf = RelayChannelHalf<HvsockConnectRequest, HvsockConnectResult>;
177pub type HvsockRelayChannel = RelayChannel<HvsockConnectRequest, HvsockConnectResult>;
178
179/// A request from the server to the relay to modify connection state.
180///
181/// The version, use_interrupt_page and target_message_vp field can only be present if this request
182/// was sent for an InitiateContact message from the guest.
183#[derive(Debug, Copy, Clone)]
184pub struct ModifyRelayRequest {
185    pub version: Option<u32>,
186    pub monitor_page: Update<MonitorPageGpas>,
187    pub use_interrupt_page: Option<bool>,
188}
189
190impl From<ModifyConnectionRequest> for ModifyRelayRequest {
191    fn from(value: ModifyConnectionRequest) -> Self {
192        Self {
193            version: value.version,
194            monitor_page: value.monitor_page,
195            use_interrupt_page: match value.interrupt_page {
196                Update::Unchanged => None,
197                Update::Reset => Some(false),
198                Update::Set(_) => Some(true),
199            },
200        }
201    }
202}
203
204#[derive(Debug)]
205enum VmbusRequest {
206    Reset(Rpc<(), ()>),
207    Inspect(inspect::Deferred),
208    Save(Rpc<(), SavedState>),
209    Restore(Rpc<SavedState, Result<(), RestoreError>>),
210    PostRestore(Rpc<(), Result<(), RestoreError>>),
211    Start,
212    Stop(Rpc<(), ()>),
213}
214
215#[derive(mesh::MeshPayload, Debug)]
216pub struct OfferInfo {
217    pub params: OfferParamsInternal,
218    pub request_send: mesh::Sender<ChannelRequest>,
219    pub server_request_recv: mesh::Receiver<ChannelServerRequest>,
220}
221
222#[derive(mesh::MeshPayload)]
223pub(crate) enum OfferRequest {
224    Offer(FailableRpc<OfferInfo, ()>),
225    ForceReset(Rpc<(), ()>),
226}
227
228impl Inspect for VmbusServer {
229    fn inspect(&self, req: inspect::Request<'_>) {
230        self.task_send.send(VmbusRequest::Inspect(req.defer()));
231    }
232}
233
234struct ChannelEvent(Interrupt);
235
236impl EventPort for ChannelEvent {
237    fn handle_event(&self, _flag: u16) {
238        self.0.deliver();
239    }
240
241    fn os_event(&self) -> Option<&Event> {
242        self.0.event()
243    }
244}
245
246#[derive(Debug, Protobuf, SavedStateRoot)]
247#[mesh(package = "vmbus.server")]
248pub struct SavedState {
249    #[mesh(1)]
250    server: channels::SavedState,
251    // Indicates if the lost synic bug is fixed or not. By default it's false.
252    // During the restore process, we check if the field is not true then
253    // unstick_channels() function will be called to mitigate the issue.
254    #[mesh(2)]
255    lost_synic_bug_fixed: bool,
256}
257
258const MESSAGE_CONNECTION_ID: u32 = 1;
259const MULTICLIENT_MESSAGE_CONNECTION_ID: u32 = 4;
260
261impl<'a, T: Spawn> VmbusServerBuilder<'a, T> {
262    /// Creates a new builder for `VmbusServer` with the default options.
263    pub fn new(spawner: &'a T, synic: Arc<dyn SynicPortAccess>, gm: GuestMemory) -> Self {
264        Self {
265            spawner,
266            synic,
267            gm,
268            private_gm: None,
269            vtl: Vtl::Vtl0,
270            hvsock_notify: None,
271            server_relay: None,
272            external_server: None,
273            external_requests: None,
274            use_message_redirect: false,
275            channel_id_offset: 0,
276            max_version: None,
277            delay_max_version: false,
278            enable_mnf: false,
279            force_confidential_external_memory: false,
280            send_messages_while_stopped: false,
281        }
282    }
283
284    /// Sets a separate guest memory instance to use for channels that are confidential (non-relay
285    /// channels in Underhill on a hardware isolated VM). This is not relevant for a non-Underhill
286    /// VmBus server.
287    pub fn private_gm(mut self, private_gm: Option<GuestMemory>) -> Self {
288        self.private_gm = private_gm;
289        self
290    }
291
292    /// Sets the VTL that this instance will serve.
293    pub fn vtl(mut self, vtl: Vtl) -> Self {
294        self.vtl = vtl;
295        self
296    }
297
298    /// Sets a send/receive pair used to handle hvsocket requests.
299    pub fn hvsock_notify(mut self, hvsock_notify: Option<HvsockServerChannelHalf>) -> Self {
300        self.hvsock_notify = hvsock_notify;
301        self
302    }
303
304    /// Sets a send/receive pair that will be notified of server requests. This is used by the
305    /// Underhill relay.
306    pub fn server_relay(mut self, server_relay: Option<VmbusServerChannelHalf>) -> Self {
307        self.server_relay = server_relay;
308        self
309    }
310
311    /// Sets a receiver that receives requests from another server.
312    pub fn external_requests(
313        mut self,
314        external_requests: Option<mesh::Receiver<InitiateContactRequest>>,
315    ) -> Self {
316        self.external_requests = external_requests;
317        self
318    }
319
320    /// Sets a sender used to forward unhandled connect requests (which used a different VTL)
321    /// to another server.
322    pub fn external_server(
323        mut self,
324        external_server: Option<mesh::Sender<InitiateContactRequest>>,
325    ) -> Self {
326        self.external_server = external_server;
327        self
328    }
329
330    /// Sets a value which indicates whether the vmbus control plane is redirected to Underhill.
331    pub fn use_message_redirect(mut self, use_message_redirect: bool) -> Self {
332        self.use_message_redirect = use_message_redirect;
333        self
334    }
335
336    /// Tells the server to use an offset when generating channel IDs to void collisions with
337    /// another vmbus server.
338    ///
339    /// N.B. This should only be used by the Underhill vmbus server.
340    pub fn enable_channel_id_offset(mut self, enable: bool) -> Self {
341        self.channel_id_offset = if enable { 1024 } else { 0 };
342        self
343    }
344
345    /// Tells the server to limit the protocol version offered to the guest.
346    ///
347    /// N.B. This is used for testing older protocols without requiring a specific guest OS.
348    pub fn max_version(mut self, max_version: Option<MaxVersionInfo>) -> Self {
349        self.max_version = max_version;
350        self
351    }
352
353    /// Delay limiting the maximum version until after the first `Unload` message.
354    ///
355    /// N.B. This is used to enable the use of versions older than `Version::Win10` with Uefi boot,
356    ///      since that's the oldest version the Uefi client supports.
357    pub fn delay_max_version(mut self, delay: bool) -> Self {
358        self.delay_max_version = delay;
359        self
360    }
361
362    /// Enable MNF support in the server.
363    ///
364    /// N.B. Enabling this has no effect if the synic does not support mapping monitor pages.
365    pub fn enable_mnf(mut self, enable: bool) -> Self {
366        self.enable_mnf = enable;
367        self
368    }
369
370    /// Force all non-relay channels to use encrypted external memory. Used for testing purposes
371    /// only.
372    pub fn force_confidential_external_memory(mut self, force: bool) -> Self {
373        self.force_confidential_external_memory = force;
374        self
375    }
376
377    /// Send messages to the partition even while stopped, which can cause
378    /// corrupted synic states across VM reset.
379    ///
380    /// This option is used to prevent messages from getting into the queue, for
381    /// saved state compatibility with release/2411. It can be removed once that
382    /// release is no longer supported.
383    pub fn send_messages_while_stopped(mut self, send: bool) -> Self {
384        self.send_messages_while_stopped = send;
385        self
386    }
387
388    /// Creates a new instance of the server.
389    ///
390    /// When the object is dropped, all channels will be closed and revoked
391    /// automatically.
392    pub fn build(self) -> anyhow::Result<VmbusServer> {
393        #[expect(clippy::disallowed_methods)] // TODO
394        let (message_send, message_recv) = mpsc::channel(64);
395        let message_sender = Arc::new(MessageSender {
396            send: message_send.clone(),
397            multiclient: self.use_message_redirect,
398        });
399
400        let (redirect_vtl, redirect_sint) = if self.use_message_redirect {
401            (REDIRECT_VTL, REDIRECT_SINT)
402        } else {
403            (self.vtl, SINT)
404        };
405
406        // If this server is not for VTL2, use a server-specific connection ID rather than the
407        // standard one.
408        let connection_id = if self.vtl == Vtl::Vtl0 && !self.use_message_redirect {
409            MESSAGE_CONNECTION_ID
410        } else {
411            // TODO: This ID should be using the correct target VP, but that is not known until
412            //       InitiateContact.
413            VmbusServer::get_child_message_connection_id(0, redirect_sint, redirect_vtl)
414        };
415
416        let _message_port = self
417            .synic
418            .add_message_port(connection_id, redirect_vtl, message_sender)
419            .context("failed to create vmbus synic ports")?;
420
421        // If this server is for VTL0, it is also responsible for the multiclient message port.
422        // N.B. If control plane redirection is enabled, the redirected message port is used for
423        //      multiclient and no separate multiclient port is created.
424        let _multiclient_message_port = if self.vtl == Vtl::Vtl0 && !self.use_message_redirect {
425            let multiclient_message_sender = Arc::new(MessageSender {
426                send: message_send,
427                multiclient: true,
428            });
429
430            Some(
431                self.synic
432                    .add_message_port(
433                        MULTICLIENT_MESSAGE_CONNECTION_ID,
434                        self.vtl,
435                        multiclient_message_sender,
436                    )
437                    .context("failed to create vmbus synic ports")?,
438            )
439        } else {
440            None
441        };
442
443        let (offer_send, offer_recv) = mesh::mpsc_channel();
444        let control = Arc::new(VmbusServerControl {
445            mem: self.gm.clone(),
446            private_mem: self.private_gm.clone(),
447            send: offer_send,
448            use_event: self.synic.prefer_os_events(),
449            force_confidential_external_memory: self.force_confidential_external_memory,
450        });
451
452        let mut server = channels::Server::new(self.vtl, connection_id, self.channel_id_offset);
453
454        // If requested, limit the maximum protocol version and feature flags.
455        if let Some(version) = self.max_version {
456            server.set_compatibility_version(version, self.delay_max_version);
457        }
458        let (relay_request_send, relay_response_recv) =
459            if let Some(server_relay) = self.server_relay {
460                let r = server_relay.response_receive.boxed().fuse();
461                (server_relay.request_send, r)
462            } else {
463                let (req_send, req_recv) = mesh::channel();
464                let resp_recv = req_recv
465                    .map(|_| {
466                        ModifyConnectionResponse::Supported(
467                            protocol::ConnectionState::SUCCESSFUL,
468                            protocol::FeatureFlags::from_bits(u32::MAX),
469                        )
470                    })
471                    .boxed()
472                    .fuse();
473                (req_send, resp_recv)
474            };
475
476        // If no hvsock notifier was specified, use a default one that always sends an error response.
477        let (hvsock_send, hvsock_recv) = if let Some(hvsock_notify) = self.hvsock_notify {
478            let r = hvsock_notify.response_receive.boxed().fuse();
479            (hvsock_notify.request_send, r)
480        } else {
481            let (req_send, req_recv) = mesh::channel();
482            let resp_recv = req_recv
483                .map(|r: HvsockConnectRequest| HvsockConnectResult::from_request(&r, false))
484                .boxed()
485                .fuse();
486            (req_send, resp_recv)
487        };
488
489        let inner = ServerTaskInner {
490            running: false,
491            send_messages_while_stopped: self.send_messages_while_stopped,
492            gm: self.gm,
493            private_gm: self.private_gm,
494            vtl: self.vtl,
495            redirect_vtl,
496            redirect_sint,
497            message_port: self
498                .synic
499                .new_guest_message_port(redirect_vtl, 0, redirect_sint)?,
500            synic: self.synic,
501            hvsock_requests: 0,
502            hvsock_send,
503            channels: HashMap::new(),
504            channel_responses: FuturesUnordered::new(),
505            relay_send: relay_request_send,
506            external_server_send: self.external_server,
507            channel_bitmap: None,
508            shared_event_port: None,
509            reset_done: Vec::new(),
510            enable_mnf: self.enable_mnf,
511        };
512
513        let (task_send, task_recv) = mesh::channel();
514        let mut server_task = ServerTask {
515            server,
516            task_recv,
517            offer_recv,
518            message_recv,
519            server_request_recv: SelectAll::new(),
520            inner,
521            external_requests: self.external_requests,
522            next_seq: 0,
523            unstick_on_start: false,
524        };
525
526        let task = self.spawner.spawn("vmbus server", async move {
527            server_task.run(relay_response_recv, hvsock_recv).await;
528            server_task
529        });
530
531        Ok(VmbusServer {
532            task_send,
533            control,
534            _message_port,
535            _multiclient_message_port,
536            task,
537        })
538    }
539}
540
541impl VmbusServer {
542    /// Creates a new builder for `VmbusServer` with the default options.
543    pub fn builder<T: Spawn>(
544        spawner: &T,
545        synic: Arc<dyn SynicPortAccess>,
546        gm: GuestMemory,
547    ) -> VmbusServerBuilder<'_, T> {
548        VmbusServerBuilder::new(spawner, synic, gm)
549    }
550
551    pub async fn save(&self) -> SavedState {
552        self.task_send.call(VmbusRequest::Save, ()).await.unwrap()
553    }
554
555    pub async fn restore(&self, state: SavedState) -> Result<(), RestoreError> {
556        self.task_send
557            .call(VmbusRequest::Restore, state)
558            .await
559            .unwrap()
560    }
561
562    pub async fn post_restore(&self) -> Result<(), RestoreError> {
563        self.task_send
564            .call(VmbusRequest::PostRestore, ())
565            .await
566            .unwrap()
567    }
568
569    /// Stop the control plane.
570    pub async fn stop(&self) {
571        self.task_send.call(VmbusRequest::Stop, ()).await.unwrap()
572    }
573
574    /// Starts the control plane.
575    pub fn start(&self) {
576        self.task_send.send(VmbusRequest::Start);
577    }
578
579    /// Resets the vmbus channel state.
580    pub async fn reset(&self) {
581        tracing::debug!("resetting channel state");
582        self.task_send.call(VmbusRequest::Reset, ()).await.unwrap()
583    }
584
585    /// Tears down the vmbus control plane.
586    pub async fn shutdown(self) {
587        drop(self.task_send);
588        let _ = self.task.await;
589    }
590
591    /// Returns an object that can be used to offer channels.
592    pub fn control(&self) -> Arc<VmbusServerControl> {
593        self.control.clone()
594    }
595
596    /// Returns the message connection ID to use for a communication from the guest for servers
597    /// that use a non-standard SINT or VTL.
598    fn get_child_message_connection_id(vp_index: u32, sint_index: u8, vtl: Vtl) -> u32 {
599        MULTICLIENT_MESSAGE_CONNECTION_ID
600            | (vtl as u32) << 22
601            | vp_index << 8
602            | (sint_index as u32) << 4
603    }
604
605    fn get_child_event_port_id(channel_id: protocol::ChannelId, sint_index: u8, vtl: Vtl) -> u32 {
606        EVENT_PORT_ID | (vtl as u32) << 22 | channel_id.0 << 8 | (sint_index as u32) << 4
607    }
608}
609
610#[derive(mesh::MeshPayload)]
611pub struct RestoreInfo {
612    open_data: Option<OpenData>,
613    gpadls: Vec<(GpadlId, u16, Vec<u64>)>,
614    interrupt: Option<Interrupt>,
615}
616
617#[derive(Default)]
618pub struct SynicMessage {
619    data: Vec<u8>,
620    multiclient: bool,
621    trusted: bool,
622}
623
624struct ServerTask {
625    server: channels::Server,
626    task_recv: mesh::Receiver<VmbusRequest>,
627    offer_recv: mesh::Receiver<OfferRequest>,
628    message_recv: mpsc::Receiver<SynicMessage>,
629    server_request_recv: SelectAll<TaggedStream<OfferId, mesh::Receiver<ChannelServerRequest>>>,
630    inner: ServerTaskInner,
631    external_requests: Option<mesh::Receiver<InitiateContactRequest>>,
632    /// Next value for [`Channel::seq`].
633    next_seq: u64,
634    unstick_on_start: bool,
635}
636
637struct ServerTaskInner {
638    running: bool,
639    send_messages_while_stopped: bool,
640    gm: GuestMemory,
641    private_gm: Option<GuestMemory>,
642    synic: Arc<dyn SynicPortAccess>,
643    vtl: Vtl,
644    redirect_vtl: Vtl,
645    redirect_sint: u8,
646    message_port: Box<dyn GuestMessagePort>,
647    hvsock_requests: usize,
648    hvsock_send: mesh::Sender<HvsockConnectRequest>,
649    channels: HashMap<OfferId, Channel>,
650    channel_responses: FuturesUnordered<
651        Pin<Box<dyn Send + Future<Output = (OfferId, u64, Result<ChannelResponse, RpcError>)>>>,
652    >,
653    external_server_send: Option<mesh::Sender<InitiateContactRequest>>,
654    relay_send: mesh::Sender<ModifyRelayRequest>,
655    channel_bitmap: Option<Arc<ChannelBitmap>>,
656    shared_event_port: Option<Box<dyn Send>>,
657    reset_done: Vec<Rpc<(), ()>>,
658    enable_mnf: bool,
659}
660
661#[derive(Debug)]
662enum ChannelResponse {
663    Open(Option<OpenResult>),
664    Close,
665    Gpadl(GpadlId, bool),
666    TeardownGpadl(GpadlId),
667    Modify(i32),
668}
669
670struct Channel {
671    key: OfferKey,
672    send: mesh::Sender<ChannelRequest>,
673    seq: u64,
674    state: ChannelState,
675    gpadls: Arc<GpadlMap>,
676    flags: protocol::OfferFlags,
677    // A channel can be reserved no matter what state it is in. This allows the message port for a
678    // reserved channel to remain available even if the channel is closed, so the guest can read the
679    // close reserved channel response. The reserved state is cleared when the channel is revoked,
680    // reopened, or the guest sends an unload message.
681    reserved_state: ReservedState,
682}
683
684struct ReservedState {
685    message_port: Option<Box<dyn GuestMessagePort>>,
686    target: ConnectionTarget,
687}
688
689enum ChannelState {
690    Closed,
691    Opening {
692        open_params: OpenParams,
693        guest_event_port: Box<dyn GuestEventPort>,
694        host_to_guest_interrupt: Interrupt,
695    },
696    Open {
697        open_params: OpenParams,
698        _event_port: Box<dyn Send>,
699        guest_event_port: Box<dyn GuestEventPort>,
700        host_to_guest_interrupt: Interrupt,
701        guest_to_host_event: Arc<ChannelEvent>,
702    },
703    Closing,
704    FailedOpen,
705}
706
707impl ServerTask {
708    fn handle_offer(&mut self, mut info: OfferInfo) -> anyhow::Result<()> {
709        let key = info.params.key();
710        let flags = info.params.flags;
711
712        if self.inner.enable_mnf && self.inner.synic.monitor_support().is_some() {
713            // If this server is handling MnF, ignore any relayed monitor IDs but still enable MnF
714            // for those channels.
715            // N.B. Since this can only happen in OpenHCL, which emulates MnF, the latency is
716            //      ignored.
717            if info.params.use_mnf.is_relayed() {
718                info.params.use_mnf = MnfUsage::Enabled {
719                    latency: Duration::ZERO,
720                }
721            }
722        } else if info.params.use_mnf.is_enabled() {
723            // If the server is not handling MnF, disable it for the channel. This does not affect
724            // channels with a relayed monitor ID.
725            info.params.use_mnf = MnfUsage::Disabled;
726        }
727
728        let offer_id = self
729            .server
730            .with_notifier(&mut self.inner)
731            .offer_channel(info.params)
732            .context("channel offer failed")?;
733
734        tracing::debug!(?offer_id, %key, "offered channel");
735
736        let id = self.next_seq;
737        self.next_seq += 1;
738        self.inner.channels.insert(
739            offer_id,
740            Channel {
741                key,
742                send: info.request_send,
743                state: ChannelState::Closed,
744                gpadls: GpadlMap::new(),
745                seq: id,
746                flags,
747                reserved_state: ReservedState {
748                    message_port: None,
749                    target: ConnectionTarget { vp: 0, sint: 0 },
750                },
751            },
752        );
753
754        self.server_request_recv
755            .push(TaggedStream::new(offer_id, info.server_request_recv));
756
757        Ok(())
758    }
759
760    fn handle_revoke(&mut self, offer_id: OfferId) {
761        // The channel may or may not exist in the map depending on whether it's been explicitly
762        // revoked before being dropped.
763        if self.inner.channels.remove(&offer_id).is_some() {
764            tracing::info!(?offer_id, "revoking channel");
765            self.server
766                .with_notifier(&mut self.inner)
767                .revoke_channel(offer_id);
768        }
769    }
770
771    fn handle_response(
772        &mut self,
773        offer_id: OfferId,
774        seq: u64,
775        response: Result<ChannelResponse, RpcError>,
776    ) {
777        // Validate the sequence to ensure the response is not for a revoked channel.
778        let channel = self
779            .inner
780            .channels
781            .get(&offer_id)
782            .filter(|channel| channel.seq == seq);
783
784        if let Some(channel) = channel {
785            match response {
786                Ok(response) => match response {
787                    ChannelResponse::Open(result) => self.handle_open(offer_id, result),
788                    ChannelResponse::Close => self.handle_close(offer_id),
789                    ChannelResponse::Gpadl(gpadl_id, ok) => {
790                        self.handle_gpadl_create(offer_id, gpadl_id, ok)
791                    }
792                    ChannelResponse::TeardownGpadl(gpadl_id) => {
793                        self.handle_gpadl_teardown(offer_id, gpadl_id)
794                    }
795                    ChannelResponse::Modify(status) => self.handle_modify_channel(offer_id, status),
796                },
797                Err(err) => {
798                    tracing::error!(
799                        key = %channel.key,
800                        error = &err as &dyn std::error::Error,
801                        "channel response failure, channel is in inconsistent state until revoked"
802                    );
803                }
804            }
805        } else {
806            tracing::debug!(offer_id = ?offer_id, seq, ?response, "received response after revoke");
807        }
808    }
809
810    fn handle_open(&mut self, offer_id: OfferId, result: Option<OpenResult>) {
811        let status = if result.is_some() {
812            0
813        } else {
814            protocol::STATUS_UNSUCCESSFUL
815        };
816        if let Err(err) = self.inner.complete_open(offer_id, result) {
817            tracelimit::error_ratelimited!(
818                error = err.as_ref() as &dyn std::error::Error,
819                "failed to complete open"
820            );
821            // If complete_open failed, the channel is now in FailedOpen state and the device needs
822            // to notified to close it. Calling open_complete is postponed until the device responds
823            // to the close request.
824            self.inner.notify(offer_id, channels::Action::Close);
825        } else {
826            self.server
827                .with_notifier(&mut self.inner)
828                .open_complete(offer_id, status);
829        }
830    }
831
832    fn handle_close(&mut self, offer_id: OfferId) {
833        let channel = self
834            .inner
835            .channels
836            .get_mut(&offer_id)
837            .expect("channel still exists");
838
839        match &mut channel.state {
840            ChannelState::Closing => {
841                channel.state = ChannelState::Closed;
842                self.server
843                    .with_notifier(&mut self.inner)
844                    .close_complete(offer_id);
845            }
846            ChannelState::FailedOpen => {
847                // Now that the device has processed the close request after open failed, we can
848                // finish handling the failed open and send an open result to the guest.
849                channel.state = ChannelState::Closed;
850                self.server
851                    .with_notifier(&mut self.inner)
852                    .open_complete(offer_id, protocol::STATUS_UNSUCCESSFUL);
853            }
854            _ => {
855                tracing::error!(?offer_id, "invalid close channel response");
856            }
857        };
858    }
859
860    fn handle_gpadl_create(&mut self, offer_id: OfferId, gpadl_id: GpadlId, ok: bool) {
861        let status = if ok { 0 } else { protocol::STATUS_UNSUCCESSFUL };
862        self.server
863            .with_notifier(&mut self.inner)
864            .gpadl_create_complete(offer_id, gpadl_id, status);
865    }
866
867    fn handle_gpadl_teardown(&mut self, offer_id: OfferId, gpadl_id: GpadlId) {
868        self.server
869            .with_notifier(&mut self.inner)
870            .gpadl_teardown_complete(offer_id, gpadl_id);
871    }
872
873    fn handle_modify_channel(&mut self, offer_id: OfferId, status: i32) {
874        self.server
875            .with_notifier(&mut self.inner)
876            .modify_channel_complete(offer_id, status);
877    }
878
879    fn handle_restore_channel(
880        &mut self,
881        offer_id: OfferId,
882        open: Option<OpenResult>,
883    ) -> anyhow::Result<RestoreResult> {
884        let gpadls = self.server.channel_gpadls(offer_id);
885
886        // If the channel is opened, handle that before calling into channels so that failure can
887        // be handled before the channel is marked restored.
888        let open_request = open
889            .map(|result| -> anyhow::Result<_> {
890                let params = self.server.get_restore_open_params(offer_id)?;
891                let (_, interrupt) = self.inner.open_channel(offer_id, &params)?;
892                let channel = self.inner.complete_open(offer_id, Some(result))?;
893                Ok(OpenRequest::new(
894                    params.open_data,
895                    interrupt,
896                    self.server
897                        .get_version()
898                        .expect("must be connected")
899                        .feature_flags,
900                    channel.flags,
901                ))
902            })
903            .transpose()?;
904
905        self.server
906            .with_notifier(&mut self.inner)
907            .restore_channel(offer_id, open_request.is_some())?;
908
909        let channel = self.inner.channels.get_mut(&offer_id).unwrap();
910        for gpadl in &gpadls {
911            if let Ok(buf) =
912                MultiPagedRangeBuf::new(gpadl.request.count.into(), gpadl.request.buf.clone())
913            {
914                channel.gpadls.add(gpadl.request.id, buf);
915            }
916        }
917
918        let result = RestoreResult {
919            open_request,
920            gpadls,
921        };
922        Ok(result)
923    }
924
925    fn handle_request(&mut self, request: VmbusRequest) {
926        tracing::debug!(?request, "handle_request");
927        match request {
928            VmbusRequest::Reset(rpc) => self.handle_reset(rpc),
929            VmbusRequest::Inspect(deferred) => {
930                deferred.respond(|resp| {
931                    resp.field("message_port", &self.inner.message_port)
932                        .field("running", self.inner.running)
933                        .field("hvsock_requests", self.inner.hvsock_requests)
934                        .field_mut_with("unstick_channels", |v| {
935                            let v: inspect::Value = if let Some(v) = v {
936                                if v == "force" {
937                                    self.unstick_channels(true);
938                                    v.into()
939                                } else {
940                                    let v =
941                                        v.parse().ok().context("expected false, true, or force")?;
942                                    if v {
943                                        self.unstick_channels(false);
944                                    }
945                                    v.into()
946                                }
947                            } else {
948                                false.into()
949                            };
950                            anyhow::Ok(v)
951                        })
952                        .merge(&self.server.with_notifier(&mut self.inner));
953                });
954            }
955            VmbusRequest::Save(rpc) => rpc.handle_sync(|()| SavedState {
956                server: self.server.save(),
957                lost_synic_bug_fixed: true,
958            }),
959            VmbusRequest::Restore(rpc) => rpc.handle_sync(|state| {
960                self.unstick_on_start = !state.lost_synic_bug_fixed;
961                self.server.restore(state.server)
962            }),
963            VmbusRequest::PostRestore(rpc) => {
964                rpc.handle_sync(|()| self.server.with_notifier(&mut self.inner).post_restore())
965            }
966            VmbusRequest::Stop(rpc) => rpc.handle_sync(|()| {
967                if self.inner.running {
968                    self.inner.running = false;
969                }
970            }),
971            VmbusRequest::Start => {
972                if !self.inner.running {
973                    self.inner.running = true;
974                    if self.unstick_on_start {
975                        tracing::info!(
976                            "lost synic bug fix is not in yet, call unstick_channels to mitigate the issue."
977                        );
978                        self.unstick_channels(false);
979                        self.unstick_on_start = false;
980                    }
981                }
982            }
983        }
984    }
985
986    fn handle_reset(&mut self, rpc: Rpc<(), ()>) {
987        let needs_reset = self.inner.reset_done.is_empty();
988        self.inner.reset_done.push(rpc);
989        if needs_reset {
990            self.server.with_notifier(&mut self.inner).reset();
991        }
992    }
993
994    fn handle_relay_response(&mut self, response: ModifyConnectionResponse) {
995        self.server
996            .with_notifier(&mut self.inner)
997            .complete_modify_connection(response);
998    }
999
1000    fn handle_tl_connect_result(&mut self, result: HvsockConnectResult) {
1001        assert_ne!(self.inner.hvsock_requests, 0);
1002        self.inner.hvsock_requests -= 1;
1003
1004        self.server
1005            .with_notifier(&mut self.inner)
1006            .send_tl_connect_result(result);
1007    }
1008
1009    fn handle_synic_message(&mut self, message: SynicMessage) {
1010        match self
1011            .server
1012            .with_notifier(&mut self.inner)
1013            .handle_synic_message(message)
1014        {
1015            Ok(()) => {}
1016            Err(err) => {
1017                tracing::warn!(
1018                    error = &err as &dyn std::error::Error,
1019                    "synic message error"
1020                );
1021            }
1022        }
1023    }
1024
1025    /// Handles a request forwarded by a different vmbus server. This is used to forward requests
1026    /// for different VTLs to different servers.
1027    ///
1028    /// N.B. This uses the same mechanism as the HCL server relay, so all requests, even the ones
1029    ///      meant for the primary server, are forwarded. In that case the primary server depends
1030    ///      on this server to send back a response so it can continue handling it.
1031    fn handle_external_request(&mut self, request: InitiateContactRequest) {
1032        self.server
1033            .with_notifier(&mut self.inner)
1034            .initiate_contact(request);
1035    }
1036
1037    async fn run(
1038        &mut self,
1039        mut relay_response_recv: impl futures::stream::FusedStream<Item = ModifyConnectionResponse>
1040        + Unpin,
1041        mut hvsock_recv: impl futures::stream::FusedStream<Item = HvsockConnectResult> + Unpin,
1042    ) {
1043        loop {
1044            // Create an OptionFuture for each event that should only be handled
1045            // while the VM is running. In other cases, leave the events in
1046            // their respective queues.
1047
1048            let running_not_resetting = self.inner.running && self.inner.reset_done.is_empty();
1049            let mut external_requests = OptionFuture::from(
1050                running_not_resetting
1051                    .then(|| {
1052                        self.external_requests
1053                            .as_mut()
1054                            .map(|r| r.select_next_some())
1055                    })
1056                    .flatten(),
1057            );
1058
1059            // Try to send any pending messages while the VM is running.
1060            let has_pending_messages = self.server.has_pending_messages();
1061            let message_port = self.inner.message_port.as_mut();
1062            let mut flush_pending_messages =
1063                OptionFuture::from((running_not_resetting && has_pending_messages).then(|| {
1064                    poll_fn(|cx| {
1065                        self.server.poll_flush_pending_messages(|msg| {
1066                            message_port.poll_post_message(cx, VMBUS_MESSAGE_TYPE, msg.data())
1067                        })
1068                    })
1069                    .fuse()
1070                }));
1071
1072            // Only handle new incoming messages if there are no outgoing messages pending, and not
1073            // too many hvsock requests outstanding. This puts a bound on the resources used by the
1074            // guest.
1075            let mut message_recv = OptionFuture::from(
1076                (running_not_resetting
1077                    && !has_pending_messages
1078                    && self.inner.hvsock_requests < MAX_CONCURRENT_HVSOCK_REQUESTS)
1079                    .then(|| self.message_recv.select_next_some()),
1080            );
1081
1082            // Accept channel responses until stopped or when resetting.
1083            let mut channel_response = OptionFuture::from(
1084                (self.inner.running || !self.inner.reset_done.is_empty())
1085                    .then(|| self.inner.channel_responses.select_next_some()),
1086            );
1087
1088            // Accept hvsock connect responses while the VM is running.
1089            let mut hvsock_response =
1090                OptionFuture::from(running_not_resetting.then(|| hvsock_recv.select_next_some()));
1091
1092            futures::select! { // merge semantics
1093                r = self.task_recv.recv().fuse() => {
1094                    if let Ok(request) = r {
1095                        self.handle_request(request);
1096                    } else {
1097                        break;
1098                    }
1099                }
1100                r = self.offer_recv.select_next_some() => {
1101                    match r {
1102                        OfferRequest::Offer(rpc) => {
1103                            rpc.handle_failable_sync(|request| { self.handle_offer(request) })
1104                        },
1105                        OfferRequest::ForceReset(rpc) => {
1106                            self.handle_reset(rpc);
1107                        }
1108                    }
1109                }
1110                r = self.server_request_recv.select_next_some() => {
1111                    match r {
1112                        (id, Some(request)) => match request {
1113                            ChannelServerRequest::Restore(rpc) => rpc.handle_failable_sync(|open| {
1114                                self.handle_restore_channel(id, open)
1115                            }),
1116                            ChannelServerRequest::Revoke(rpc) => rpc.handle_sync(|_| {
1117                                self.handle_revoke(id);
1118                            })
1119                        },
1120                        (id, None) => self.handle_revoke(id),
1121                    }
1122                }
1123                r = channel_response => {
1124                    let (id, seq, response) = r.unwrap();
1125                    self.handle_response(id, seq, response);
1126                }
1127                r = relay_response_recv.select_next_some() => {
1128                    self.handle_relay_response(r);
1129                },
1130                r = hvsock_response => {
1131                    self.handle_tl_connect_result(r.unwrap());
1132                }
1133                data = message_recv => {
1134                    let data = data.unwrap();
1135                    self.handle_synic_message(data);
1136                }
1137                r = external_requests => {
1138                    let r = r.unwrap();
1139                    self.handle_external_request(r);
1140                }
1141                _r = flush_pending_messages => {}
1142                complete => break,
1143            }
1144        }
1145    }
1146
1147    /// Wakes the host and guest for every open channel. If `force`, always
1148    /// wakes both the host and guest. If `!force`, only wake for rings that are
1149    /// in the state where a notification is expected.
1150    fn unstick_channels(&self, force: bool) {
1151        for channel in self.inner.channels.values() {
1152            if let Err(err) = self.unstick_channel(channel, force) {
1153                tracing::warn!(
1154                    channel = %channel.key,
1155                    error = err.as_ref() as &dyn std::error::Error,
1156                    "could not unstick channel"
1157                );
1158            }
1159        }
1160    }
1161
1162    fn unstick_channel(&self, channel: &Channel, force: bool) -> anyhow::Result<()> {
1163        if let ChannelState::Open {
1164            open_params,
1165            host_to_guest_interrupt,
1166            guest_to_host_event,
1167            ..
1168        } = &channel.state
1169        {
1170            if force {
1171                tracing::info!(channel = %channel.key, "waking host and guest");
1172                guest_to_host_event.0.deliver();
1173                host_to_guest_interrupt.deliver();
1174                return Ok(());
1175            }
1176
1177            let gpadl = channel
1178                .gpadls
1179                .clone()
1180                .view()
1181                .map(open_params.open_data.ring_gpadl_id)
1182                .context("couldn't find ring gpadl")?;
1183
1184            let aligned = AlignedGpadlView::new(gpadl)
1185                .ok()
1186                .context("ring not aligned")?;
1187            let (in_gpadl, out_gpadl) = aligned
1188                .split(open_params.open_data.ring_offset)
1189                .ok()
1190                .context("couldn't split ring")?;
1191
1192            if let Err(err) = self.unstick_incoming_ring(
1193                channel,
1194                in_gpadl,
1195                guest_to_host_event,
1196                host_to_guest_interrupt,
1197            ) {
1198                tracing::warn!(
1199                    channel = %channel.key,
1200                    error = err.as_ref() as &dyn std::error::Error,
1201                    "could not unstick incoming ring"
1202                );
1203            }
1204            if let Err(err) = self.unstick_outgoing_ring(
1205                channel,
1206                out_gpadl,
1207                guest_to_host_event,
1208                host_to_guest_interrupt,
1209            ) {
1210                tracing::warn!(
1211                    channel = %channel.key,
1212                    error = err.as_ref() as &dyn std::error::Error,
1213                    "could not unstick outgoing ring"
1214                );
1215            }
1216        }
1217        Ok(())
1218    }
1219
1220    fn unstick_incoming_ring(
1221        &self,
1222        channel: &Channel,
1223        in_gpadl: AlignedGpadlView,
1224        guest_to_host_event: &ChannelEvent,
1225        host_to_guest_interrupt: &Interrupt,
1226    ) -> Result<(), anyhow::Error> {
1227        let incoming_mem = GpadlRingMem::new(in_gpadl, &self.inner.gm)?;
1228        if ring::reader_needs_signal(&incoming_mem) {
1229            tracing::info!(channel = %channel.key, "waking host for incoming ring");
1230            guest_to_host_event.0.deliver();
1231        }
1232        if ring::writer_needs_signal(&incoming_mem) {
1233            tracing::info!(channel = %channel.key, "waking guest for incoming ring");
1234            host_to_guest_interrupt.deliver();
1235        }
1236        Ok(())
1237    }
1238
1239    fn unstick_outgoing_ring(
1240        &self,
1241        channel: &Channel,
1242        out_gpadl: AlignedGpadlView,
1243        guest_to_host_event: &ChannelEvent,
1244        host_to_guest_interrupt: &Interrupt,
1245    ) -> Result<(), anyhow::Error> {
1246        let outgoing_mem = GpadlRingMem::new(out_gpadl, &self.inner.gm)?;
1247        if ring::reader_needs_signal(&outgoing_mem) {
1248            tracing::info!(channel = %channel.key, "waking guest for outgoing ring");
1249            host_to_guest_interrupt.deliver();
1250        }
1251        if ring::writer_needs_signal(&outgoing_mem) {
1252            tracing::info!(channel = %channel.key, "waking host for outgoing ring");
1253            guest_to_host_event.0.deliver();
1254        }
1255        Ok(())
1256    }
1257}
1258
1259impl Notifier for ServerTaskInner {
1260    fn notify(&mut self, offer_id: OfferId, action: channels::Action) {
1261        let channel = self
1262            .channels
1263            .get_mut(&offer_id)
1264            .expect("channel does not exist");
1265
1266        fn handle<I: 'static + Send, R: 'static + Send>(
1267            offer_id: OfferId,
1268            channel: &Channel,
1269            req: impl FnOnce(Rpc<I, R>) -> ChannelRequest,
1270            input: I,
1271            f: impl 'static + Send + FnOnce(R) -> ChannelResponse,
1272        ) -> Pin<Box<dyn Send + Future<Output = (OfferId, u64, Result<ChannelResponse, RpcError>)>>>
1273        {
1274            let recv = channel.send.call(req, input);
1275            let seq = channel.seq;
1276            Box::pin(async move {
1277                let r = recv.await.map(f);
1278                (offer_id, seq, r)
1279            })
1280        }
1281
1282        let response = match action {
1283            channels::Action::Open(open_params, version) => {
1284                let seq = channel.seq;
1285                match self.open_channel(offer_id, &open_params) {
1286                    Ok((channel, interrupt)) => handle(
1287                        offer_id,
1288                        channel,
1289                        ChannelRequest::Open,
1290                        OpenRequest::new(
1291                            open_params.open_data,
1292                            interrupt,
1293                            version.feature_flags,
1294                            channel.flags,
1295                        ),
1296                        ChannelResponse::Open,
1297                    ),
1298                    Err(err) => {
1299                        tracelimit::error_ratelimited!(
1300                            err = err.as_ref() as &dyn std::error::Error,
1301                            ?offer_id,
1302                            "could not open channel",
1303                        );
1304
1305                        // Return an error response to the channels module if the open_channel call
1306                        // failed.
1307                        Box::pin(future::ready((
1308                            offer_id,
1309                            seq,
1310                            Ok(ChannelResponse::Open(None)),
1311                        )))
1312                    }
1313                }
1314            }
1315            channels::Action::Close => {
1316                if let Some(channel_bitmap) = self.channel_bitmap.as_ref() {
1317                    if let ChannelState::Open { open_params, .. } = channel.state {
1318                        channel_bitmap.unregister_channel(open_params.event_flag);
1319                    }
1320                }
1321
1322                channel.state = ChannelState::Closing;
1323                handle(offer_id, channel, ChannelRequest::Close, (), |()| {
1324                    ChannelResponse::Close
1325                })
1326            }
1327            channels::Action::Gpadl(gpadl_id, count, buf) => {
1328                channel.gpadls.add(
1329                    gpadl_id,
1330                    MultiPagedRangeBuf::new(count.into(), buf.clone()).unwrap(),
1331                );
1332                handle(
1333                    offer_id,
1334                    channel,
1335                    ChannelRequest::Gpadl,
1336                    GpadlRequest {
1337                        id: gpadl_id,
1338                        count,
1339                        buf,
1340                    },
1341                    move |r| ChannelResponse::Gpadl(gpadl_id, r),
1342                )
1343            }
1344            channels::Action::TeardownGpadl {
1345                gpadl_id,
1346                post_restore,
1347            } => {
1348                if !post_restore {
1349                    channel.gpadls.remove(gpadl_id, Box::new(|| ()));
1350                }
1351
1352                handle(
1353                    offer_id,
1354                    channel,
1355                    ChannelRequest::TeardownGpadl,
1356                    gpadl_id,
1357                    move |()| ChannelResponse::TeardownGpadl(gpadl_id),
1358                )
1359            }
1360            channels::Action::Modify { target_vp } => {
1361                if let ChannelState::Open {
1362                    guest_event_port, ..
1363                } = &mut channel.state
1364                {
1365                    if let Err(err) = guest_event_port.set_target_vp(target_vp) {
1366                        tracelimit::error_ratelimited!(
1367                            error = &err as &dyn std::error::Error,
1368                            channel = %channel.key,
1369                            "could not modify channel",
1370                        );
1371                        let seq = channel.seq;
1372                        Box::pin(async move {
1373                            (
1374                                offer_id,
1375                                seq,
1376                                Ok(ChannelResponse::Modify(protocol::STATUS_UNSUCCESSFUL)),
1377                            )
1378                        })
1379                    } else {
1380                        handle(
1381                            offer_id,
1382                            channel,
1383                            ChannelRequest::Modify,
1384                            ModifyRequest::TargetVp { target_vp },
1385                            ChannelResponse::Modify,
1386                        )
1387                    }
1388                } else {
1389                    unreachable!();
1390                }
1391            }
1392        };
1393        self.channel_responses.push(response);
1394    }
1395
1396    fn modify_connection(&mut self, mut request: ModifyConnectionRequest) -> anyhow::Result<()> {
1397        self.map_interrupt_page(request.interrupt_page)
1398            .context("Failed to map interrupt page.")?;
1399
1400        self.set_monitor_page(request.monitor_page, request.force)
1401            .context("Failed to map monitor page.")?;
1402
1403        if let Some(vp) = request.target_message_vp {
1404            self.message_port.set_target_vp(vp)?;
1405        }
1406
1407        if request.notify_relay {
1408            // If this server is handling MNF, the monitor pages should not be relayed.
1409            // N.B. Since the relay is being asked not to update the monitor pages, rather than
1410            //      reset them, this is only safe because the value of enable_mnf won't change after
1411            //      the server has been created.
1412            if self.enable_mnf {
1413                request.monitor_page = Update::Unchanged;
1414            }
1415
1416            self.relay_send.send(request.into());
1417        }
1418
1419        Ok(())
1420    }
1421
1422    fn forward_unhandled(&mut self, request: InitiateContactRequest) {
1423        if let Some(external_server) = &self.external_server_send {
1424            external_server.send(request);
1425        } else {
1426            tracing::warn!(?request, "nowhere to forward unhandled request")
1427        }
1428    }
1429
1430    fn inspect(&self, version: Option<VersionInfo>, offer_id: OfferId, req: inspect::Request<'_>) {
1431        let channel = self.channels.get(&offer_id).expect("should exist");
1432        let mut resp = req.respond();
1433        if let ChannelState::Open { open_params, .. } = &channel.state {
1434            let mem = if self.private_gm.is_some()
1435                && channel.flags.confidential_ring_buffer()
1436                && version
1437                    .expect("must be connected")
1438                    .feature_flags
1439                    .confidential_channels()
1440            {
1441                self.private_gm.as_ref().unwrap()
1442            } else {
1443                &self.gm
1444            };
1445
1446            inspect_rings(
1447                &mut resp,
1448                mem,
1449                channel.gpadls.clone(),
1450                &open_params.open_data,
1451            );
1452        }
1453    }
1454
1455    fn send_message(&mut self, message: &OutgoingMessage, target: MessageTarget) -> bool {
1456        // If the server is paused, queue all messages, to avoid affecting synic
1457        // state during/after it has been saved or reset.
1458        //
1459        // Note that messages to reserved channels or custom targets will be
1460        // dropped. However, such messages should only be sent in response to
1461        // guest requests, which should not be processed while the server is
1462        // paused.
1463        //
1464        // FUTURE: it would be better to ensure that no messages are generated
1465        // by operations that run while the server is paused. E.g., defer
1466        // sending offer or revoke messages for new or revoked offers. This
1467        // would prevent the queue from growing without bound.
1468        if !self.running && !self.send_messages_while_stopped {
1469            if !matches!(target, MessageTarget::Default) {
1470                tracelimit::error_ratelimited!(?target, "dropping message while paused");
1471            }
1472            return false;
1473        }
1474
1475        let mut port_storage;
1476        let port = match target {
1477            MessageTarget::Default => self.message_port.as_mut(),
1478            MessageTarget::ReservedChannel(offer_id, target) => {
1479                if let Some(port) = self.get_reserved_channel_message_port(offer_id, target) {
1480                    port.as_mut()
1481                } else {
1482                    // Updating the port failed, so there is no way to send the message.
1483                    return true;
1484                }
1485            }
1486            MessageTarget::Custom(target) => {
1487                port_storage = match self.synic.new_guest_message_port(
1488                    self.redirect_vtl,
1489                    target.vp,
1490                    target.sint,
1491                ) {
1492                    Ok(port) => port,
1493                    Err(err) => {
1494                        tracing::error!(
1495                            ?err,
1496                            ?self.redirect_vtl,
1497                            ?target,
1498                            "could not create message port"
1499                        );
1500
1501                        // There is no way to send the message.
1502                        return true;
1503                    }
1504                };
1505                port_storage.as_mut()
1506            }
1507        };
1508
1509        // If this returns Pending, the channels module will queue the message and the ServerTask
1510        // main loop will try to send it again later.
1511        matches!(
1512            port.poll_post_message(
1513                &mut std::task::Context::from_waker(std::task::Waker::noop()),
1514                VMBUS_MESSAGE_TYPE,
1515                message.data()
1516            ),
1517            Poll::Ready(())
1518        )
1519    }
1520
1521    fn notify_hvsock(&mut self, request: &HvsockConnectRequest) {
1522        self.hvsock_requests += 1;
1523        self.hvsock_send.send(*request);
1524    }
1525
1526    fn reset_complete(&mut self) {
1527        if let Some(monitor) = self.synic.monitor_support() {
1528            if let Err(err) = monitor.set_monitor_page(self.vtl, None) {
1529                tracing::warn!(?err, "resetting monitor page failed")
1530            }
1531        }
1532
1533        self.unreserve_channels();
1534        for done in self.reset_done.drain(..) {
1535            done.complete(());
1536        }
1537    }
1538
1539    fn unload_complete(&mut self) {
1540        self.unreserve_channels();
1541    }
1542}
1543
1544impl ServerTaskInner {
1545    fn open_channel(
1546        &mut self,
1547        offer_id: OfferId,
1548        open_params: &OpenParams,
1549    ) -> anyhow::Result<(&mut Channel, Interrupt)> {
1550        let channel = self
1551            .channels
1552            .get_mut(&offer_id)
1553            .expect("channel does not exist");
1554
1555        // For pre-Win8 guests, the host-to-guest event always targets vp 0 and the channel
1556        // bitmap is used instead of the event flag.
1557        let (target_vp, event_flag) = if self.channel_bitmap.is_some() {
1558            (0, 0)
1559        } else {
1560            (open_params.open_data.target_vp, open_params.event_flag)
1561        };
1562        let (target_vtl, target_sint) = if open_params.flags.redirect_interrupt() {
1563            (self.redirect_vtl, self.redirect_sint)
1564        } else {
1565            (self.vtl, SINT)
1566        };
1567
1568        let guest_event_port = self.synic.new_guest_event_port(
1569            VmbusServer::get_child_event_port_id(open_params.channel_id, SINT, self.vtl),
1570            target_vtl,
1571            target_vp,
1572            target_sint,
1573            event_flag,
1574            open_params.monitor_info,
1575        )?;
1576
1577        let interrupt = ChannelBitmap::create_interrupt(
1578            &self.channel_bitmap,
1579            guest_event_port.interrupt(),
1580            open_params.event_flag,
1581        );
1582
1583        // Delete any previously reserved state.
1584        channel.reserved_state.message_port = None;
1585
1586        // If the channel is reserved, create a message port for it.
1587        if let Some(target) = open_params.reserved_target {
1588            channel.reserved_state.message_port = Some(self.synic.new_guest_message_port(
1589                self.redirect_vtl,
1590                target.vp,
1591                target.sint,
1592            )?);
1593
1594            channel.reserved_state.target = target;
1595        }
1596
1597        channel.state = ChannelState::Opening {
1598            open_params: *open_params,
1599            guest_event_port,
1600            host_to_guest_interrupt: interrupt.clone(),
1601        };
1602        Ok((channel, interrupt))
1603    }
1604
1605    fn complete_open(
1606        &mut self,
1607        offer_id: OfferId,
1608        result: Option<OpenResult>,
1609    ) -> anyhow::Result<&mut Channel> {
1610        let channel = self
1611            .channels
1612            .get_mut(&offer_id)
1613            .expect("channel does not exist");
1614
1615        channel.state = if let Some(result) = result {
1616            // The channel will be left in the FailedOpen state only if an error occurs in the match
1617            // arm.
1618            match std::mem::replace(&mut channel.state, ChannelState::FailedOpen) {
1619                ChannelState::Opening {
1620                    open_params,
1621                    guest_event_port,
1622                    host_to_guest_interrupt,
1623                } => {
1624                    let guest_to_host_event =
1625                        Arc::new(ChannelEvent(result.guest_to_host_interrupt));
1626                    // Always register with the channel bitmap; if Win7, this may be unnecessary.
1627                    if let Some(channel_bitmap) = self.channel_bitmap.as_ref() {
1628                        channel_bitmap.register_channel(
1629                            open_params.event_flag,
1630                            guest_to_host_event.0.clone(),
1631                        );
1632                    }
1633                    // Always set up an event port; if V1, this will be unused.
1634                    let event_port = self
1635                        .synic
1636                        .add_event_port(
1637                            open_params.connection_id,
1638                            self.vtl,
1639                            guest_to_host_event.clone(),
1640                            open_params.monitor_info,
1641                        )
1642                        .with_context(|| {
1643                            format!(
1644                                "failed to create event port for VTL {:?}, connection ID {:#x}",
1645                                self.vtl, open_params.connection_id
1646                            )
1647                        })?;
1648
1649                    ChannelState::Open {
1650                        open_params,
1651                        _event_port: event_port,
1652                        guest_event_port,
1653                        host_to_guest_interrupt,
1654                        guest_to_host_event,
1655                    }
1656                }
1657                s => {
1658                    tracing::error!("attempting to complete open of open or closed channel");
1659                    // Restore the original state
1660                    s
1661                }
1662            }
1663        } else {
1664            ChannelState::Closed
1665        };
1666        Ok(channel)
1667    }
1668
1669    /// If the client specified an interrupt page, map it into host memory and
1670    /// set up the shared event port.
1671    fn map_interrupt_page(&mut self, interrupt_page: Update<u64>) -> anyhow::Result<()> {
1672        let interrupt_page = match interrupt_page {
1673            Update::Unchanged => return Ok(()),
1674            Update::Reset => {
1675                self.channel_bitmap = None;
1676                self.shared_event_port = None;
1677                return Ok(());
1678            }
1679            Update::Set(interrupt_page) => interrupt_page,
1680        };
1681
1682        assert_ne!(interrupt_page, 0);
1683
1684        if interrupt_page % PAGE_SIZE as u64 != 0 {
1685            anyhow::bail!("interrupt page {:#x} is not page aligned", interrupt_page);
1686        }
1687
1688        // Use a subrange to access the interrupt page to give GuestMemory's without a full mapping
1689        // a chance to create one.
1690        let interrupt_page = self
1691            .gm
1692            .lockable_subrange(interrupt_page, PAGE_SIZE as u64)?
1693            .lock_gpns(false, &[0])?;
1694
1695        let channel_bitmap = Arc::new(ChannelBitmap::new(interrupt_page));
1696        self.channel_bitmap = Some(channel_bitmap.clone());
1697
1698        // Create the shared event port for pre-Win8 guests.
1699        let interrupt = Interrupt::from_fn(move || {
1700            channel_bitmap.handle_shared_interrupt();
1701        });
1702
1703        self.shared_event_port = Some(self.synic.add_event_port(
1704            SHARED_EVENT_CONNECTION_ID,
1705            self.vtl,
1706            Arc::new(ChannelEvent(interrupt)),
1707            None,
1708        )?);
1709
1710        Ok(())
1711    }
1712
1713    fn set_monitor_page(
1714        &mut self,
1715        monitor_page: Update<MonitorPageGpas>,
1716        force: bool,
1717    ) -> anyhow::Result<()> {
1718        let monitor_page = match monitor_page {
1719            Update::Unchanged => return Ok(()),
1720            Update::Reset => None,
1721            Update::Set(value) => Some(value),
1722        };
1723
1724        // Force is used by restore because there may be restored channels in the open state.
1725        // TODO: can this check be moved into channels.rs?
1726        if !force
1727            && self.channels.iter().any(|(_, c)| {
1728                matches!(
1729                    &c.state,
1730                    ChannelState::Open {
1731                        open_params,
1732                        ..
1733                    } | ChannelState::Opening {
1734                        open_params,
1735                        ..
1736                    } if open_params.monitor_info.is_some()
1737                )
1738            })
1739        {
1740            anyhow::bail!("attempt to change monitor page while open channels using mnf");
1741        }
1742
1743        if self.enable_mnf {
1744            if let Some(monitor) = self.synic.monitor_support() {
1745                if let Err(err) = monitor.set_monitor_page(self.vtl, monitor_page) {
1746                    anyhow::bail!(
1747                        "setting monitor page failed, err = {err:?}, monitor_page = {monitor_page:?}"
1748                    );
1749                }
1750            }
1751        }
1752
1753        Ok(())
1754    }
1755
1756    fn get_reserved_channel_message_port(
1757        &mut self,
1758        offer_id: OfferId,
1759        new_target: ConnectionTarget,
1760    ) -> Option<&mut Box<dyn GuestMessagePort>> {
1761        let channel = self
1762            .channels
1763            .get_mut(&offer_id)
1764            .expect("channel does not exist");
1765
1766        assert!(
1767            channel.reserved_state.message_port.is_some(),
1768            "channel is not reserved"
1769        );
1770
1771        // On close, the guest may have changed the message target it wants to use for the close
1772        // response. If so, update the message port.
1773        if channel.reserved_state.target.sint != new_target.sint {
1774            // Destroy the old port before creating the new one.
1775            channel.reserved_state.message_port = None;
1776            let message_port = self
1777                .synic
1778                .new_guest_message_port(self.redirect_vtl, new_target.vp, new_target.sint)
1779                .inspect_err(|err| {
1780                    tracing::error!(
1781                        ?err,
1782                        ?self.redirect_vtl,
1783                        ?new_target,
1784                        "could not create reserved channel message port"
1785                    )
1786                })
1787                .ok()?;
1788
1789            channel.reserved_state.message_port = Some(message_port);
1790            channel.reserved_state.target = new_target;
1791        } else if channel.reserved_state.target.vp != new_target.vp {
1792            let message_port = channel.reserved_state.message_port.as_mut().unwrap();
1793
1794            // The vp has changed, but the SINT is the same. Just update the vp. If this fails,
1795            // ignore it and just send to the old vp.
1796            if let Err(err) = message_port.set_target_vp(new_target.vp) {
1797                tracing::error!(
1798                    ?err,
1799                    ?self.redirect_vtl,
1800                    ?new_target,
1801                    "could not update reserved channel message port"
1802                );
1803            }
1804
1805            channel.reserved_state.target = new_target;
1806            return Some(message_port);
1807        }
1808
1809        Some(channel.reserved_state.message_port.as_mut().unwrap())
1810    }
1811
1812    fn unreserve_channels(&mut self) {
1813        // Unreserve all closed channels.
1814        for channel in self.channels.values_mut() {
1815            if let ChannelState::Closed = channel.state {
1816                channel.reserved_state.message_port = None;
1817            }
1818        }
1819    }
1820}
1821
1822/// Control point for [`VmbusServer`], allowing callers to offer channels.
1823#[derive(Clone)]
1824pub struct VmbusServerControl {
1825    mem: GuestMemory,
1826    private_mem: Option<GuestMemory>,
1827    send: mesh::Sender<OfferRequest>,
1828    use_event: bool,
1829    force_confidential_external_memory: bool,
1830}
1831
1832impl VmbusServerControl {
1833    /// Offers a channel to the vmbus server, where the flags and user_defined data are already set.
1834    /// This is used by the relay to forward the host's parameters.
1835    pub async fn offer_core(&self, offer_info: OfferInfo) -> anyhow::Result<OfferResources> {
1836        let flags = offer_info.params.flags;
1837        self.send
1838            .call_failable(OfferRequest::Offer, offer_info)
1839            .await?;
1840        Ok(OfferResources::new(
1841            self.mem.clone(),
1842            if flags.confidential_ring_buffer() || flags.confidential_external_memory() {
1843                self.private_mem.clone()
1844            } else {
1845                None
1846            },
1847        ))
1848    }
1849
1850    /// Force reset all channels and protocol state, without requiring the
1851    /// server to be paused.
1852    pub async fn force_reset(&self) -> anyhow::Result<()> {
1853        self.send
1854            .call(OfferRequest::ForceReset, ())
1855            .await
1856            .context("vmbus server is gone")
1857    }
1858
1859    async fn offer(&self, request: OfferInput) -> anyhow::Result<OfferResources> {
1860        let mut offer_info = OfferInfo {
1861            params: request.params.into(),
1862            request_send: request.request_send,
1863            server_request_recv: request.server_request_recv,
1864        };
1865
1866        if self.force_confidential_external_memory {
1867            tracing::warn!(
1868                key = %offer_info.params.key(),
1869                "forcing confidential external memory for channel"
1870            );
1871
1872            offer_info
1873                .params
1874                .flags
1875                .set_confidential_external_memory(true);
1876        }
1877
1878        self.offer_core(offer_info).await
1879    }
1880}
1881
1882/// Inspects the specified ring buffer state by directly accessing guest memory.
1883fn inspect_rings(
1884    resp: &mut inspect::Response<'_>,
1885    gm: &GuestMemory,
1886    gpadl_map: Arc<GpadlMap>,
1887    open_data: &OpenData,
1888) -> Option<()> {
1889    let gpadl = gpadl_map
1890        .view()
1891        .map(GpadlId(open_data.ring_gpadl_id.0))
1892        .ok()?;
1893    let aligned = AlignedGpadlView::new(gpadl).ok()?;
1894    let (in_gpadl, out_gpadl) = aligned.split(open_data.ring_offset).ok()?;
1895    if let Ok(incoming_mem) = GpadlRingMem::new(in_gpadl, gm) {
1896        resp.child("incoming_ring", |req| ring::inspect_ring(incoming_mem, req));
1897    }
1898    if let Ok(outgoing_mem) = GpadlRingMem::new(out_gpadl, gm) {
1899        resp.child("outgoing_ring", |req| ring::inspect_ring(outgoing_mem, req));
1900    }
1901    Some(())
1902}
1903
1904pub(crate) struct MessageSender {
1905    send: mpsc::Sender<SynicMessage>,
1906    multiclient: bool,
1907}
1908
1909impl MessageSender {
1910    fn poll_handle_message(
1911        &self,
1912        cx: &mut std::task::Context<'_>,
1913        msg: &[u8],
1914        trusted: bool,
1915    ) -> Poll<Result<(), SendError>> {
1916        let mut send = self.send.clone();
1917        ready!(send.poll_ready(cx))?;
1918        send.start_send(SynicMessage {
1919            data: msg.to_vec(),
1920            multiclient: self.multiclient,
1921            trusted,
1922        })?;
1923
1924        Poll::Ready(Ok(()))
1925    }
1926}
1927
1928impl MessagePort for MessageSender {
1929    fn poll_handle_message(
1930        &self,
1931        cx: &mut std::task::Context<'_>,
1932        msg: &[u8],
1933        trusted: bool,
1934    ) -> Poll<()> {
1935        if let Err(err) = ready!(self.poll_handle_message(cx, msg, trusted)) {
1936            tracelimit::error_ratelimited!(
1937                error = &err as &dyn std::error::Error,
1938                "failed to send message"
1939            );
1940        }
1941
1942        Poll::Ready(())
1943    }
1944}
1945
1946#[async_trait]
1947impl ParentBus for VmbusServerControl {
1948    async fn add_child(&self, request: OfferInput) -> anyhow::Result<OfferResources> {
1949        self.offer(request).await
1950    }
1951
1952    fn clone_bus(&self) -> Box<dyn ParentBus> {
1953        Box::new(self.clone())
1954    }
1955
1956    fn use_event(&self) -> bool {
1957        self.use_event
1958    }
1959}
1960
1961#[cfg(test)]
1962mod tests {
1963    use super::*;
1964    use pal_async::DefaultDriver;
1965    use pal_async::async_test;
1966    use pal_async::driver::SpawnDriver;
1967    use pal_async::timer::Instant;
1968    use pal_async::timer::PolledTimer;
1969    use parking_lot::Mutex;
1970    use protocol::UserDefinedData;
1971    use std::time::Duration;
1972    use test_with_tracing::test;
1973    use vmbus_channel::bus::OfferParams;
1974    use vmbus_core::protocol::ChannelId;
1975    use vmbus_core::protocol::VmbusMessage;
1976    use vmcore::synic::MonitorInfo;
1977    use vmcore::synic::SynicPortAccess;
1978    use zerocopy::FromBytes;
1979    use zerocopy::Immutable;
1980    use zerocopy::IntoBytes;
1981    use zerocopy::KnownLayout;
1982
1983    struct MockSynicInner {
1984        message_port: Option<Arc<dyn MessagePort>>,
1985    }
1986
1987    struct MockSynic {
1988        inner: Mutex<MockSynicInner>,
1989        message_send: mesh::Sender<Vec<u8>>,
1990        spawner: Arc<dyn SpawnDriver>,
1991    }
1992
1993    impl MockSynic {
1994        fn new(message_send: mesh::Sender<Vec<u8>>, spawner: Arc<dyn SpawnDriver>) -> Self {
1995            Self {
1996                inner: Mutex::new(MockSynicInner { message_port: None }),
1997                message_send,
1998                spawner,
1999            }
2000        }
2001
2002        fn send_message(&self, msg: impl VmbusMessage + IntoBytes + Immutable + KnownLayout) {
2003            self.send_message_core(OutgoingMessage::new(&msg), false);
2004        }
2005
2006        fn send_message_trusted(
2007            &self,
2008            msg: impl VmbusMessage + IntoBytes + Immutable + KnownLayout,
2009        ) {
2010            self.send_message_core(OutgoingMessage::new(&msg), true);
2011        }
2012
2013        fn send_message_core(&self, msg: OutgoingMessage, trusted: bool) {
2014            assert_eq!(
2015                self.inner
2016                    .lock()
2017                    .message_port
2018                    .as_ref()
2019                    .unwrap()
2020                    .poll_handle_message(
2021                        &mut std::task::Context::from_waker(std::task::Waker::noop()),
2022                        msg.data(),
2023                        trusted,
2024                    ),
2025                Poll::Ready(())
2026            );
2027        }
2028    }
2029
2030    #[derive(Debug)]
2031    struct MockGuestPort {}
2032
2033    impl GuestEventPort for MockGuestPort {
2034        fn interrupt(&self) -> Interrupt {
2035            Interrupt::null()
2036        }
2037
2038        fn set_target_vp(&mut self, _vp: u32) -> Result<(), vmcore::synic::HypervisorError> {
2039            Ok(())
2040        }
2041    }
2042
2043    struct MockGuestMessagePort {
2044        send: mesh::Sender<Vec<u8>>,
2045        spawner: Arc<dyn SpawnDriver>,
2046        timer: Option<(PolledTimer, Instant)>,
2047    }
2048
2049    impl GuestMessagePort for MockGuestMessagePort {
2050        fn poll_post_message(
2051            &mut self,
2052            cx: &mut std::task::Context<'_>,
2053            _typ: u32,
2054            payload: &[u8],
2055        ) -> Poll<()> {
2056            if let Some((timer, deadline)) = self.timer.as_mut() {
2057                ready!(timer.sleep_until(*deadline).poll_unpin(cx));
2058                self.timer = None;
2059            }
2060
2061            // Return pending 25% of the time.
2062            let mut pending_chance = [0; 1];
2063            getrandom::fill(&mut pending_chance).unwrap();
2064            if pending_chance[0] % 4 == 0 {
2065                let mut timer = PolledTimer::new(self.spawner.as_ref());
2066                let deadline = Instant::now() + Duration::from_millis(10);
2067                match timer.sleep_until(deadline).poll_unpin(cx) {
2068                    Poll::Ready(_) => {}
2069                    Poll::Pending => {
2070                        self.timer = Some((timer, deadline));
2071                        return Poll::Pending;
2072                    }
2073                }
2074            }
2075
2076            self.send.send(payload.into());
2077            Poll::Ready(())
2078        }
2079
2080        fn set_target_vp(&mut self, _vp: u32) -> Result<(), vmcore::synic::HypervisorError> {
2081            Ok(())
2082        }
2083    }
2084
2085    impl Inspect for MockGuestMessagePort {
2086        fn inspect(&self, _req: inspect::Request<'_>) {}
2087    }
2088
2089    impl SynicPortAccess for MockSynic {
2090        fn add_message_port(
2091            &self,
2092            connection_id: u32,
2093            _minimum_vtl: Vtl,
2094            port: Arc<dyn MessagePort>,
2095        ) -> Result<Box<dyn Sync + Send>, vmcore::synic::Error> {
2096            self.inner.lock().message_port = Some(port);
2097            Ok(Box::new(connection_id))
2098        }
2099
2100        fn add_event_port(
2101            &self,
2102            connection_id: u32,
2103            _minimum_vtl: Vtl,
2104            _port: Arc<dyn EventPort>,
2105            _monitor_info: Option<MonitorInfo>,
2106        ) -> Result<Box<dyn Sync + Send>, vmcore::synic::Error> {
2107            Ok(Box::new(connection_id))
2108        }
2109
2110        fn new_guest_message_port(
2111            &self,
2112            _vtl: Vtl,
2113            _vp: u32,
2114            _sint: u8,
2115        ) -> Result<Box<(dyn GuestMessagePort)>, vmcore::synic::HypervisorError> {
2116            Ok(Box::new(MockGuestMessagePort {
2117                send: self.message_send.clone(),
2118                spawner: Arc::clone(&self.spawner),
2119                timer: None,
2120            }))
2121        }
2122
2123        fn new_guest_event_port(
2124            &self,
2125            _port_id: u32,
2126            _vtl: Vtl,
2127            _vp: u32,
2128            _sint: u8,
2129            _flag: u16,
2130            _monitor_info: Option<MonitorInfo>,
2131        ) -> Result<Box<(dyn GuestEventPort)>, vmcore::synic::HypervisorError> {
2132            Ok(Box::new(MockGuestPort {}))
2133        }
2134
2135        fn prefer_os_events(&self) -> bool {
2136            false
2137        }
2138    }
2139
2140    struct TestChannel {
2141        request_recv: mesh::Receiver<ChannelRequest>,
2142        server_request_send: mesh::Sender<ChannelServerRequest>,
2143        _resources: OfferResources,
2144    }
2145
2146    impl TestChannel {
2147        async fn next_request(&mut self) -> ChannelRequest {
2148            self.request_recv.next().await.unwrap()
2149        }
2150
2151        async fn handle_gpadl(&mut self) {
2152            let ChannelRequest::Gpadl(rpc) = self.next_request().await else {
2153                panic!("Wrong request");
2154            };
2155
2156            rpc.complete(true);
2157        }
2158
2159        async fn handle_open(&mut self, f: fn(&OpenRequest)) {
2160            let ChannelRequest::Open(rpc) = self.next_request().await else {
2161                panic!("Wrong request");
2162            };
2163
2164            f(rpc.input());
2165            rpc.complete(Some(OpenResult {
2166                guest_to_host_interrupt: Interrupt::null(),
2167            }));
2168        }
2169
2170        async fn handle_gpadl_teardown(&mut self) {
2171            let rpc = self.get_gpadl_teardown().await;
2172            rpc.complete(());
2173        }
2174
2175        async fn get_gpadl_teardown(&mut self) -> Rpc<GpadlId, ()> {
2176            let ChannelRequest::TeardownGpadl(rpc) = self.next_request().await else {
2177                panic!("Wrong request");
2178            };
2179
2180            rpc
2181        }
2182
2183        async fn restore(&self) {
2184            self.server_request_send
2185                .call(ChannelServerRequest::Restore, None)
2186                .await
2187                .unwrap()
2188                .unwrap();
2189        }
2190    }
2191
2192    struct TestEnv {
2193        vmbus: VmbusServer,
2194        synic: Arc<MockSynic>,
2195        message_recv: mesh::Receiver<Vec<u8>>,
2196        trusted: bool,
2197    }
2198
2199    impl TestEnv {
2200        fn new(spawner: DefaultDriver) -> Self {
2201            let spawner: Arc<dyn SpawnDriver> = Arc::new(spawner);
2202            let (message_send, message_recv) = mesh::channel();
2203            let synic = Arc::new(MockSynic::new(message_send, Arc::clone(&spawner)));
2204            let gm = GuestMemory::empty();
2205            let vmbus = VmbusServerBuilder::new(&spawner, synic.clone(), gm)
2206                .build()
2207                .unwrap();
2208
2209            Self {
2210                vmbus,
2211                synic,
2212                message_recv,
2213                trusted: false,
2214            }
2215        }
2216
2217        async fn offer(&self, id: u32, allow_confidential_external_memory: bool) -> TestChannel {
2218            let guid = Guid {
2219                data1: id,
2220                ..Guid::ZERO
2221            };
2222            let (request_send, request_recv) = mesh::channel();
2223            let (server_request_send, server_request_recv) = mesh::channel();
2224            let offer = OfferInput {
2225                request_send,
2226                server_request_recv,
2227                params: OfferParams {
2228                    interface_name: "test".into(),
2229                    instance_id: guid,
2230                    interface_id: guid,
2231                    mmio_megabytes: 0,
2232                    mmio_megabytes_optional: 0,
2233                    channel_type: vmbus_channel::bus::ChannelType::Device {
2234                        pipe_packets: false,
2235                    },
2236                    subchannel_index: 0,
2237                    mnf_interrupt_latency: None,
2238                    offer_order: None,
2239                    allow_confidential_external_memory,
2240                },
2241            };
2242
2243            let control = self.vmbus.control();
2244            let _resources = control.add_child(offer).await.unwrap();
2245
2246            TestChannel {
2247                request_recv,
2248                server_request_send,
2249                _resources,
2250            }
2251        }
2252
2253        async fn gpadl(&mut self, channel_id: u32, gpadl_id: u32, channel: &mut TestChannel) {
2254            self.synic.send_message_core(
2255                OutgoingMessage::with_data(
2256                    &protocol::GpadlHeader {
2257                        channel_id: ChannelId(channel_id),
2258                        gpadl_id: GpadlId(gpadl_id),
2259                        count: 1,
2260                        len: 16,
2261                    },
2262                    [1u64, 0u64].as_bytes(),
2263                ),
2264                self.trusted,
2265            );
2266
2267            channel.handle_gpadl().await;
2268            self.expect_response(protocol::MessageType::GPADL_CREATED)
2269                .await;
2270        }
2271
2272        async fn open_channel(
2273            &mut self,
2274            channel_id: u32,
2275            ring_gpadl_id: u32,
2276            channel: &mut TestChannel,
2277            f: fn(&OpenRequest),
2278        ) {
2279            self.gpadl(channel_id, ring_gpadl_id, channel).await;
2280            self.synic.send_message_core(
2281                OutgoingMessage::new(&protocol::OpenChannel {
2282                    channel_id: ChannelId(channel_id),
2283                    open_id: 0,
2284                    ring_buffer_gpadl_id: GpadlId(ring_gpadl_id),
2285                    target_vp: 0,
2286                    downstream_ring_buffer_page_offset: 0,
2287                    user_data: UserDefinedData::default(),
2288                }),
2289                self.trusted,
2290            );
2291
2292            channel.handle_open(f).await;
2293            self.expect_response(protocol::MessageType::OPEN_CHANNEL_RESULT)
2294                .await;
2295        }
2296
2297        async fn expect_response(&mut self, expected: protocol::MessageType) {
2298            let data = self.message_recv.next().await.unwrap();
2299            let header = protocol::MessageHeader::read_from_prefix(&data).unwrap().0; // TODO: zerocopy: use-rest-of-range (https://github.com/microsoft/openvmm/issues/759)
2300            assert_eq!(expected, header.message_type())
2301        }
2302
2303        async fn get_response<T: VmbusMessage + FromBytes + Immutable + KnownLayout>(
2304            &mut self,
2305        ) -> T {
2306            let data = self.message_recv.next().await.unwrap();
2307            let (header, message) = protocol::MessageHeader::read_from_prefix(&data).unwrap(); // TODO: zerocopy: unwrap (https://github.com/microsoft/openvmm/issues/759)
2308            assert_eq!(T::MESSAGE_TYPE, header.message_type());
2309            T::read_from_prefix(message).unwrap().0 // TODO: zerocopy: use-rest-of-range (https://github.com/microsoft/openvmm/issues/759)
2310        }
2311
2312        fn initiate_contact(
2313            &mut self,
2314            version: protocol::Version,
2315            feature_flags: protocol::FeatureFlags,
2316            trusted: bool,
2317        ) {
2318            self.synic.send_message_core(
2319                OutgoingMessage::new(&protocol::InitiateContact {
2320                    version_requested: version as u32,
2321                    target_message_vp: 0,
2322                    child_to_parent_monitor_page_gpa: 0,
2323                    parent_to_child_monitor_page_gpa: 0,
2324                    interrupt_page_or_target_info: protocol::TargetInfo::new()
2325                        .with_sint(2)
2326                        .with_vtl(0)
2327                        .with_feature_flags(feature_flags.into())
2328                        .into(),
2329                }),
2330                trusted,
2331            );
2332
2333            self.trusted = trusted;
2334        }
2335
2336        async fn connect(
2337            &mut self,
2338            offer_count: u32,
2339            feature_flags: protocol::FeatureFlags,
2340            trusted: bool,
2341        ) {
2342            self.initiate_contact(protocol::Version::Copper, feature_flags, trusted);
2343
2344            self.expect_response(protocol::MessageType::VERSION_RESPONSE)
2345                .await;
2346
2347            self.synic
2348                .send_message_core(OutgoingMessage::new(&protocol::RequestOffers {}), trusted);
2349
2350            for _ in 0..offer_count {
2351                self.expect_response(protocol::MessageType::OFFER_CHANNEL)
2352                    .await;
2353            }
2354
2355            self.expect_response(protocol::MessageType::ALL_OFFERS_DELIVERED)
2356                .await;
2357        }
2358    }
2359
2360    #[async_test]
2361    async fn test_save_restore(spawner: DefaultDriver) {
2362        // Most save/restore state is tested in mod channels::tests; this test specifically checks
2363        // that ServerTaskInner correctly handles some aspects of the save/restore.
2364        //
2365        // If this test fails, it is more likely to hang than panic.
2366        let mut env = TestEnv::new(spawner);
2367        let mut channel = env.offer(1, false).await;
2368        env.vmbus.start();
2369        env.connect(1, protocol::FeatureFlags::new(), false).await;
2370
2371        // Create a GPADL for the channel.
2372        env.gpadl(1, 10, &mut channel).await;
2373
2374        // Start tearing it down.
2375        env.synic.send_message(protocol::GpadlTeardown {
2376            channel_id: ChannelId(1),
2377            gpadl_id: GpadlId(10),
2378        });
2379
2380        // Wait for the teardown request here to make sure the server has processed the teardown
2381        // message, but do not complete it before saving.
2382        let rpc = channel.get_gpadl_teardown().await;
2383        env.vmbus.stop().await;
2384        let saved_state = env.vmbus.save().await;
2385        env.vmbus.start();
2386
2387        // Finish tearing down the gpadl and release the channel so the server can reset.
2388        rpc.complete(());
2389        env.expect_response(protocol::MessageType::GPADL_TORNDOWN)
2390            .await;
2391
2392        env.synic.send_message(protocol::RelIdReleased {
2393            channel_id: ChannelId(1),
2394        });
2395
2396        env.vmbus.reset().await;
2397        env.vmbus.stop().await;
2398
2399        // When restoring with a gpadl in the TearingDown state, the teardown request for the device
2400        // will be repeated. This must not panic.
2401        env.vmbus.restore(saved_state).await.unwrap();
2402        channel.restore().await;
2403        env.vmbus.post_restore().await.unwrap();
2404        env.vmbus.start();
2405
2406        // Handle the teardown after restore.
2407        channel.handle_gpadl_teardown().await;
2408        env.expect_response(protocol::MessageType::GPADL_TORNDOWN)
2409            .await;
2410
2411        env.synic.send_message(protocol::RelIdReleased {
2412            channel_id: ChannelId(1),
2413        });
2414    }
2415
2416    #[async_test]
2417    async fn test_confidential_connection(spawner: DefaultDriver) {
2418        let mut env = TestEnv::new(spawner);
2419        // Add regular bus child channels, one of which supports confidential external memory.
2420        let mut channel = env.offer(1, false).await;
2421        let mut channel2 = env.offer(2, true).await;
2422
2423        // Add a channel directly, like the relay would do.
2424        let (request_send, request_recv) = mesh::channel();
2425        let (server_request_send, server_request_recv) = mesh::channel();
2426        let id = Guid {
2427            data1: 3,
2428            ..Guid::ZERO
2429        };
2430        let control = env.vmbus.control();
2431        let relay_resources = control
2432            .offer_core(OfferInfo {
2433                params: OfferParamsInternal {
2434                    interface_name: "test".into(),
2435                    instance_id: id,
2436                    interface_id: id,
2437                    mmio_megabytes: 0,
2438                    mmio_megabytes_optional: 0,
2439                    subchannel_index: 0,
2440                    use_mnf: MnfUsage::Disabled,
2441                    offer_order: None,
2442                    flags: protocol::OfferFlags::new().with_enumerate_device_interface(true),
2443                    ..Default::default()
2444                },
2445                request_send,
2446                server_request_recv,
2447            })
2448            .await
2449            .unwrap();
2450
2451        let mut relay_channel = TestChannel {
2452            request_recv,
2453            server_request_send,
2454            _resources: relay_resources,
2455        };
2456
2457        env.vmbus.start();
2458        env.initiate_contact(
2459            protocol::Version::Copper,
2460            protocol::FeatureFlags::new().with_confidential_channels(true),
2461            true,
2462        );
2463
2464        env.expect_response(protocol::MessageType::VERSION_RESPONSE)
2465            .await;
2466
2467        env.synic.send_message_trusted(protocol::RequestOffers {});
2468
2469        // All offers added with add_child have confidential ring support.
2470        let offer = env.get_response::<protocol::OfferChannel>().await;
2471        assert!(offer.flags.confidential_ring_buffer());
2472        assert!(!offer.flags.confidential_external_memory());
2473        let offer = env.get_response::<protocol::OfferChannel>().await;
2474        assert!(offer.flags.confidential_ring_buffer());
2475        assert!(offer.flags.confidential_external_memory());
2476
2477        // The "relay" channel will not have its flags modified.
2478        let offer = env.get_response::<protocol::OfferChannel>().await;
2479        assert!(!offer.flags.confidential_ring_buffer());
2480        assert!(!offer.flags.confidential_external_memory());
2481
2482        env.expect_response(protocol::MessageType::ALL_OFFERS_DELIVERED)
2483            .await;
2484
2485        // Make sure that the correct confidential flags are set in the open request when opening
2486        // the channels.
2487        env.open_channel(1, 1, &mut channel, |request| {
2488            assert!(request.use_confidential_ring);
2489            assert!(!request.use_confidential_external_memory);
2490        })
2491        .await;
2492
2493        env.open_channel(2, 2, &mut channel2, |request| {
2494            assert!(request.use_confidential_ring);
2495            assert!(request.use_confidential_external_memory);
2496        })
2497        .await;
2498
2499        env.open_channel(3, 3, &mut relay_channel, |request| {
2500            assert!(!request.use_confidential_ring);
2501            assert!(!request.use_confidential_external_memory);
2502        })
2503        .await;
2504    }
2505
2506    #[async_test]
2507    async fn test_confidential_channels_unsupported(spawner: DefaultDriver) {
2508        let mut env = TestEnv::new(spawner);
2509        let mut channel = env.offer(1, false).await;
2510        let mut channel2 = env.offer(2, true).await;
2511
2512        env.vmbus.start();
2513        env.connect(2, protocol::FeatureFlags::new(), true).await;
2514
2515        // Make sure that the correct confidential flags are always false when the client doesn't
2516        // support confidential channels.
2517        env.open_channel(1, 1, &mut channel, |request| {
2518            assert!(!request.use_confidential_ring);
2519            assert!(!request.use_confidential_external_memory);
2520        })
2521        .await;
2522
2523        env.open_channel(2, 2, &mut channel2, |request| {
2524            assert!(!request.use_confidential_ring);
2525            assert!(!request.use_confidential_external_memory);
2526        })
2527        .await;
2528    }
2529
2530    #[async_test]
2531    async fn test_confidential_channels_untrusted(spawner: DefaultDriver) {
2532        let mut env = TestEnv::new(spawner);
2533        let mut channel = env.offer(1, false).await;
2534        let mut channel2 = env.offer(2, true).await;
2535
2536        env.vmbus.start();
2537        // Client claims to support confidential channels, but they can't be used because the
2538        // connection is untrusted.
2539        env.connect(
2540            2,
2541            protocol::FeatureFlags::new().with_confidential_channels(true),
2542            false,
2543        )
2544        .await;
2545
2546        // Make sure that the correct confidential flags are always false when the client doesn't
2547        // support confidential channels.
2548        env.open_channel(1, 1, &mut channel, |request| {
2549            assert!(!request.use_confidential_ring);
2550            assert!(!request.use_confidential_external_memory);
2551        })
2552        .await;
2553
2554        env.open_channel(2, 2, &mut channel2, |request| {
2555            assert!(!request.use_confidential_ring);
2556            assert!(!request.use_confidential_external_memory);
2557        })
2558        .await;
2559    }
2560}