vmbus_server/
lib.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4#![expect(missing_docs)]
5#![forbid(unsafe_code)]
6
7mod channel_bitmap;
8pub mod channels;
9pub mod hvsock;
10mod monitor;
11mod proxyintegration;
12#[cfg(test)]
13mod tests;
14
15/// The GUID type used for vmbus channel identifiers.
16pub type Guid = guid::Guid;
17
18use anyhow::Context;
19use async_trait::async_trait;
20use channel_bitmap::ChannelBitmap;
21use channels::ConnectionTarget;
22pub use channels::InitiateContactRequest;
23use channels::MessageTarget;
24pub use channels::MnfUsage;
25use channels::ModifyConnectionRequest;
26use channels::ModifyConnectionResponse;
27use channels::Notifier;
28use channels::OfferId;
29pub use channels::OfferParamsInternal;
30use channels::OpenParams;
31use channels::RestoreError;
32pub use channels::Update;
33use futures::FutureExt;
34use futures::StreamExt;
35use futures::channel::mpsc;
36use futures::channel::mpsc::SendError;
37use futures::future::OptionFuture;
38use futures::future::poll_fn;
39use futures::stream::SelectAll;
40use guestmem::GuestMemory;
41use hvdef::Vtl;
42use inspect::Inspect;
43use mesh::payload::Protobuf;
44use mesh::rpc::FailableRpc;
45use mesh::rpc::Rpc;
46use mesh::rpc::RpcError;
47use mesh::rpc::RpcSend;
48use pal_async::driver::Driver;
49use pal_async::driver::SpawnDriver;
50use pal_async::task::Task;
51use pal_async::timer::PolledTimer;
52use pal_event::Event;
53#[cfg(windows)]
54pub use proxyintegration::ProxyIntegration;
55#[cfg(windows)]
56pub use proxyintegration::ProxyServerInfo;
57use ring::PAGE_SIZE;
58use std::collections::HashMap;
59use std::future;
60use std::future::Future;
61use std::pin::Pin;
62use std::sync::Arc;
63use std::task::Poll;
64use std::task::ready;
65use std::time::Duration;
66use unicycle::FuturesUnordered;
67use vmbus_channel::bus::ChannelRequest;
68use vmbus_channel::bus::ChannelServerRequest;
69use vmbus_channel::bus::GpadlRequest;
70use vmbus_channel::bus::ModifyRequest;
71use vmbus_channel::bus::OfferInput;
72use vmbus_channel::bus::OfferKey;
73use vmbus_channel::bus::OfferResources;
74use vmbus_channel::bus::OpenData;
75use vmbus_channel::bus::OpenRequest;
76use vmbus_channel::bus::ParentBus;
77use vmbus_channel::bus::RestoreResult;
78use vmbus_channel::gpadl::GpadlMap;
79use vmbus_channel::gpadl_ring::AlignedGpadlView;
80use vmbus_core::HvsockConnectRequest;
81use vmbus_core::HvsockConnectResult;
82use vmbus_core::MaxVersionInfo;
83use vmbus_core::OutgoingMessage;
84use vmbus_core::TaggedStream;
85use vmbus_core::VMBUS_SINT;
86use vmbus_core::VersionInfo;
87use vmbus_core::protocol;
88pub use vmbus_core::protocol::GpadlId;
89#[cfg(windows)]
90use vmbus_proxy::ProxyHandle;
91use vmbus_ring as ring;
92use vmbus_ring::gparange::MultiPagedRangeBuf;
93use vmcore::interrupt::Interrupt;
94use vmcore::save_restore::SavedStateRoot;
95use vmcore::synic::EventPort;
96use vmcore::synic::GuestEventPort;
97use vmcore::synic::GuestMessagePort;
98use vmcore::synic::MessagePort;
99use vmcore::synic::MonitorPageGpas;
100use vmcore::synic::SynicPortAccess;
101
102pub const REDIRECT_SINT: u8 = 7;
103pub const REDIRECT_VTL: Vtl = Vtl::Vtl2;
104const SHARED_EVENT_CONNECTION_ID: u32 = 2;
105const EVENT_PORT_ID: u32 = 2;
106const VMBUS_MESSAGE_TYPE: u32 = 1;
107
108const MAX_CONCURRENT_HVSOCK_REQUESTS: usize = 16;
109
110#[derive(Inspect)]
111pub struct VmbusServer {
112    #[inspect(flatten, send = "VmbusRequest::Inspect")]
113    task_send: mesh::Sender<VmbusRequest>,
114    #[inspect(skip)]
115    control: Arc<VmbusServerControl>,
116    #[inspect(skip)]
117    _message_port: Box<dyn Sync + Send>,
118    #[inspect(skip)]
119    _multiclient_message_port: Option<Box<dyn Sync + Send>>,
120    #[inspect(skip)]
121    task: Task<ServerTask>,
122}
123
124pub struct VmbusServerBuilder<T: SpawnDriver> {
125    spawner: T,
126    synic: Arc<dyn SynicPortAccess>,
127    gm: GuestMemory,
128    private_gm: Option<GuestMemory>,
129    vtl: Vtl,
130    hvsock_notify: Option<HvsockServerChannelHalf>,
131    server_relay: Option<VmbusServerChannelHalf>,
132    saved_state_notify: Option<mesh::Sender<SavedStateRequest>>,
133    external_server: Option<mesh::Sender<InitiateContactRequest>>,
134    external_requests: Option<mesh::Receiver<InitiateContactRequest>>,
135    use_message_redirect: bool,
136    channel_id_offset: u16,
137    max_version: Option<MaxVersionInfo>,
138    delay_max_version: bool,
139    enable_mnf: bool,
140    force_confidential_external_memory: bool,
141    send_messages_while_stopped: bool,
142    channel_unstick_delay: Option<Duration>,
143    use_absolute_channel_order: bool,
144}
145
146#[derive(mesh::MeshPayload)]
147/// The request to send to the proxy to set or clear its saved state cache.
148pub enum SavedStateRequest {
149    Set(FailableRpc<Box<channels::SavedState>, ()>),
150    Clear(Rpc<(), ()>),
151}
152
153/// The server side of the connection between a vmbus server and a relay.
154pub struct ServerChannelHalf<Request, Response> {
155    request_send: mesh::Sender<Request>,
156    response_receive: mesh::Receiver<Response>,
157}
158
159/// The relay side of a connection between a vmbus server and a relay.
160pub struct RelayChannelHalf<Request, Response> {
161    pub request_receive: mesh::Receiver<Request>,
162    pub response_send: mesh::Sender<Response>,
163}
164
165/// A connection between a vmbus server and a relay.
166pub struct RelayChannel<Request, Response> {
167    pub relay_half: RelayChannelHalf<Request, Response>,
168    pub server_half: ServerChannelHalf<Request, Response>,
169}
170
171impl<Request: 'static + Send, Response: 'static + Send> RelayChannel<Request, Response> {
172    /// Creates a new channel between the vmbus server and a relay.
173    pub fn new() -> Self {
174        let (request_send, request_receive) = mesh::channel();
175        let (response_send, response_receive) = mesh::channel();
176        Self {
177            relay_half: RelayChannelHalf {
178                request_receive,
179                response_send,
180            },
181            server_half: ServerChannelHalf {
182                request_send,
183                response_receive,
184            },
185        }
186    }
187}
188
189pub type VmbusServerChannelHalf = ServerChannelHalf<ModifyRelayRequest, ModifyRelayResponse>;
190pub type VmbusRelayChannelHalf = RelayChannelHalf<ModifyRelayRequest, ModifyRelayResponse>;
191pub type VmbusRelayChannel = RelayChannel<ModifyRelayRequest, ModifyRelayResponse>;
192pub type HvsockServerChannelHalf = ServerChannelHalf<HvsockConnectRequest, HvsockConnectResult>;
193pub type HvsockRelayChannelHalf = RelayChannelHalf<HvsockConnectRequest, HvsockConnectResult>;
194pub type HvsockRelayChannel = RelayChannel<HvsockConnectRequest, HvsockConnectResult>;
195
196/// A request from the server to the relay to modify connection state.
197///
198/// The version and use_interrupt_page fields can only be present if this request was sent for an
199/// InitiateContact message from the guest.
200///
201/// If `version` is `Some`, the relay must respond with either `ModifyRelayResponse::Supported` or
202/// `ModifyRelayResponse::Unsupported`. If `version` is `None`, the relay must respond with
203/// `ModifyRelayResponse::Modified`.
204#[derive(Debug, Copy, Clone)]
205pub struct ModifyRelayRequest {
206    pub version: Option<u32>,
207    pub monitor_page: Update<MonitorPageGpas>,
208    pub use_interrupt_page: Option<bool>,
209}
210
211/// A response from the relay to a ModifyRelayRequest from the server.
212#[derive(Debug, Copy, Clone)]
213pub enum ModifyRelayResponse {
214    /// The requested version change is supported, and the relay completed the connection
215    /// modification with the specified status. All of the feature flags supported by the relay host
216    /// are included, regardless of what features were requested.
217    Supported(protocol::ConnectionState, protocol::FeatureFlags),
218    /// A version change was requested but the relay host doesn't support that version. This
219    /// response cannot be returned for a request with no version change set.
220    Unsupported,
221    /// The connection modification completed with the specified status. This response must be sent
222    /// if no version change was requested.
223    Modified(protocol::ConnectionState),
224}
225
226impl From<ModifyConnectionRequest> for ModifyRelayRequest {
227    fn from(value: ModifyConnectionRequest) -> Self {
228        Self {
229            version: value.version.map(|v| v.version as u32),
230            monitor_page: value.monitor_page,
231            use_interrupt_page: match value.interrupt_page {
232                Update::Unchanged => None,
233                Update::Reset => Some(false),
234                Update::Set(_) => Some(true),
235            },
236        }
237    }
238}
239
240#[derive(Debug)]
241enum VmbusRequest {
242    Reset(Rpc<(), ()>),
243    Inspect(inspect::Deferred),
244    Save(Rpc<(), SavedState>),
245    Restore(Rpc<Box<SavedState>, Result<(), RestoreError>>),
246    Start,
247    Stop(Rpc<(), ()>),
248}
249
250#[derive(mesh::MeshPayload, Debug)]
251pub struct OfferInfo {
252    pub params: OfferParamsInternal,
253    pub event: Interrupt,
254    pub request_send: mesh::Sender<ChannelRequest>,
255    pub server_request_recv: mesh::Receiver<ChannelServerRequest>,
256}
257
258#[expect(clippy::large_enum_variant)]
259#[derive(mesh::MeshPayload)]
260pub(crate) enum OfferRequest {
261    Offer(FailableRpc<OfferInfo, ()>),
262    ForceReset(Rpc<(), ()>),
263}
264
265struct ChannelEvent(Interrupt);
266
267impl EventPort for ChannelEvent {
268    fn handle_event(&self, _flag: u16) {
269        self.0.deliver();
270    }
271
272    fn os_event(&self) -> Option<&Event> {
273        self.0.event()
274    }
275}
276
277#[derive(Debug, Protobuf, SavedStateRoot)]
278#[mesh(package = "vmbus.server")]
279pub struct SavedState {
280    #[mesh(1)]
281    pub server: channels::SavedState,
282    // Indicates if the lost synic bug is fixed or not. By default it's false.
283    // During the restore process, we check if the field is not true then
284    // unstick_channels() function will be called to mitigate the issue.
285    #[mesh(2)]
286    pub lost_synic_bug_fixed: bool,
287}
288
289const MESSAGE_CONNECTION_ID: u32 = 1;
290const MULTICLIENT_MESSAGE_CONNECTION_ID: u32 = 4;
291
292impl<T: SpawnDriver + Clone> VmbusServerBuilder<T> {
293    /// Creates a new builder for `VmbusServer` with the default options.
294    pub fn new(spawner: T, synic: Arc<dyn SynicPortAccess>, gm: GuestMemory) -> Self {
295        Self {
296            spawner,
297            synic,
298            gm,
299            private_gm: None,
300            vtl: Vtl::Vtl0,
301            hvsock_notify: None,
302            server_relay: None,
303            saved_state_notify: None,
304            external_server: None,
305            external_requests: None,
306            use_message_redirect: false,
307            channel_id_offset: 0,
308            max_version: None,
309            delay_max_version: false,
310            enable_mnf: false,
311            force_confidential_external_memory: false,
312            send_messages_while_stopped: false,
313            channel_unstick_delay: Some(Duration::from_millis(100)),
314            use_absolute_channel_order: false,
315        }
316    }
317
318    /// Sets a separate guest memory instance to use for channels that are confidential (non-relay
319    /// channels in Underhill on a hardware isolated VM). This is not relevant for a non-Underhill
320    /// VmBus server.
321    pub fn private_gm(mut self, private_gm: Option<GuestMemory>) -> Self {
322        self.private_gm = private_gm;
323        self
324    }
325
326    /// Sets the VTL that this instance will serve.
327    pub fn vtl(mut self, vtl: Vtl) -> Self {
328        self.vtl = vtl;
329        self
330    }
331
332    /// Sets a send/receive pair used to handle hvsocket requests.
333    pub fn hvsock_notify(mut self, hvsock_notify: Option<HvsockServerChannelHalf>) -> Self {
334        self.hvsock_notify = hvsock_notify;
335        self
336    }
337
338    /// Sets a send channel used to enlighten ProxyIntegration about saved channels.
339    pub fn saved_state_notify(
340        mut self,
341        saved_state_notify: Option<mesh::Sender<SavedStateRequest>>,
342    ) -> Self {
343        self.saved_state_notify = saved_state_notify;
344        self
345    }
346
347    /// Sets a send/receive pair that will be notified of server requests. This is used by the
348    /// Underhill relay.
349    pub fn server_relay(mut self, server_relay: Option<VmbusServerChannelHalf>) -> Self {
350        self.server_relay = server_relay;
351        self
352    }
353
354    /// Sets a receiver that receives requests from another server.
355    pub fn external_requests(
356        mut self,
357        external_requests: Option<mesh::Receiver<InitiateContactRequest>>,
358    ) -> Self {
359        self.external_requests = external_requests;
360        self
361    }
362
363    /// Sets a sender used to forward unhandled connect requests (which used a different VTL)
364    /// to another server.
365    pub fn external_server(
366        mut self,
367        external_server: Option<mesh::Sender<InitiateContactRequest>>,
368    ) -> Self {
369        self.external_server = external_server;
370        self
371    }
372
373    /// Sets a value which indicates whether the vmbus control plane is redirected to Underhill.
374    pub fn use_message_redirect(mut self, use_message_redirect: bool) -> Self {
375        self.use_message_redirect = use_message_redirect;
376        self
377    }
378
379    /// Tells the server to use an offset when generating channel IDs to void collisions with
380    /// another vmbus server.
381    ///
382    /// N.B. This should only be used by the Underhill vmbus server.
383    pub fn enable_channel_id_offset(mut self, enable: bool) -> Self {
384        self.channel_id_offset = if enable { 1024 } else { 0 };
385        self
386    }
387
388    /// Tells the server to limit the protocol version offered to the guest.
389    ///
390    /// N.B. This is used for testing older protocols without requiring a specific guest OS.
391    pub fn max_version(mut self, max_version: Option<MaxVersionInfo>) -> Self {
392        self.max_version = max_version;
393        self
394    }
395
396    /// Delay limiting the maximum version until after the first `Unload` message.
397    ///
398    /// N.B. This is used to enable the use of versions older than `Version::Win10` with Uefi boot,
399    ///      since that's the oldest version the Uefi client supports.
400    pub fn delay_max_version(mut self, delay: bool) -> Self {
401        self.delay_max_version = delay;
402        self
403    }
404
405    /// Enable MNF support in the server.
406    ///
407    /// N.B. Enabling this has no effect if the synic does not support mapping monitor pages.
408    pub fn enable_mnf(mut self, enable: bool) -> Self {
409        self.enable_mnf = enable;
410        self
411    }
412
413    /// Force all non-relay channels to use encrypted external memory. Used for testing purposes
414    /// only.
415    pub fn force_confidential_external_memory(mut self, force: bool) -> Self {
416        self.force_confidential_external_memory = force;
417        self
418    }
419
420    /// Send messages to the partition even while stopped, which can cause
421    /// corrupted synic states across VM reset.
422    ///
423    /// This option is used to prevent messages from getting into the queue, for
424    /// saved state compatibility with release/2411. It can be removed once that
425    /// release is no longer supported.
426    pub fn send_messages_while_stopped(mut self, send: bool) -> Self {
427        self.send_messages_while_stopped = send;
428        self
429    }
430
431    /// Sets the delay before unsticking a vmbus channel after it has been opened.
432    ///
433    /// This option provides a work around for guests that ignore interrupts before they receive the
434    /// OpenResult message, by triggering an interrupt after the channel has been opened.
435    ///
436    /// If not set, the default is 100ms. If set to `None`, no interrupt will be triggered.
437    pub fn channel_unstick_delay(mut self, delay: Option<Duration>) -> Self {
438        self.channel_unstick_delay = delay;
439        self
440    }
441
442    /// Sets whether the channel order value provided in an offer is the primary way of ordering
443    /// channels when assigning channel IDs, rather than the default behavior of ordering by
444    /// interface ID first.
445    pub fn use_absolute_channel_order(mut self, assign: bool) -> Self {
446        self.use_absolute_channel_order = assign;
447        self
448    }
449
450    /// Creates a new instance of the server.
451    ///
452    /// When the object is dropped, all channels will be closed and revoked
453    /// automatically.
454    pub fn build(self) -> anyhow::Result<VmbusServer> {
455        #[expect(clippy::disallowed_methods)] // TODO
456        let (message_send, message_recv) = mpsc::channel(64);
457        let message_sender = Arc::new(MessageSender {
458            send: message_send.clone(),
459            multiclient: self.use_message_redirect,
460        });
461
462        let (redirect_vtl, redirect_sint) = if self.use_message_redirect {
463            (REDIRECT_VTL, REDIRECT_SINT)
464        } else {
465            (self.vtl, VMBUS_SINT)
466        };
467
468        // If this server is not for VTL2, use a server-specific connection ID rather than the
469        // standard one.
470        let connection_id = if self.vtl == Vtl::Vtl0 && !self.use_message_redirect {
471            MESSAGE_CONNECTION_ID
472        } else {
473            // TODO: This ID should be using the correct target VP, but that is not known until
474            //       InitiateContact.
475            VmbusServer::get_child_message_connection_id(0, redirect_sint, redirect_vtl)
476        };
477
478        let _message_port = self
479            .synic
480            .add_message_port(connection_id, redirect_vtl, message_sender)
481            .context("failed to create vmbus synic ports")?;
482
483        // If this server is for VTL0, it is also responsible for the multiclient message port.
484        // N.B. If control plane redirection is enabled, the redirected message port is used for
485        //      multiclient and no separate multiclient port is created.
486        let _multiclient_message_port = if self.vtl == Vtl::Vtl0 && !self.use_message_redirect {
487            let multiclient_message_sender = Arc::new(MessageSender {
488                send: message_send,
489                multiclient: true,
490            });
491
492            Some(
493                self.synic
494                    .add_message_port(
495                        MULTICLIENT_MESSAGE_CONNECTION_ID,
496                        self.vtl,
497                        multiclient_message_sender,
498                    )
499                    .context("failed to create vmbus synic ports")?,
500            )
501        } else {
502            None
503        };
504
505        let (offer_send, offer_recv) = mesh::mpsc_channel();
506        let control = Arc::new(VmbusServerControl {
507            mem: self.gm.clone(),
508            private_mem: self.private_gm.clone(),
509            send: offer_send,
510            use_event: self.synic.prefer_os_events(),
511            force_confidential_external_memory: self.force_confidential_external_memory,
512        });
513
514        let mut server = channels::Server::new(
515            self.vtl,
516            connection_id,
517            self.channel_id_offset,
518            self.use_absolute_channel_order,
519        );
520
521        // If MNF is handled by this server and this is a paravisor for an isolated VM, the monitor
522        // pages must be allocated by the server, not the guest, since the guest will provide shared
523        // pages which can't be used in this case. If the guest doesn't support server-specified
524        // monitor pages, MNF will be disabled for all channels for that connection.
525        server.set_require_server_allocated_mnf(self.enable_mnf && self.private_gm.is_some());
526
527        // If requested, limit the maximum protocol version and feature flags.
528        if let Some(version) = self.max_version {
529            server.set_compatibility_version(version, self.delay_max_version);
530        }
531        let (relay_request_send, relay_response_recv) =
532            if let Some(server_relay) = self.server_relay {
533                let r = server_relay.response_receive.boxed().fuse();
534                (server_relay.request_send, r)
535            } else {
536                let (req_send, req_recv) = mesh::channel();
537                let resp_recv = req_recv
538                    .map(|req: ModifyRelayRequest| {
539                        // Map to the correct response type for the request.
540                        if req.version.is_some() {
541                            ModifyRelayResponse::Supported(
542                                protocol::ConnectionState::SUCCESSFUL,
543                                protocol::FeatureFlags::from_bits(u32::MAX),
544                            )
545                        } else {
546                            ModifyRelayResponse::Modified(protocol::ConnectionState::SUCCESSFUL)
547                        }
548                    })
549                    .boxed()
550                    .fuse();
551                (req_send, resp_recv)
552            };
553
554        // If no hvsock notifier was specified, use a default one that always sends an error response.
555        let (hvsock_send, hvsock_recv) = if let Some(hvsock_notify) = self.hvsock_notify {
556            let r = hvsock_notify.response_receive.boxed().fuse();
557            (hvsock_notify.request_send, r)
558        } else {
559            let (req_send, req_recv) = mesh::channel();
560            let resp_recv = req_recv
561                .map(|r: HvsockConnectRequest| HvsockConnectResult::from_request(&r, false))
562                .boxed()
563                .fuse();
564            (req_send, resp_recv)
565        };
566
567        let inner = ServerTaskInner {
568            running: false,
569            send_messages_while_stopped: self.send_messages_while_stopped,
570            gm: self.gm,
571            private_gm: self.private_gm,
572            vtl: self.vtl,
573            redirect_vtl,
574            redirect_sint,
575            message_port: self
576                .synic
577                .new_guest_message_port(redirect_vtl, 0, redirect_sint)?,
578            synic: self.synic,
579            hvsock_requests: 0,
580            hvsock_send,
581            saved_state_notify: self.saved_state_notify,
582            channels: HashMap::new(),
583            channel_responses: FuturesUnordered::new(),
584            relay_send: relay_request_send,
585            external_server_send: self.external_server,
586            channel_bitmap: None,
587            shared_event_port: None,
588            reset_done: Vec::new(),
589            mnf_support: self.enable_mnf.then(MnfSupport::default),
590        };
591
592        let (task_send, task_recv) = mesh::channel();
593        let mut server_task = ServerTask {
594            driver: Box::new(self.spawner.clone()),
595            server,
596            task_recv,
597            offer_recv,
598            message_recv,
599            server_request_recv: SelectAll::new(),
600            inner,
601            external_requests: self.external_requests,
602            next_seq: 0,
603            perform_post_restore_on_start: false,
604            unstick_on_start: false,
605            channel_unstickers: FuturesUnordered::new(),
606            channel_unstick_delay: self.channel_unstick_delay,
607        };
608
609        let task = self.spawner.spawn("vmbus server", async move {
610            server_task.run(relay_response_recv, hvsock_recv).await;
611            server_task
612        });
613
614        Ok(VmbusServer {
615            task_send,
616            control,
617            _message_port,
618            _multiclient_message_port,
619            task,
620        })
621    }
622}
623
624impl VmbusServer {
625    /// Creates a new builder for `VmbusServer` with the default options.
626    pub fn builder<T: SpawnDriver + Clone>(
627        spawner: T,
628        synic: Arc<dyn SynicPortAccess>,
629        gm: GuestMemory,
630    ) -> VmbusServerBuilder<T> {
631        VmbusServerBuilder::new(spawner, synic, gm)
632    }
633
634    pub async fn save(&self) -> SavedState {
635        self.task_send.call(VmbusRequest::Save, ()).await.unwrap()
636    }
637
638    pub async fn restore(&self, state: SavedState) -> Result<(), RestoreError> {
639        self.task_send
640            .call(VmbusRequest::Restore, Box::new(state))
641            .await
642            .unwrap()
643    }
644
645    /// Stop the control plane.
646    pub async fn stop(&self) {
647        self.task_send.call(VmbusRequest::Stop, ()).await.unwrap()
648    }
649
650    /// Starts the control plane.
651    pub fn start(&self) {
652        self.task_send.send(VmbusRequest::Start);
653    }
654
655    /// Resets the vmbus channel state.
656    pub async fn reset(&self) {
657        tracing::debug!("resetting channel state");
658        self.task_send.call(VmbusRequest::Reset, ()).await.unwrap()
659    }
660
661    /// Tears down the vmbus control plane.
662    pub async fn shutdown(self) {
663        drop(self.task_send);
664        let _ = self.task.await;
665    }
666
667    /// Returns an object that can be used to offer channels.
668    pub fn control(&self) -> Arc<VmbusServerControl> {
669        self.control.clone()
670    }
671
672    /// Returns the message connection ID to use for a communication from the guest for servers
673    /// that use a non-standard SINT or VTL.
674    fn get_child_message_connection_id(vp_index: u32, sint_index: u8, vtl: Vtl) -> u32 {
675        MULTICLIENT_MESSAGE_CONNECTION_ID
676            | (vtl as u32) << 22
677            | vp_index << 8
678            | (sint_index as u32) << 4
679    }
680
681    fn get_child_event_port_id(channel_id: protocol::ChannelId, sint_index: u8, vtl: Vtl) -> u32 {
682        EVENT_PORT_ID | (vtl as u32) << 22 | channel_id.0 << 8 | (sint_index as u32) << 4
683    }
684}
685
686#[derive(mesh::MeshPayload)]
687pub struct RestoreInfo {
688    open_data: Option<OpenData>,
689    gpadls: Vec<(GpadlId, u16, Vec<u64>)>,
690    interrupt: Option<Interrupt>,
691}
692
693#[derive(Default)]
694pub struct SynicMessage {
695    data: Vec<u8>,
696    multiclient: bool,
697    trusted: bool,
698}
699
700/// Information used by a server that supports MNF.
701#[derive(Default)]
702struct MnfSupport {
703    allocated_monitor_page: Option<MonitorPageGpas>,
704}
705
706/// Disambiguates offer instances that may have reused the same offer ID.
707#[derive(Debug, Clone, Copy)]
708struct OfferInstanceId {
709    offer_id: OfferId,
710    seq: u64,
711}
712
713struct ServerTask {
714    driver: Box<dyn Driver>,
715    server: channels::Server,
716    task_recv: mesh::Receiver<VmbusRequest>,
717    offer_recv: mesh::Receiver<OfferRequest>,
718    message_recv: mpsc::Receiver<SynicMessage>,
719    server_request_recv:
720        SelectAll<TaggedStream<OfferInstanceId, mesh::Receiver<ChannelServerRequest>>>,
721    inner: ServerTaskInner,
722    external_requests: Option<mesh::Receiver<InitiateContactRequest>>,
723    /// Next value for [`Channel::seq`].
724    next_seq: u64,
725    perform_post_restore_on_start: bool,
726    unstick_on_start: bool,
727    channel_unstickers: FuturesUnordered<Pin<Box<dyn Send + Future<Output = OfferInstanceId>>>>,
728    channel_unstick_delay: Option<Duration>,
729}
730
731struct ServerTaskInner {
732    running: bool,
733    send_messages_while_stopped: bool,
734    gm: GuestMemory,
735    private_gm: Option<GuestMemory>,
736    synic: Arc<dyn SynicPortAccess>,
737    vtl: Vtl,
738    redirect_vtl: Vtl,
739    redirect_sint: u8,
740    message_port: Box<dyn GuestMessagePort>,
741    hvsock_requests: usize,
742    hvsock_send: mesh::Sender<HvsockConnectRequest>,
743    saved_state_notify: Option<mesh::Sender<SavedStateRequest>>,
744    channels: HashMap<OfferId, Channel>,
745    channel_responses: FuturesUnordered<
746        Pin<Box<dyn Send + Future<Output = (OfferId, u64, Result<ChannelResponse, RpcError>)>>>,
747    >,
748    external_server_send: Option<mesh::Sender<InitiateContactRequest>>,
749    relay_send: mesh::Sender<ModifyRelayRequest>,
750    channel_bitmap: Option<Arc<ChannelBitmap>>,
751    shared_event_port: Option<Box<dyn Send>>,
752    reset_done: Vec<Rpc<(), ()>>,
753    /// Stores information needed to support MNF. If `None`, this server doesn't support MNF (in
754    /// the case of OpenHCL, that means it will be handled by the relay host).
755    mnf_support: Option<MnfSupport>,
756}
757
758#[derive(Debug)]
759enum ChannelResponse {
760    Open(bool),
761    Close,
762    Gpadl(GpadlId, bool),
763    TeardownGpadl(GpadlId),
764    Modify(i32),
765}
766
767#[derive(Debug, Copy, Clone, PartialEq, Eq)]
768enum ChannelUnstickState {
769    None,
770    Queued,
771    NeedsRequeue,
772}
773
774struct Channel {
775    key: OfferKey,
776    send: mesh::Sender<ChannelRequest>,
777    seq: u64,
778    state: ChannelState,
779    gpadls: Arc<GpadlMap>,
780    guest_to_host_event: Arc<ChannelEvent>,
781    flags: protocol::OfferFlags,
782    // A channel can be reserved no matter what state it is in. This allows the message port for a
783    // reserved channel to remain available even if the channel is closed, so the guest can read the
784    // close reserved channel response. The reserved state is cleared when the channel is revoked,
785    // reopened, or the guest sends an unload message.
786    reserved_state: ReservedState,
787    unstick_state: ChannelUnstickState,
788}
789
790struct ReservedState {
791    message_port: Option<Box<dyn GuestMessagePort>>,
792    target: ConnectionTarget,
793}
794
795struct ChannelOpenState {
796    open_params: OpenParams,
797    _event_port: Box<dyn Send>,
798    guest_event_port: Option<Box<dyn GuestEventPort>>,
799    host_to_guest_interrupt: Interrupt,
800}
801
802impl ChannelOpenState {
803    fn set_event_port_target_vp(&mut self, vp: u32) -> anyhow::Result<()> {
804        let Some(guest_event_port) = self.guest_event_port.as_mut() else {
805            anyhow::bail!("cannot set target VP if the channel interrupt is disabled");
806        };
807
808        guest_event_port.set_target_vp(vp)?;
809        Ok(())
810    }
811}
812
813enum ChannelState {
814    Closed,
815    Open(Box<ChannelOpenState>),
816    Closing,
817}
818
819impl ServerTask {
820    fn handle_offer(&mut self, mut info: OfferInfo) -> anyhow::Result<()> {
821        let key = info.params.key();
822        let flags = info.params.flags;
823
824        if self.inner.mnf_support.is_some() && self.inner.synic.monitor_support().is_some() {
825            // If this server is handling MnF, ignore any relayed monitor IDs but still enable MnF
826            // for those channels.
827            // N.B. Since this can only happen in OpenHCL, which emulates MnF, the latency is
828            //      ignored.
829            if info.params.use_mnf.is_relayed() {
830                info.params.use_mnf = MnfUsage::Enabled {
831                    latency: Duration::ZERO,
832                }
833            }
834        } else if info.params.use_mnf.is_enabled() {
835            // If the server is not handling MnF, disable it for the channel. This does not affect
836            // channels with a relayed monitor ID.
837            info.params.use_mnf = MnfUsage::Disabled;
838        }
839
840        let offer_id = self
841            .server
842            .with_notifier(&mut self.inner)
843            .offer_channel(info.params)
844            .context("channel offer failed")?;
845
846        tracing::debug!(?offer_id, %key, "offered channel");
847
848        let seq = self.next_seq;
849        self.next_seq += 1;
850        self.inner.channels.insert(
851            offer_id,
852            Channel {
853                key,
854                send: info.request_send,
855                state: ChannelState::Closed,
856                gpadls: GpadlMap::new(),
857                guest_to_host_event: Arc::new(ChannelEvent(info.event)),
858                seq,
859                flags,
860                reserved_state: ReservedState {
861                    message_port: None,
862                    target: ConnectionTarget { vp: 0, sint: 0 },
863                },
864                unstick_state: ChannelUnstickState::None,
865            },
866        );
867
868        self.server_request_recv.push(TaggedStream::new(
869            OfferInstanceId { offer_id, seq },
870            info.server_request_recv,
871        ));
872
873        Ok(())
874    }
875
876    fn handle_revoke(&mut self, id: OfferInstanceId) {
877        // The channel may or may not exist in the map depending on whether it's been explicitly
878        // revoked before being dropped.
879        if let Some(channel) = self.inner.channels.get(&id.offer_id) {
880            if channel.seq == id.seq {
881                tracing::info!(?id.offer_id, key = %channel.key, "revoking channel");
882                self.inner.channels.remove(&id.offer_id);
883                self.server
884                    .with_notifier(&mut self.inner)
885                    .revoke_channel(id.offer_id);
886            }
887        }
888    }
889
890    fn handle_response(
891        &mut self,
892        offer_id: OfferId,
893        seq: u64,
894        response: Result<ChannelResponse, RpcError>,
895    ) {
896        // Validate the sequence to ensure the response is not for a revoked channel.
897        let channel = self
898            .inner
899            .channels
900            .get(&offer_id)
901            .filter(|channel| channel.seq == seq);
902
903        if let Some(channel) = channel {
904            match response {
905                Ok(response) => match response {
906                    ChannelResponse::Open(result) => self.handle_open(offer_id, result),
907                    ChannelResponse::Close => self.handle_close(offer_id),
908                    ChannelResponse::Gpadl(gpadl_id, ok) => {
909                        self.handle_gpadl_create(offer_id, gpadl_id, ok)
910                    }
911                    ChannelResponse::TeardownGpadl(gpadl_id) => {
912                        self.handle_gpadl_teardown(offer_id, gpadl_id)
913                    }
914                    ChannelResponse::Modify(status) => self.handle_modify_channel(offer_id, status),
915                },
916                Err(err) => {
917                    tracing::error!(
918                        key = %channel.key,
919                        error = &err as &dyn std::error::Error,
920                        "channel response failure, channel is in inconsistent state until revoked"
921                    );
922                }
923            }
924        } else {
925            tracing::debug!(offer_id = ?offer_id, seq, ?response, "received response after revoke");
926        }
927    }
928
929    fn handle_open(&mut self, offer_id: OfferId, success: bool) {
930        let status = if success {
931            let channel = self
932                .inner
933                .channels
934                .get_mut(&offer_id)
935                .expect("channel exists");
936
937            // Some guests ignore interrupts before they receive the OpenResult message. To avoid
938            // a potential hang, signal the channel after a delay if needed.
939            if let Some(delay) = self.channel_unstick_delay {
940                if channel.unstick_state == ChannelUnstickState::None {
941                    channel.unstick_state = ChannelUnstickState::Queued;
942                    let seq = channel.seq;
943                    let mut timer = PolledTimer::new(&self.driver);
944                    self.channel_unstickers.push(Box::pin(async move {
945                        timer.sleep(delay).await;
946                        OfferInstanceId { offer_id, seq }
947                    }));
948                } else {
949                    channel.unstick_state = ChannelUnstickState::NeedsRequeue;
950                }
951            }
952
953            0
954        } else {
955            protocol::STATUS_UNSUCCESSFUL
956        };
957
958        self.server
959            .with_notifier(&mut self.inner)
960            .open_complete(offer_id, status);
961    }
962
963    fn handle_close(&mut self, offer_id: OfferId) {
964        let channel = self
965            .inner
966            .channels
967            .get_mut(&offer_id)
968            .expect("channel still exists");
969
970        match &mut channel.state {
971            ChannelState::Closing => {
972                tracing::debug!(?offer_id, key = %channel.key, "closing channel");
973                channel.state = ChannelState::Closed;
974                self.server
975                    .with_notifier(&mut self.inner)
976                    .close_complete(offer_id);
977            }
978            _ => {
979                tracing::error!(?offer_id, key = %channel.key, "invalid close channel response");
980            }
981        };
982    }
983
984    fn handle_gpadl_create(&mut self, offer_id: OfferId, gpadl_id: GpadlId, ok: bool) {
985        let status = if ok { 0 } else { protocol::STATUS_UNSUCCESSFUL };
986        self.server
987            .with_notifier(&mut self.inner)
988            .gpadl_create_complete(offer_id, gpadl_id, status);
989    }
990
991    fn handle_gpadl_teardown(&mut self, offer_id: OfferId, gpadl_id: GpadlId) {
992        self.server
993            .with_notifier(&mut self.inner)
994            .gpadl_teardown_complete(offer_id, gpadl_id);
995    }
996
997    fn handle_modify_channel(&mut self, offer_id: OfferId, status: i32) {
998        self.server
999            .with_notifier(&mut self.inner)
1000            .modify_channel_complete(offer_id, status);
1001    }
1002
1003    fn handle_restore_channel(
1004        &mut self,
1005        offer_id: OfferId,
1006        open: bool,
1007    ) -> anyhow::Result<RestoreResult> {
1008        let gpadls = self.server.channel_gpadls(offer_id);
1009
1010        // If the channel is opened, handle that before calling into channels so that failure can
1011        // be handled before the channel is marked restored.
1012        let open_request = open
1013            .then(|| -> anyhow::Result<_> {
1014                let params = self.server.get_restore_open_params(offer_id)?;
1015                let (channel, interrupt) = self.inner.open_channel(offer_id, &params)?;
1016                Ok(OpenRequest::new(
1017                    params.open_data,
1018                    interrupt,
1019                    self.server
1020                        .get_version()
1021                        .expect("must be connected")
1022                        .feature_flags,
1023                    channel.flags,
1024                ))
1025            })
1026            .transpose()?;
1027
1028        self.server
1029            .with_notifier(&mut self.inner)
1030            .restore_channel(offer_id, open_request.is_some())?;
1031
1032        let channel = self.inner.channels.get_mut(&offer_id).unwrap();
1033        for gpadl in &gpadls {
1034            if let Ok(buf) = MultiPagedRangeBuf::from_range_buffer(
1035                gpadl.request.count.into(),
1036                gpadl.request.buf.clone(),
1037            ) {
1038                channel.gpadls.add(gpadl.request.id, buf);
1039            }
1040        }
1041
1042        let result = RestoreResult {
1043            open_request,
1044            gpadls,
1045        };
1046        Ok(result)
1047    }
1048
1049    async fn handle_request(&mut self, request: VmbusRequest) {
1050        tracing::debug!(?request, "handle_request");
1051        match request {
1052            VmbusRequest::Reset(rpc) => self.handle_reset(rpc),
1053            VmbusRequest::Inspect(deferred) => {
1054                deferred.respond(|resp| {
1055                    resp.field("message_port", &self.inner.message_port)
1056                        .field("running", self.inner.running)
1057                        .field("hvsock_requests", self.inner.hvsock_requests)
1058                        .field("channel_unstick_delay", self.channel_unstick_delay)
1059                        .field_mut_with("unstick_channels", |v| {
1060                            let v: inspect::ValueKind = if let Some(v) = v {
1061                                if v == "force" {
1062                                    self.unstick_channels(true);
1063                                    v.into()
1064                                } else {
1065                                    let v =
1066                                        v.parse().ok().context("expected false, true, or force")?;
1067                                    if v {
1068                                        self.unstick_channels(false);
1069                                    }
1070                                    v.into()
1071                                }
1072                            } else {
1073                                false.into()
1074                            };
1075                            anyhow::Ok(v)
1076                        })
1077                        .merge(&self.server.with_notifier(&mut self.inner));
1078                });
1079            }
1080            VmbusRequest::Save(rpc) => rpc.handle_sync(|()| SavedState {
1081                server: self.server.save(),
1082                lost_synic_bug_fixed: true,
1083            }),
1084            VmbusRequest::Restore(rpc) => {
1085                rpc.handle(async |state| {
1086                    self.unstick_on_start = !state.lost_synic_bug_fixed;
1087                    self.perform_post_restore_on_start = true;
1088                    if let Some(sender) = &self.inner.saved_state_notify {
1089                        tracing::trace!("sending saved state to proxy");
1090                        if let Err(err) = sender
1091                            .call_failable(SavedStateRequest::Set, Box::new(state.server.clone()))
1092                            .await
1093                        {
1094                            tracing::error!(
1095                                err = &err as &dyn std::error::Error,
1096                                "failed to restore proxy saved state"
1097                            );
1098                            return Err(RestoreError::ServerError(err.into()));
1099                        }
1100                    }
1101
1102                    self.server
1103                        .with_notifier(&mut self.inner)
1104                        .restore(state.server)
1105                })
1106                .await
1107            }
1108            VmbusRequest::Stop(rpc) => rpc.handle_sync(|()| {
1109                if self.inner.running {
1110                    self.inner.running = false;
1111                }
1112            }),
1113            VmbusRequest::Start => {
1114                if !self.inner.running {
1115                    self.inner.running = true;
1116                    if self.perform_post_restore_on_start {
1117                        if let Some(sender) = self.inner.saved_state_notify.as_ref() {
1118                            // Indicate to the proxy that the server is starting and that it should
1119                            // clear its saved state cache.
1120                            tracing::trace!("sending clear saved state message to proxy");
1121                            sender
1122                                .call(SavedStateRequest::Clear, ())
1123                                .await
1124                                .expect("failed to clear proxy saved state");
1125                        }
1126
1127                        self.server
1128                            .with_notifier(&mut self.inner)
1129                            .revoke_unclaimed_channels();
1130
1131                        self.perform_post_restore_on_start = false;
1132                    }
1133
1134                    if self.unstick_on_start {
1135                        tracing::info!(
1136                            "lost synic bug fix is not in yet, call unstick_channels to mitigate the issue."
1137                        );
1138                        self.unstick_channels(false);
1139                        self.unstick_on_start = false;
1140                    }
1141                }
1142            }
1143        }
1144    }
1145
1146    fn handle_reset(&mut self, rpc: Rpc<(), ()>) {
1147        let needs_reset = self.inner.reset_done.is_empty();
1148        self.inner.reset_done.push(rpc);
1149        if needs_reset {
1150            self.server.with_notifier(&mut self.inner).reset();
1151        }
1152    }
1153
1154    fn handle_relay_response(&mut self, response: ModifyRelayResponse) {
1155        // Convert to a matching ModifyConnectionResponse.
1156        let response = match response {
1157            ModifyRelayResponse::Supported(state, features) => {
1158                // Provide the server-allocated monitor page to the server only if they're actually being
1159                // used (if not, they may still be allocated from a previous connection).
1160                let allocated_monitor_gpas = self
1161                    .inner
1162                    .mnf_support
1163                    .as_ref()
1164                    .and_then(|mnf| mnf.allocated_monitor_page);
1165
1166                ModifyConnectionResponse::Supported(state, features, allocated_monitor_gpas)
1167            }
1168            ModifyRelayResponse::Unsupported => ModifyConnectionResponse::Unsupported,
1169            ModifyRelayResponse::Modified(state) => ModifyConnectionResponse::Modified(state),
1170        };
1171
1172        self.server
1173            .with_notifier(&mut self.inner)
1174            .complete_modify_connection(response);
1175    }
1176
1177    fn handle_tl_connect_result(&mut self, result: HvsockConnectResult) {
1178        tracing::debug!(?result, "hvsock connect result");
1179        assert_ne!(self.inner.hvsock_requests, 0);
1180        self.inner.hvsock_requests -= 1;
1181
1182        self.server
1183            .with_notifier(&mut self.inner)
1184            .send_tl_connect_result(result);
1185    }
1186
1187    fn handle_synic_message(&mut self, message: SynicMessage) {
1188        match self
1189            .server
1190            .with_notifier(&mut self.inner)
1191            .handle_synic_message(message)
1192        {
1193            Ok(()) => {}
1194            Err(err) => {
1195                tracing::warn!(
1196                    error = &err as &dyn std::error::Error,
1197                    "synic message error"
1198                );
1199            }
1200        }
1201    }
1202
1203    /// Handles a request forwarded by a different vmbus server. This is used to forward requests
1204    /// for different VTLs to different servers.
1205    ///
1206    /// N.B. This uses the same mechanism as the HCL server relay, so all requests, even the ones
1207    ///      meant for the primary server, are forwarded. In that case the primary server depends
1208    ///      on this server to send back a response so it can continue handling it.
1209    fn handle_external_request(&mut self, request: InitiateContactRequest) {
1210        self.server
1211            .with_notifier(&mut self.inner)
1212            .initiate_contact(request);
1213    }
1214
1215    async fn run(
1216        &mut self,
1217        mut relay_response_recv: impl futures::stream::FusedStream<Item = ModifyRelayResponse> + Unpin,
1218        mut hvsock_recv: impl futures::stream::FusedStream<Item = HvsockConnectResult> + Unpin,
1219    ) {
1220        loop {
1221            // Create an OptionFuture for each event that should only be handled
1222            // while the VM is running. In other cases, leave the events in
1223            // their respective queues.
1224
1225            let running_not_resetting = self.inner.running && self.inner.reset_done.is_empty();
1226            let mut external_requests = OptionFuture::from(
1227                running_not_resetting
1228                    .then(|| {
1229                        self.external_requests
1230                            .as_mut()
1231                            .map(|r| r.select_next_some())
1232                    })
1233                    .flatten(),
1234            );
1235
1236            // Try to send any pending messages while the VM is running.
1237            let has_pending_messages = self.server.has_pending_messages();
1238            let message_port = self.inner.message_port.as_mut();
1239            let mut flush_pending_messages =
1240                OptionFuture::from((running_not_resetting && has_pending_messages).then(|| {
1241                    poll_fn(|cx| {
1242                        self.server.poll_flush_pending_messages(|msg| {
1243                            message_port.poll_post_message(cx, VMBUS_MESSAGE_TYPE, msg.data())
1244                        })
1245                    })
1246                    .fuse()
1247                }));
1248
1249            // Only handle new incoming messages if there are no outgoing messages pending, and not
1250            // too many hvsock requests outstanding. This puts a bound on the resources used by the
1251            // guest.
1252            let mut message_recv = OptionFuture::from(
1253                (running_not_resetting
1254                    && !has_pending_messages
1255                    && self.inner.hvsock_requests < MAX_CONCURRENT_HVSOCK_REQUESTS)
1256                    .then(|| self.message_recv.select_next_some()),
1257            );
1258
1259            // Accept channel responses until stopped or when resetting.
1260            let mut channel_response = OptionFuture::from(
1261                (self.inner.running || !self.inner.reset_done.is_empty())
1262                    .then(|| self.inner.channel_responses.select_next_some()),
1263            );
1264
1265            // Accept hvsock connect responses while the VM is running.
1266            let mut hvsock_response =
1267                OptionFuture::from(running_not_resetting.then(|| hvsock_recv.select_next_some()));
1268
1269            let mut channel_unstickers = OptionFuture::from(
1270                running_not_resetting.then(|| self.channel_unstickers.select_next_some()),
1271            );
1272
1273            futures::select! { // merge semantics
1274                r = self.task_recv.recv().fuse() => {
1275                    if let Ok(request) = r {
1276                        self.handle_request(request).await;
1277                    } else {
1278                        break;
1279                    }
1280                }
1281                r = self.offer_recv.select_next_some() => {
1282                    match r {
1283                        OfferRequest::Offer(rpc) => {
1284                            rpc.handle_failable_sync(|request| { self.handle_offer(request) })
1285                        },
1286                        OfferRequest::ForceReset(rpc) => {
1287                            self.handle_reset(rpc);
1288                        }
1289                    }
1290                }
1291                r = self.server_request_recv.select_next_some() => {
1292                    match r {
1293                        (id, Some(request)) => match request {
1294                            ChannelServerRequest::Restore(rpc) => rpc.handle_failable_sync(|open| {
1295                                self.handle_restore_channel(id.offer_id, open)
1296                            }),
1297                            ChannelServerRequest::Revoke(rpc) => rpc.handle_sync(|_| {
1298                                self.handle_revoke(id);
1299                            })
1300                        },
1301                        (id, None) => self.handle_revoke(id),
1302                    }
1303                }
1304                r = channel_response => {
1305                    let (id, seq, response) = r.unwrap();
1306                    self.handle_response(id, seq, response);
1307                }
1308                r = relay_response_recv.select_next_some() => {
1309                    self.handle_relay_response(r);
1310                },
1311                r = hvsock_response => {
1312                    self.handle_tl_connect_result(r.unwrap());
1313                }
1314                data = message_recv => {
1315                    let data = data.unwrap();
1316                    self.handle_synic_message(data);
1317                }
1318                r = external_requests => {
1319                    let r = r.unwrap();
1320                    self.handle_external_request(r);
1321                }
1322                r = channel_unstickers => {
1323                    self.unstick_channel_by_id(r.unwrap());
1324                }
1325                _r = flush_pending_messages => {}
1326                complete => break,
1327            }
1328        }
1329    }
1330
1331    /// Wakes the guest and optionally the host for every open channel. If `force`, always wakes
1332    /// them. If `!force`, only wake for rings that are in the state where a notification is
1333    /// expected.
1334    fn unstick_channels(&self, force: bool) {
1335        let Some(version) = self.server.get_version() else {
1336            tracing::warn!("cannot unstick when not connected");
1337            return;
1338        };
1339
1340        for channel in self.inner.channels.values() {
1341            let gm = self.inner.get_gm_for_channel(version, channel);
1342            if let Err(err) = Self::unstick_channel(gm, channel, force, true) {
1343                tracing::warn!(
1344                    channel = %channel.key,
1345                    error = err.as_ref() as &dyn std::error::Error,
1346                    "could not unstick channel"
1347                );
1348            }
1349        }
1350    }
1351
1352    /// Wakes the guest for the specified channel if it's open and the rings are in a state where
1353    /// notification is expected.
1354    fn unstick_channel_by_id(&mut self, id: OfferInstanceId) {
1355        let Some(version) = self.server.get_version() else {
1356            tracelimit::warn_ratelimited!("cannot unstick when not connected");
1357            return;
1358        };
1359
1360        if let Some(channel) = self.inner.channels.get_mut(&id.offer_id) {
1361            if channel.seq != id.seq {
1362                // The channel was revoked.
1363                return;
1364            }
1365
1366            // The channel was closed and reopened before the delay expired, so wait again to ensure
1367            // we don't signal too early.
1368            if channel.unstick_state == ChannelUnstickState::NeedsRequeue {
1369                channel.unstick_state = ChannelUnstickState::Queued;
1370                let mut timer = PolledTimer::new(&self.driver);
1371                let delay = self.channel_unstick_delay.unwrap();
1372                self.channel_unstickers.push(Box::pin(async move {
1373                    timer.sleep(delay).await;
1374                    id
1375                }));
1376
1377                return;
1378            }
1379
1380            channel.unstick_state = ChannelUnstickState::None;
1381            let gm = select_gm_for_channel(
1382                &self.inner.gm,
1383                self.inner.private_gm.as_ref(),
1384                version,
1385                channel,
1386            );
1387            if let Err(err) = Self::unstick_channel(gm, channel, false, false) {
1388                tracelimit::warn_ratelimited!(
1389                    channel = %channel.key,
1390                    error = err.as_ref() as &dyn std::error::Error,
1391                    "could not unstick channel"
1392                );
1393            }
1394        }
1395    }
1396
1397    fn unstick_channel(
1398        gm: &GuestMemory,
1399        channel: &Channel,
1400        force: bool,
1401        unstick_host: bool,
1402    ) -> anyhow::Result<()> {
1403        if let ChannelState::Open(state) = &channel.state {
1404            if force {
1405                tracing::info!(channel = %channel.key, "waking host and guest");
1406                if unstick_host {
1407                    channel.guest_to_host_event.0.deliver();
1408                }
1409                state.host_to_guest_interrupt.deliver();
1410                return Ok(());
1411            }
1412
1413            let gpadl = channel
1414                .gpadls
1415                .clone()
1416                .view()
1417                .map(state.open_params.open_data.ring_gpadl_id)
1418                .context("couldn't find ring gpadl")?;
1419
1420            let aligned = AlignedGpadlView::new(gpadl)
1421                .ok()
1422                .context("ring not aligned")?;
1423            let (in_gpadl, out_gpadl) = aligned
1424                .split(state.open_params.open_data.ring_offset)
1425                .ok()
1426                .context("couldn't split ring")?;
1427
1428            if let Err(err) = Self::unstick_incoming_ring(
1429                gm,
1430                channel,
1431                in_gpadl,
1432                unstick_host.then_some(channel.guest_to_host_event.as_ref()),
1433                &state.host_to_guest_interrupt,
1434            ) {
1435                tracelimit::warn_ratelimited!(
1436                    channel = %channel.key,
1437                    error = err.as_ref() as &dyn std::error::Error,
1438                    "could not unstick incoming ring"
1439                );
1440            }
1441            if let Err(err) = Self::unstick_outgoing_ring(
1442                gm,
1443                channel,
1444                out_gpadl,
1445                unstick_host.then_some(channel.guest_to_host_event.as_ref()),
1446                &state.host_to_guest_interrupt,
1447            ) {
1448                tracelimit::warn_ratelimited!(
1449                    channel = %channel.key,
1450                    error = err.as_ref() as &dyn std::error::Error,
1451                    "could not unstick outgoing ring"
1452                );
1453            }
1454        }
1455        Ok(())
1456    }
1457
1458    fn unstick_incoming_ring(
1459        gm: &GuestMemory,
1460        channel: &Channel,
1461        in_gpadl: AlignedGpadlView,
1462        guest_to_host_event: Option<&ChannelEvent>,
1463        host_to_guest_interrupt: &Interrupt,
1464    ) -> anyhow::Result<()> {
1465        let control_page = lock_gpn_with_subrange(gm, in_gpadl.gpns()[0])?;
1466        if let Some(guest_to_host_event) = guest_to_host_event {
1467            if ring::reader_needs_signal(control_page.pages()[0]) {
1468                tracelimit::info_ratelimited!(channel = %channel.key, "waking host for incoming ring");
1469                guest_to_host_event.0.deliver();
1470            }
1471        }
1472
1473        let ring_size = gpadl_ring_size(&in_gpadl).try_into()?;
1474        if ring::writer_needs_signal(control_page.pages()[0], ring_size) {
1475            tracelimit::info_ratelimited!(channel = %channel.key, "waking guest for incoming ring");
1476            host_to_guest_interrupt.deliver();
1477        }
1478        Ok(())
1479    }
1480
1481    fn unstick_outgoing_ring(
1482        gm: &GuestMemory,
1483        channel: &Channel,
1484        out_gpadl: AlignedGpadlView,
1485        guest_to_host_event: Option<&ChannelEvent>,
1486        host_to_guest_interrupt: &Interrupt,
1487    ) -> anyhow::Result<()> {
1488        let control_page = lock_gpn_with_subrange(gm, out_gpadl.gpns()[0])?;
1489        if ring::reader_needs_signal(control_page.pages()[0]) {
1490            tracelimit::info_ratelimited!(channel = %channel.key, "waking guest for outgoing ring");
1491            host_to_guest_interrupt.deliver();
1492        }
1493
1494        if let Some(guest_to_host_event) = guest_to_host_event {
1495            let ring_size = gpadl_ring_size(&out_gpadl).try_into()?;
1496            if ring::writer_needs_signal(control_page.pages()[0], ring_size) {
1497                tracelimit::info_ratelimited!(channel = %channel.key, "waking host for outgoing ring");
1498                guest_to_host_event.0.deliver();
1499            }
1500        }
1501        Ok(())
1502    }
1503}
1504
1505impl Notifier for ServerTaskInner {
1506    fn notify(&mut self, offer_id: OfferId, action: channels::Action) {
1507        let channel = self
1508            .channels
1509            .get_mut(&offer_id)
1510            .expect("channel does not exist");
1511
1512        fn handle<I: 'static + Send, R: 'static + Send>(
1513            offer_id: OfferId,
1514            channel: &Channel,
1515            req: impl FnOnce(Rpc<I, R>) -> ChannelRequest,
1516            input: I,
1517            f: impl 'static + Send + FnOnce(R) -> ChannelResponse,
1518        ) -> Pin<Box<dyn Send + Future<Output = (OfferId, u64, Result<ChannelResponse, RpcError>)>>>
1519        {
1520            let recv = channel.send.call(req, input);
1521            let seq = channel.seq;
1522            Box::pin(async move {
1523                let r = recv.await.map(f);
1524                (offer_id, seq, r)
1525            })
1526        }
1527
1528        let response = match action {
1529            channels::Action::Open(open_params, version) => {
1530                let seq = channel.seq;
1531                let key = channel.key;
1532                match self.open_channel(offer_id, &open_params) {
1533                    Ok((channel, interrupt)) => handle(
1534                        offer_id,
1535                        channel,
1536                        ChannelRequest::Open,
1537                        OpenRequest::new(
1538                            open_params.open_data,
1539                            interrupt,
1540                            version.feature_flags,
1541                            channel.flags,
1542                        ),
1543                        ChannelResponse::Open,
1544                    ),
1545                    Err(err) => {
1546                        tracelimit::error_ratelimited!(
1547                            err = err.as_ref() as &dyn std::error::Error,
1548                            ?offer_id,
1549                            %key,
1550                            "could not open channel",
1551                        );
1552
1553                        // Return an error response to the channels module if the open_channel call
1554                        // failed.
1555                        Box::pin(future::ready((
1556                            offer_id,
1557                            seq,
1558                            Ok(ChannelResponse::Open(false)),
1559                        )))
1560                    }
1561                }
1562            }
1563            channels::Action::Close => {
1564                if let Some(channel_bitmap) = self.channel_bitmap.as_ref() {
1565                    if let ChannelState::Open(ref state) = channel.state {
1566                        channel_bitmap.unregister_channel(state.open_params.event_flag);
1567                    }
1568                }
1569
1570                channel.state = ChannelState::Closing;
1571                handle(offer_id, channel, ChannelRequest::Close, (), |()| {
1572                    ChannelResponse::Close
1573                })
1574            }
1575            channels::Action::Gpadl(gpadl_id, count, buf) => {
1576                channel.gpadls.add(
1577                    gpadl_id,
1578                    MultiPagedRangeBuf::from_range_buffer(count.into(), buf.clone()).unwrap(),
1579                );
1580                handle(
1581                    offer_id,
1582                    channel,
1583                    ChannelRequest::Gpadl,
1584                    GpadlRequest {
1585                        id: gpadl_id,
1586                        count,
1587                        buf,
1588                    },
1589                    move |r| ChannelResponse::Gpadl(gpadl_id, r),
1590                )
1591            }
1592            channels::Action::TeardownGpadl {
1593                gpadl_id,
1594                post_restore,
1595            } => {
1596                if !post_restore {
1597                    channel.gpadls.remove(gpadl_id, Box::new(|| ()));
1598                }
1599
1600                handle(
1601                    offer_id,
1602                    channel,
1603                    ChannelRequest::TeardownGpadl,
1604                    gpadl_id,
1605                    move |()| ChannelResponse::TeardownGpadl(gpadl_id),
1606                )
1607            }
1608            channels::Action::Modify { target_vp } => {
1609                let ChannelState::Open(state) = &mut channel.state else {
1610                    unreachable!();
1611                };
1612
1613                if let Err(err) = state.set_event_port_target_vp(target_vp) {
1614                    tracelimit::error_ratelimited!(
1615                        error = err.as_ref() as &dyn std::error::Error,
1616                        channel = %channel.key,
1617                        "could not modify channel",
1618                    );
1619
1620                    // Send an immediate error response.
1621                    let seq = channel.seq;
1622                    Box::pin(async move {
1623                        (
1624                            offer_id,
1625                            seq,
1626                            Ok(ChannelResponse::Modify(protocol::STATUS_UNSUCCESSFUL)),
1627                        )
1628                    })
1629                } else {
1630                    handle(
1631                        offer_id,
1632                        channel,
1633                        ChannelRequest::Modify,
1634                        ModifyRequest::TargetVp { target_vp },
1635                        ChannelResponse::Modify,
1636                    )
1637                }
1638            }
1639        };
1640        self.channel_responses.push(response);
1641    }
1642
1643    fn modify_connection(&mut self, mut request: ModifyConnectionRequest) -> anyhow::Result<()> {
1644        self.map_interrupt_page(request.interrupt_page)
1645            .context("Failed to map interrupt page.")?;
1646
1647        self.set_monitor_page(&mut request)
1648            .context("Failed to map monitor page.")?;
1649
1650        if let Some(vp) = request.target_message_vp {
1651            self.message_port.set_target_vp(vp)?;
1652        }
1653
1654        if request.notify_relay {
1655            self.relay_send.send(request.into());
1656        }
1657
1658        Ok(())
1659    }
1660
1661    fn forward_unhandled(&mut self, request: InitiateContactRequest) {
1662        if let Some(external_server) = &self.external_server_send {
1663            external_server.send(request);
1664        } else {
1665            tracing::warn!(?request, "nowhere to forward unhandled request")
1666        }
1667    }
1668
1669    fn inspect(&self, version: Option<VersionInfo>, offer_id: OfferId, req: inspect::Request<'_>) {
1670        let channel = self.channels.get(&offer_id).expect("should exist");
1671        let mut resp = req.respond();
1672        if let ChannelState::Open(state) = &channel.state {
1673            let mem = self.get_gm_for_channel(version.expect("must be connected"), channel);
1674            inspect_rings(
1675                &mut resp,
1676                mem,
1677                channel.gpadls.clone(),
1678                &state.open_params.open_data,
1679            );
1680        }
1681    }
1682
1683    fn send_message(&mut self, message: &OutgoingMessage, target: MessageTarget) -> bool {
1684        // If the server is paused, queue all messages, to avoid affecting synic
1685        // state during/after it has been saved or reset.
1686        //
1687        // Note that messages to reserved channels or custom targets will be
1688        // dropped. However, such messages should only be sent in response to
1689        // guest requests, which should not be processed while the server is
1690        // paused.
1691        //
1692        // FUTURE: it would be better to ensure that no messages are generated
1693        // by operations that run while the server is paused. E.g., defer
1694        // sending offer or revoke messages for new or revoked offers. This
1695        // would prevent the queue from growing without bound.
1696        if !self.running && !self.send_messages_while_stopped {
1697            if !matches!(target, MessageTarget::Default) {
1698                tracelimit::error_ratelimited!(?target, "dropping message while paused");
1699            }
1700            return false;
1701        }
1702
1703        let mut port_storage;
1704        let port = match target {
1705            MessageTarget::Default => self.message_port.as_mut(),
1706            MessageTarget::ReservedChannel(offer_id, target) => {
1707                if let Some(port) = self.get_reserved_channel_message_port(offer_id, target) {
1708                    port.as_mut()
1709                } else {
1710                    // Updating the port failed, so there is no way to send the message.
1711                    return true;
1712                }
1713            }
1714            MessageTarget::Custom(target) => {
1715                port_storage = match self.synic.new_guest_message_port(
1716                    self.redirect_vtl,
1717                    target.vp,
1718                    target.sint,
1719                ) {
1720                    Ok(port) => port,
1721                    Err(err) => {
1722                        tracing::error!(
1723                            ?err,
1724                            ?self.redirect_vtl,
1725                            ?target,
1726                            "could not create message port"
1727                        );
1728
1729                        // There is no way to send the message.
1730                        return true;
1731                    }
1732                };
1733                port_storage.as_mut()
1734            }
1735        };
1736
1737        // If this returns Pending, the channels module will queue the message and the ServerTask
1738        // main loop will try to send it again later.
1739        matches!(
1740            port.poll_post_message(
1741                &mut std::task::Context::from_waker(std::task::Waker::noop()),
1742                VMBUS_MESSAGE_TYPE,
1743                message.data()
1744            ),
1745            Poll::Ready(())
1746        )
1747    }
1748
1749    fn notify_hvsock(&mut self, request: &HvsockConnectRequest) {
1750        tracing::debug!(?request, "received hvsock connect request");
1751        self.hvsock_requests += 1;
1752        self.hvsock_send.send(*request);
1753    }
1754
1755    fn reset_complete(&mut self) {
1756        if let Some(monitor) = self.synic.monitor_support() {
1757            if let Err(err) = monitor.set_monitor_page(self.vtl, None) {
1758                tracing::warn!(?err, "resetting monitor page failed")
1759            }
1760        }
1761
1762        self.unreserve_channels();
1763        for done in self.reset_done.drain(..) {
1764            done.complete(());
1765        }
1766    }
1767
1768    fn unload_complete(&mut self) {
1769        self.unreserve_channels();
1770    }
1771}
1772
1773impl ServerTaskInner {
1774    fn open_channel(
1775        &mut self,
1776        offer_id: OfferId,
1777        open_params: &OpenParams,
1778    ) -> anyhow::Result<(&mut Channel, Interrupt)> {
1779        let channel = self
1780            .channels
1781            .get_mut(&offer_id)
1782            .expect("channel does not exist");
1783
1784        // Always register with the channel bitmap; if Win7, this may be unnecessary.
1785        if let Some(channel_bitmap) = self.channel_bitmap.as_ref() {
1786            channel_bitmap.register_channel(
1787                open_params.event_flag,
1788                channel.guest_to_host_event.0.clone(),
1789            );
1790        }
1791        // Always set up an event port; if V1, this will be unused.
1792        // N.B. The event port must be created before the device is notified of the open by the
1793        //      caller. The device may begin communicating with the guest immediately when it is
1794        //      notified, so the event port must exist so that the guest can send interrupts.
1795        let event_port = self
1796            .synic
1797            .add_event_port(
1798                open_params.connection_id,
1799                self.vtl,
1800                channel.guest_to_host_event.clone(),
1801                open_params.monitor_info,
1802            )
1803            .context("failed to create guest-to-host event port")?;
1804
1805        // For pre-Win8 guests, the host-to-guest event always targets vp 0 and the channel
1806        // bitmap is used instead of the event flag.
1807        let (target_vp, event_flag) = if self.channel_bitmap.is_some() {
1808            (Some(0), 0)
1809        } else {
1810            (open_params.open_data.target_vp, open_params.event_flag)
1811        };
1812
1813        let (guest_event_port, interrupt) = if let Some(target_vp) = target_vp {
1814            let (target_vtl, target_sint) = if open_params.flags.redirect_interrupt() {
1815                (self.redirect_vtl, self.redirect_sint)
1816            } else {
1817                (self.vtl, VMBUS_SINT)
1818            };
1819
1820            let guest_event_port = self.synic.new_guest_event_port(
1821                VmbusServer::get_child_event_port_id(open_params.channel_id, VMBUS_SINT, self.vtl),
1822                target_vtl,
1823                target_vp,
1824                target_sint,
1825                event_flag,
1826                open_params.monitor_info,
1827            )?;
1828
1829            let interrupt = ChannelBitmap::create_interrupt(
1830                &self.channel_bitmap,
1831                guest_event_port.interrupt(),
1832                open_params.event_flag,
1833            );
1834
1835            (Some(guest_event_port), interrupt)
1836        } else {
1837            // Use a dummy interrupt which does nothing, but make sure it has an event to avoid
1838            // proxy_integration from trying to wrap it.
1839            (None, Interrupt::null_event())
1840        };
1841
1842        // Delete any previously reserved state.
1843        channel.reserved_state.message_port = None;
1844
1845        // If the channel is reserved, create a message port for it.
1846        if let Some(target) = open_params.reserved_target {
1847            channel.reserved_state.message_port = Some(self.synic.new_guest_message_port(
1848                self.redirect_vtl,
1849                target.vp,
1850                target.sint,
1851            )?);
1852
1853            channel.reserved_state.target = target;
1854        }
1855
1856        channel.state = ChannelState::Open(Box::new(ChannelOpenState {
1857            open_params: *open_params,
1858            _event_port: event_port,
1859            guest_event_port,
1860            host_to_guest_interrupt: interrupt.clone(),
1861        }));
1862        Ok((channel, interrupt))
1863    }
1864
1865    /// If the client specified an interrupt page, map it into host memory and
1866    /// set up the shared event port.
1867    fn map_interrupt_page(&mut self, interrupt_page: Update<u64>) -> anyhow::Result<()> {
1868        let interrupt_page = match interrupt_page {
1869            Update::Unchanged => return Ok(()),
1870            Update::Reset => {
1871                self.channel_bitmap = None;
1872                self.shared_event_port = None;
1873                return Ok(());
1874            }
1875            Update::Set(interrupt_page) => interrupt_page,
1876        };
1877
1878        assert_ne!(interrupt_page, 0);
1879
1880        if interrupt_page % PAGE_SIZE as u64 != 0 {
1881            anyhow::bail!("interrupt page {:#x} is not page aligned", interrupt_page);
1882        }
1883
1884        // Use a subrange to access the interrupt page to give GuestMemory's without a full mapping
1885        // a chance to create one.
1886        let interrupt_page = lock_page_with_subrange(&self.gm, interrupt_page)?;
1887        let channel_bitmap = Arc::new(ChannelBitmap::new(interrupt_page));
1888        self.channel_bitmap = Some(channel_bitmap.clone());
1889
1890        // Create the shared event port for pre-Win8 guests.
1891        let interrupt = Interrupt::from_fn(move || {
1892            channel_bitmap.handle_shared_interrupt();
1893        });
1894
1895        self.shared_event_port = Some(self.synic.add_event_port(
1896            SHARED_EVENT_CONNECTION_ID,
1897            self.vtl,
1898            Arc::new(ChannelEvent(interrupt)),
1899            None,
1900        )?);
1901
1902        Ok(())
1903    }
1904
1905    fn set_monitor_page(&mut self, request: &mut ModifyConnectionRequest) -> anyhow::Result<()> {
1906        let monitor_page = match request.monitor_page {
1907            Update::Unchanged => return Ok(()),
1908            Update::Reset => None,
1909            Update::Set(value) => Some(value),
1910        };
1911
1912        // TODO: can this check be moved into channels.rs?
1913        if self.channels.iter().any(|(_, c)| {
1914            matches!(
1915                &c.state,
1916                ChannelState::Open(state) if state.open_params.monitor_info.is_some()
1917            )
1918        }) {
1919            anyhow::bail!("attempt to change monitor page while open channels using mnf");
1920        }
1921
1922        // Check if the server is handling MNF.
1923        // N.B. If the server is not handling MNF, there is currently no way to request
1924        //      server-allocated monitor pages from the relay host.
1925        if let Some(mnf_support) = self.mnf_support.as_mut() {
1926            if let Some(monitor) = self.synic.monitor_support() {
1927                mnf_support.allocated_monitor_page = None;
1928
1929                if let Some(version) = request.version {
1930                    if version.feature_flags.server_specified_monitor_pages() {
1931                        if let Some(monitor_page) = monitor.allocate_monitor_page(self.vtl)? {
1932                            tracelimit::info_ratelimited!(
1933                                ?monitor_page,
1934                                "using server-allocated monitor pages"
1935                            );
1936                            mnf_support.allocated_monitor_page = Some(monitor_page);
1937                        }
1938                    }
1939                }
1940
1941                // If no monitor page was allocated above, use the one provided by the client.
1942                if mnf_support.allocated_monitor_page.is_none() {
1943                    if let Err(err) = monitor.set_monitor_page(self.vtl, monitor_page) {
1944                        anyhow::bail!(
1945                            "setting monitor page failed, err = {err:?}, monitor_page = {monitor_page:?}"
1946                        );
1947                    }
1948                }
1949            }
1950
1951            // If MNF is configured to be handled by this server (even if it's not actually
1952            // supported by the synic), don't forward the pages to the relay.
1953            request.monitor_page = Update::Unchanged;
1954        }
1955
1956        Ok(())
1957    }
1958
1959    fn get_reserved_channel_message_port(
1960        &mut self,
1961        offer_id: OfferId,
1962        new_target: ConnectionTarget,
1963    ) -> Option<&mut Box<dyn GuestMessagePort>> {
1964        let channel = self
1965            .channels
1966            .get_mut(&offer_id)
1967            .expect("channel does not exist");
1968
1969        assert!(
1970            channel.reserved_state.message_port.is_some(),
1971            "channel is not reserved"
1972        );
1973
1974        // On close, the guest may have changed the message target it wants to use for the close
1975        // response. If so, update the message port.
1976        if channel.reserved_state.target.sint != new_target.sint {
1977            // Destroy the old port before creating the new one.
1978            channel.reserved_state.message_port = None;
1979            let message_port = self
1980                .synic
1981                .new_guest_message_port(self.redirect_vtl, new_target.vp, new_target.sint)
1982                .inspect_err(|err| {
1983                    tracing::error!(
1984                        key = %channel.key,
1985                        ?err,
1986                        ?self.redirect_vtl,
1987                        ?new_target,
1988                        "could not create reserved channel message port"
1989                    )
1990                })
1991                .ok()?;
1992
1993            channel.reserved_state.message_port = Some(message_port);
1994            channel.reserved_state.target = new_target;
1995        } else if channel.reserved_state.target.vp != new_target.vp {
1996            let message_port = channel.reserved_state.message_port.as_mut().unwrap();
1997
1998            // The vp has changed, but the SINT is the same. Just update the vp. If this fails,
1999            // ignore it and just send to the old vp.
2000            if let Err(err) = message_port.set_target_vp(new_target.vp) {
2001                tracing::error!(
2002                    key = %channel.key,
2003                    ?err,
2004                    ?self.redirect_vtl,
2005                    ?new_target,
2006                    "could not update reserved channel message port"
2007                );
2008            }
2009
2010            channel.reserved_state.target = new_target;
2011            return Some(message_port);
2012        }
2013
2014        Some(channel.reserved_state.message_port.as_mut().unwrap())
2015    }
2016
2017    fn unreserve_channels(&mut self) {
2018        // Unreserve all closed channels.
2019        for channel in self.channels.values_mut() {
2020            if let ChannelState::Closed = channel.state {
2021                channel.reserved_state.message_port = None;
2022            }
2023        }
2024    }
2025
2026    fn get_gm_for_channel(&self, version: VersionInfo, channel: &Channel) -> &GuestMemory {
2027        select_gm_for_channel(&self.gm, self.private_gm.as_ref(), version, channel)
2028    }
2029}
2030
2031fn select_gm_for_channel<'a>(
2032    gm: &'a GuestMemory,
2033    private_gm: Option<&'a GuestMemory>,
2034    version: VersionInfo,
2035    channel: &Channel,
2036) -> &'a GuestMemory {
2037    if channel.flags.confidential_ring_buffer() && version.feature_flags.confidential_channels() {
2038        if let Some(private_gm) = private_gm {
2039            return private_gm;
2040        }
2041    }
2042
2043    gm
2044}
2045
2046/// Control point for [`VmbusServer`], allowing callers to offer channels.
2047#[derive(Clone)]
2048pub struct VmbusServerControl {
2049    mem: GuestMemory,
2050    private_mem: Option<GuestMemory>,
2051    send: mesh::Sender<OfferRequest>,
2052    use_event: bool,
2053    force_confidential_external_memory: bool,
2054}
2055
2056impl VmbusServerControl {
2057    /// Offers a channel to the vmbus server, where the flags and user_defined data are already set.
2058    /// This is used by the relay to forward the host's parameters.
2059    pub async fn offer_core(&self, offer_info: OfferInfo) -> anyhow::Result<OfferResources> {
2060        let flags = offer_info.params.flags;
2061        self.send
2062            .call_failable(OfferRequest::Offer, offer_info)
2063            .await?;
2064        Ok(OfferResources::new(
2065            self.mem.clone(),
2066            if flags.confidential_ring_buffer() || flags.confidential_external_memory() {
2067                self.private_mem.clone()
2068            } else {
2069                None
2070            },
2071        ))
2072    }
2073
2074    /// Force reset all channels and protocol state, without requiring the
2075    /// server to be paused.
2076    pub async fn force_reset(&self) -> anyhow::Result<()> {
2077        self.send
2078            .call(OfferRequest::ForceReset, ())
2079            .await
2080            .context("vmbus server is gone")
2081    }
2082
2083    async fn offer(&self, request: OfferInput) -> anyhow::Result<OfferResources> {
2084        let mut offer_info = OfferInfo {
2085            params: request.params.into(),
2086            event: request.event,
2087            request_send: request.request_send,
2088            server_request_recv: request.server_request_recv,
2089        };
2090
2091        if self.force_confidential_external_memory {
2092            tracing::warn!(
2093                key = %offer_info.params.key(),
2094                "forcing confidential external memory for channel"
2095            );
2096
2097            offer_info
2098                .params
2099                .flags
2100                .set_confidential_external_memory(true);
2101        }
2102
2103        self.offer_core(offer_info).await
2104    }
2105}
2106
2107/// Inspects the specified ring buffer state by directly accessing guest memory.
2108fn inspect_rings(
2109    resp: &mut inspect::Response<'_>,
2110    gm: &GuestMemory,
2111    gpadl_map: Arc<GpadlMap>,
2112    open_data: &OpenData,
2113) -> Option<()> {
2114    let gpadl = gpadl_map
2115        .view()
2116        .map(GpadlId(open_data.ring_gpadl_id.0))
2117        .ok()?;
2118
2119    let aligned = AlignedGpadlView::new(gpadl).ok()?;
2120    let (in_gpadl, out_gpadl) = aligned.split(open_data.ring_offset).ok()?;
2121    resp.child("incoming_ring", |req| inspect_ring(req, &in_gpadl, gm));
2122    resp.child("outgoing_ring", |req| inspect_ring(req, &out_gpadl, gm));
2123    Some(())
2124}
2125
2126/// Inspects the incoming or outgoing ring buffer by directly accessing guest memory.
2127fn inspect_ring(req: inspect::Request<'_>, gpadl: &AlignedGpadlView, gm: &GuestMemory) {
2128    let mut resp = req.respond();
2129
2130    resp.hex("ring_size", gpadl_ring_size(gpadl));
2131
2132    // Lock just the control page. Use a subrange to allow a GuestMemory without a full mapping to
2133    // create one.
2134    if let Ok(pages) = lock_gpn_with_subrange(gm, gpadl.gpns()[0]) {
2135        ring::inspect_ring(pages.pages()[0], &mut resp);
2136    }
2137}
2138
2139fn gpadl_ring_size(gpadl: &AlignedGpadlView) -> usize {
2140    // Data size excluding the control page.
2141    (gpadl.gpns().len() - 1) * PAGE_SIZE
2142}
2143
2144/// Helper to create a subrange before locking a single page.
2145///
2146/// This allows us to lock a page in a `GuestMemory` that doesn't have a full mapping, but can
2147/// create one for a subrange.
2148fn lock_page_with_subrange(gm: &GuestMemory, offset: u64) -> anyhow::Result<guestmem::LockedPages> {
2149    Ok(gm
2150        .lockable_subrange(offset, PAGE_SIZE as u64)?
2151        .lock_gpns(false, &[0])?)
2152}
2153
2154/// Helper to create a subrange before locking a single page from a gpn.
2155///
2156/// This allows us to lock a page in a `GuestMemory` that doesn't have a full mapping, but can
2157/// create one for a subrange.
2158fn lock_gpn_with_subrange(gm: &GuestMemory, gpn: u64) -> anyhow::Result<guestmem::LockedPages> {
2159    lock_page_with_subrange(gm, gpn * PAGE_SIZE as u64)
2160}
2161
2162pub(crate) struct MessageSender {
2163    send: mpsc::Sender<SynicMessage>,
2164    multiclient: bool,
2165}
2166
2167impl MessageSender {
2168    fn poll_handle_message(
2169        &self,
2170        cx: &mut std::task::Context<'_>,
2171        msg: &[u8],
2172        trusted: bool,
2173    ) -> Poll<Result<(), SendError>> {
2174        let mut send = self.send.clone();
2175        ready!(send.poll_ready(cx))?;
2176        send.start_send(SynicMessage {
2177            data: msg.to_vec(),
2178            multiclient: self.multiclient,
2179            trusted,
2180        })?;
2181
2182        Poll::Ready(Ok(()))
2183    }
2184}
2185
2186impl MessagePort for MessageSender {
2187    fn poll_handle_message(
2188        &self,
2189        cx: &mut std::task::Context<'_>,
2190        msg: &[u8],
2191        trusted: bool,
2192    ) -> Poll<()> {
2193        if let Err(err) = ready!(self.poll_handle_message(cx, msg, trusted)) {
2194            tracelimit::error_ratelimited!(
2195                error = &err as &dyn std::error::Error,
2196                "failed to send message"
2197            );
2198        }
2199
2200        Poll::Ready(())
2201    }
2202}
2203
2204#[async_trait]
2205impl ParentBus for VmbusServerControl {
2206    async fn add_child(&self, request: OfferInput) -> anyhow::Result<OfferResources> {
2207        self.offer(request).await
2208    }
2209
2210    fn clone_bus(&self) -> Box<dyn ParentBus> {
2211        Box::new(self.clone())
2212    }
2213
2214    fn use_event(&self) -> bool {
2215        self.use_event
2216    }
2217}