vmbus_server/
lib.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4#![expect(missing_docs)]
5#![forbid(unsafe_code)]
6
7mod channel_bitmap;
8pub mod channels;
9pub mod event;
10pub mod hvsock;
11mod monitor;
12mod proxyintegration;
13#[cfg(test)]
14mod tests;
15
16/// The GUID type used for vmbus channel identifiers.
17pub type Guid = guid::Guid;
18
19use anyhow::Context;
20use async_trait::async_trait;
21use channel_bitmap::ChannelBitmap;
22use channels::ConnectionTarget;
23pub use channels::InitiateContactRequest;
24use channels::MessageTarget;
25pub use channels::MnfUsage;
26use channels::ModifyConnectionRequest;
27use channels::ModifyConnectionResponse;
28use channels::Notifier;
29use channels::OfferId;
30pub use channels::OfferParamsInternal;
31use channels::OpenParams;
32use channels::RestoreError;
33pub use channels::Update;
34use futures::FutureExt;
35use futures::StreamExt;
36use futures::channel::mpsc;
37use futures::channel::mpsc::SendError;
38use futures::future::OptionFuture;
39use futures::future::poll_fn;
40use futures::stream::SelectAll;
41use guestmem::GuestMemory;
42use hvdef::Vtl;
43use inspect::Inspect;
44use mesh::payload::Protobuf;
45use mesh::rpc::FailableRpc;
46use mesh::rpc::Rpc;
47use mesh::rpc::RpcError;
48use mesh::rpc::RpcSend;
49use pal_async::driver::Driver;
50use pal_async::driver::SpawnDriver;
51use pal_async::task::Task;
52use pal_async::timer::PolledTimer;
53use pal_event::Event;
54#[cfg(windows)]
55pub use proxyintegration::ProxyIntegration;
56#[cfg(windows)]
57pub use proxyintegration::ProxyServerInfo;
58use ring::PAGE_SIZE;
59use std::collections::HashMap;
60use std::future;
61use std::future::Future;
62use std::pin::Pin;
63use std::sync::Arc;
64use std::task::Poll;
65use std::task::ready;
66use std::time::Duration;
67use unicycle::FuturesUnordered;
68use vmbus_channel::bus::ChannelRequest;
69use vmbus_channel::bus::ChannelServerRequest;
70use vmbus_channel::bus::GpadlRequest;
71use vmbus_channel::bus::ModifyRequest;
72use vmbus_channel::bus::OfferInput;
73use vmbus_channel::bus::OfferKey;
74use vmbus_channel::bus::OfferResources;
75use vmbus_channel::bus::OpenData;
76use vmbus_channel::bus::OpenRequest;
77use vmbus_channel::bus::ParentBus;
78use vmbus_channel::bus::RestoreResult;
79use vmbus_channel::gpadl::GpadlMap;
80use vmbus_channel::gpadl_ring::AlignedGpadlView;
81use vmbus_core::HvsockConnectRequest;
82use vmbus_core::HvsockConnectResult;
83use vmbus_core::MaxVersionInfo;
84use vmbus_core::OutgoingMessage;
85use vmbus_core::TaggedStream;
86use vmbus_core::VMBUS_SINT;
87use vmbus_core::VersionInfo;
88use vmbus_core::protocol;
89pub use vmbus_core::protocol::GpadlId;
90#[cfg(windows)]
91use vmbus_proxy::ProxyHandle;
92use vmbus_ring as ring;
93use vmbus_ring::gparange::MultiPagedRangeBuf;
94use vmcore::interrupt::Interrupt;
95use vmcore::save_restore::SavedStateRoot;
96use vmcore::synic::EventPort;
97use vmcore::synic::GuestEventPort;
98use vmcore::synic::GuestMessagePort;
99use vmcore::synic::MessagePort;
100use vmcore::synic::MonitorPageGpas;
101use vmcore::synic::SynicPortAccess;
102
103pub const REDIRECT_SINT: u8 = 7;
104pub const REDIRECT_VTL: Vtl = Vtl::Vtl2;
105const SHARED_EVENT_CONNECTION_ID: u32 = 2;
106const EVENT_PORT_ID: u32 = 2;
107const VMBUS_MESSAGE_TYPE: u32 = 1;
108
109const MAX_CONCURRENT_HVSOCK_REQUESTS: usize = 16;
110
111#[derive(Inspect)]
112pub struct VmbusServer {
113    #[inspect(flatten, send = "VmbusRequest::Inspect")]
114    task_send: mesh::Sender<VmbusRequest>,
115    #[inspect(skip)]
116    control: Arc<VmbusServerControl>,
117    #[inspect(skip)]
118    _message_port: Box<dyn Sync + Send>,
119    #[inspect(skip)]
120    _multiclient_message_port: Option<Box<dyn Sync + Send>>,
121    #[inspect(skip)]
122    task: Task<ServerTask>,
123}
124
125pub struct VmbusServerBuilder<T: SpawnDriver> {
126    spawner: T,
127    synic: Arc<dyn SynicPortAccess>,
128    gm: GuestMemory,
129    private_gm: Option<GuestMemory>,
130    vtl: Vtl,
131    hvsock_notify: Option<HvsockServerChannelHalf>,
132    server_relay: Option<VmbusServerChannelHalf>,
133    saved_state_notify: Option<mesh::Sender<SavedStateRequest>>,
134    external_server: Option<mesh::Sender<InitiateContactRequest>>,
135    external_requests: Option<mesh::Receiver<InitiateContactRequest>>,
136    use_message_redirect: bool,
137    channel_id_offset: u16,
138    max_version: Option<MaxVersionInfo>,
139    delay_max_version: bool,
140    enable_mnf: bool,
141    force_confidential_external_memory: bool,
142    send_messages_while_stopped: bool,
143    channel_unstick_delay: Option<Duration>,
144    use_absolute_channel_order: bool,
145}
146
147#[derive(mesh::MeshPayload)]
148/// The request to send to the proxy to set or clear its saved state cache.
149pub enum SavedStateRequest {
150    Set(FailableRpc<Box<channels::SavedState>, ()>),
151    Clear(Rpc<(), ()>),
152}
153
154/// The server side of the connection between a vmbus server and a relay.
155pub struct ServerChannelHalf<Request, Response> {
156    request_send: mesh::Sender<Request>,
157    response_receive: mesh::Receiver<Response>,
158}
159
160/// The relay side of a connection between a vmbus server and a relay.
161pub struct RelayChannelHalf<Request, Response> {
162    pub request_receive: mesh::Receiver<Request>,
163    pub response_send: mesh::Sender<Response>,
164}
165
166/// A connection between a vmbus server and a relay.
167pub struct RelayChannel<Request, Response> {
168    pub relay_half: RelayChannelHalf<Request, Response>,
169    pub server_half: ServerChannelHalf<Request, Response>,
170}
171
172impl<Request: 'static + Send, Response: 'static + Send> RelayChannel<Request, Response> {
173    /// Creates a new channel between the vmbus server and a relay.
174    pub fn new() -> Self {
175        let (request_send, request_receive) = mesh::channel();
176        let (response_send, response_receive) = mesh::channel();
177        Self {
178            relay_half: RelayChannelHalf {
179                request_receive,
180                response_send,
181            },
182            server_half: ServerChannelHalf {
183                request_send,
184                response_receive,
185            },
186        }
187    }
188}
189
190pub type VmbusServerChannelHalf = ServerChannelHalf<ModifyRelayRequest, ModifyRelayResponse>;
191pub type VmbusRelayChannelHalf = RelayChannelHalf<ModifyRelayRequest, ModifyRelayResponse>;
192pub type VmbusRelayChannel = RelayChannel<ModifyRelayRequest, ModifyRelayResponse>;
193pub type HvsockServerChannelHalf = ServerChannelHalf<HvsockConnectRequest, HvsockConnectResult>;
194pub type HvsockRelayChannelHalf = RelayChannelHalf<HvsockConnectRequest, HvsockConnectResult>;
195pub type HvsockRelayChannel = RelayChannel<HvsockConnectRequest, HvsockConnectResult>;
196
197/// A request from the server to the relay to modify connection state.
198///
199/// The version and use_interrupt_page fields can only be present if this request was sent for an
200/// InitiateContact message from the guest.
201///
202/// If `version` is `Some`, the relay must respond with either `ModifyRelayResponse::Supported` or
203/// `ModifyRelayResponse::Unsupported`. If `version` is `None`, the relay must respond with
204/// `ModifyRelayResponse::Modified`.
205#[derive(Debug, Copy, Clone)]
206pub struct ModifyRelayRequest {
207    pub version: Option<u32>,
208    pub monitor_page: Update<MonitorPageGpas>,
209    pub use_interrupt_page: Option<bool>,
210}
211
212/// A response from the relay to a ModifyRelayRequest from the server.
213#[derive(Debug, Copy, Clone)]
214pub enum ModifyRelayResponse {
215    /// The requested version change is supported, and the relay completed the connection
216    /// modification with the specified status. All of the feature flags supported by the relay host
217    /// are included, regardless of what features were requested.
218    Supported(protocol::ConnectionState, protocol::FeatureFlags),
219    /// A version change was requested but the relay host doesn't support that version. This
220    /// response cannot be returned for a request with no version change set.
221    Unsupported,
222    /// The connection modification completed with the specified status. This response must be sent
223    /// if no version change was requested.
224    Modified(protocol::ConnectionState),
225}
226
227impl From<ModifyConnectionRequest> for ModifyRelayRequest {
228    fn from(value: ModifyConnectionRequest) -> Self {
229        Self {
230            version: value.version.map(|v| v.version as u32),
231            monitor_page: value.monitor_page,
232            use_interrupt_page: match value.interrupt_page {
233                Update::Unchanged => None,
234                Update::Reset => Some(false),
235                Update::Set(_) => Some(true),
236            },
237        }
238    }
239}
240
241#[derive(Debug)]
242enum VmbusRequest {
243    Reset(Rpc<(), ()>),
244    Inspect(inspect::Deferred),
245    Save(Rpc<(), SavedState>),
246    Restore(Rpc<Box<SavedState>, Result<(), RestoreError>>),
247    Start,
248    Stop(Rpc<(), ()>),
249}
250
251#[derive(mesh::MeshPayload, Debug)]
252pub struct OfferInfo {
253    pub params: OfferParamsInternal,
254    pub event: Interrupt,
255    pub request_send: mesh::Sender<ChannelRequest>,
256    pub server_request_recv: mesh::Receiver<ChannelServerRequest>,
257}
258
259#[expect(clippy::large_enum_variant)]
260#[derive(mesh::MeshPayload)]
261pub(crate) enum OfferRequest {
262    Offer(FailableRpc<OfferInfo, ()>),
263    ForceReset(Rpc<(), ()>),
264}
265
266struct ChannelEvent(Interrupt);
267
268impl EventPort for ChannelEvent {
269    fn handle_event(&self, _flag: u16) {
270        self.0.deliver();
271    }
272
273    fn os_event(&self) -> Option<&Event> {
274        self.0.event()
275    }
276}
277
278#[derive(Debug, Protobuf, SavedStateRoot)]
279#[mesh(package = "vmbus.server")]
280pub struct SavedState {
281    #[mesh(1)]
282    pub server: channels::SavedState,
283    // Indicates if the lost synic bug is fixed or not. By default it's false.
284    // During the restore process, we check if the field is not true then
285    // unstick_channels() function will be called to mitigate the issue.
286    #[mesh(2)]
287    pub lost_synic_bug_fixed: bool,
288}
289
290const MESSAGE_CONNECTION_ID: u32 = 1;
291const MULTICLIENT_MESSAGE_CONNECTION_ID: u32 = 4;
292
293impl<T: SpawnDriver + Clone> VmbusServerBuilder<T> {
294    /// Creates a new builder for `VmbusServer` with the default options.
295    pub fn new(spawner: T, synic: Arc<dyn SynicPortAccess>, gm: GuestMemory) -> Self {
296        Self {
297            spawner,
298            synic,
299            gm,
300            private_gm: None,
301            vtl: Vtl::Vtl0,
302            hvsock_notify: None,
303            server_relay: None,
304            saved_state_notify: None,
305            external_server: None,
306            external_requests: None,
307            use_message_redirect: false,
308            channel_id_offset: 0,
309            max_version: None,
310            delay_max_version: false,
311            enable_mnf: false,
312            force_confidential_external_memory: false,
313            send_messages_while_stopped: false,
314            channel_unstick_delay: Some(Duration::from_millis(100)),
315            use_absolute_channel_order: false,
316        }
317    }
318
319    /// Sets a separate guest memory instance to use for channels that are confidential (non-relay
320    /// channels in Underhill on a hardware isolated VM). This is not relevant for a non-Underhill
321    /// VmBus server.
322    pub fn private_gm(mut self, private_gm: Option<GuestMemory>) -> Self {
323        self.private_gm = private_gm;
324        self
325    }
326
327    /// Sets the VTL that this instance will serve.
328    pub fn vtl(mut self, vtl: Vtl) -> Self {
329        self.vtl = vtl;
330        self
331    }
332
333    /// Sets a send/receive pair used to handle hvsocket requests.
334    pub fn hvsock_notify(mut self, hvsock_notify: Option<HvsockServerChannelHalf>) -> Self {
335        self.hvsock_notify = hvsock_notify;
336        self
337    }
338
339    /// Sets a send channel used to enlighten ProxyIntegration about saved channels.
340    pub fn saved_state_notify(
341        mut self,
342        saved_state_notify: Option<mesh::Sender<SavedStateRequest>>,
343    ) -> Self {
344        self.saved_state_notify = saved_state_notify;
345        self
346    }
347
348    /// Sets a send/receive pair that will be notified of server requests. This is used by the
349    /// Underhill relay.
350    pub fn server_relay(mut self, server_relay: Option<VmbusServerChannelHalf>) -> Self {
351        self.server_relay = server_relay;
352        self
353    }
354
355    /// Sets a receiver that receives requests from another server.
356    pub fn external_requests(
357        mut self,
358        external_requests: Option<mesh::Receiver<InitiateContactRequest>>,
359    ) -> Self {
360        self.external_requests = external_requests;
361        self
362    }
363
364    /// Sets a sender used to forward unhandled connect requests (which used a different VTL)
365    /// to another server.
366    pub fn external_server(
367        mut self,
368        external_server: Option<mesh::Sender<InitiateContactRequest>>,
369    ) -> Self {
370        self.external_server = external_server;
371        self
372    }
373
374    /// Sets a value which indicates whether the vmbus control plane is redirected to Underhill.
375    pub fn use_message_redirect(mut self, use_message_redirect: bool) -> Self {
376        self.use_message_redirect = use_message_redirect;
377        self
378    }
379
380    /// Tells the server to use an offset when generating channel IDs to void collisions with
381    /// another vmbus server.
382    ///
383    /// N.B. This should only be used by the Underhill vmbus server.
384    pub fn enable_channel_id_offset(mut self, enable: bool) -> Self {
385        self.channel_id_offset = if enable { 1024 } else { 0 };
386        self
387    }
388
389    /// Tells the server to limit the protocol version offered to the guest.
390    ///
391    /// N.B. This is used for testing older protocols without requiring a specific guest OS.
392    pub fn max_version(mut self, max_version: Option<MaxVersionInfo>) -> Self {
393        self.max_version = max_version;
394        self
395    }
396
397    /// Delay limiting the maximum version until after the first `Unload` message.
398    ///
399    /// N.B. This is used to enable the use of versions older than `Version::Win10` with Uefi boot,
400    ///      since that's the oldest version the Uefi client supports.
401    pub fn delay_max_version(mut self, delay: bool) -> Self {
402        self.delay_max_version = delay;
403        self
404    }
405
406    /// Enable MNF support in the server.
407    ///
408    /// N.B. Enabling this has no effect if the synic does not support mapping monitor pages.
409    pub fn enable_mnf(mut self, enable: bool) -> Self {
410        self.enable_mnf = enable;
411        self
412    }
413
414    /// Force all non-relay channels to use encrypted external memory. Used for testing purposes
415    /// only.
416    pub fn force_confidential_external_memory(mut self, force: bool) -> Self {
417        self.force_confidential_external_memory = force;
418        self
419    }
420
421    /// Send messages to the partition even while stopped, which can cause
422    /// corrupted synic states across VM reset.
423    ///
424    /// This option is used to prevent messages from getting into the queue, for
425    /// saved state compatibility with release/2411. It can be removed once that
426    /// release is no longer supported.
427    pub fn send_messages_while_stopped(mut self, send: bool) -> Self {
428        self.send_messages_while_stopped = send;
429        self
430    }
431
432    /// Sets the delay before unsticking a vmbus channel after it has been opened.
433    ///
434    /// This option provides a work around for guests that ignore interrupts before they receive the
435    /// OpenResult message, by triggering an interrupt after the channel has been opened.
436    ///
437    /// If not set, the default is 100ms. If set to `None`, no interrupt will be triggered.
438    pub fn channel_unstick_delay(mut self, delay: Option<Duration>) -> Self {
439        self.channel_unstick_delay = delay;
440        self
441    }
442
443    /// Sets whether the channel order value provided in an offer is the primary way of ordering
444    /// channels when assigning channel IDs, rather than the default behavior of ordering by
445    /// interface ID first.
446    pub fn use_absolute_channel_order(mut self, assign: bool) -> Self {
447        self.use_absolute_channel_order = assign;
448        self
449    }
450
451    /// Creates a new instance of the server.
452    ///
453    /// When the object is dropped, all channels will be closed and revoked
454    /// automatically.
455    pub fn build(self) -> anyhow::Result<VmbusServer> {
456        #[expect(clippy::disallowed_methods)] // TODO
457        let (message_send, message_recv) = mpsc::channel(64);
458        let message_sender = Arc::new(MessageSender {
459            send: message_send.clone(),
460            multiclient: self.use_message_redirect,
461        });
462
463        let (redirect_vtl, redirect_sint) = if self.use_message_redirect {
464            (REDIRECT_VTL, REDIRECT_SINT)
465        } else {
466            (self.vtl, VMBUS_SINT)
467        };
468
469        // If this server is not for VTL2, use a server-specific connection ID rather than the
470        // standard one.
471        let connection_id = if self.vtl == Vtl::Vtl0 && !self.use_message_redirect {
472            MESSAGE_CONNECTION_ID
473        } else {
474            // TODO: This ID should be using the correct target VP, but that is not known until
475            //       InitiateContact.
476            VmbusServer::get_child_message_connection_id(0, redirect_sint, redirect_vtl)
477        };
478
479        let _message_port = self
480            .synic
481            .add_message_port(connection_id, redirect_vtl, message_sender)
482            .context("failed to create vmbus synic ports")?;
483
484        // If this server is for VTL0, it is also responsible for the multiclient message port.
485        // N.B. If control plane redirection is enabled, the redirected message port is used for
486        //      multiclient and no separate multiclient port is created.
487        let _multiclient_message_port = if self.vtl == Vtl::Vtl0 && !self.use_message_redirect {
488            let multiclient_message_sender = Arc::new(MessageSender {
489                send: message_send,
490                multiclient: true,
491            });
492
493            Some(
494                self.synic
495                    .add_message_port(
496                        MULTICLIENT_MESSAGE_CONNECTION_ID,
497                        self.vtl,
498                        multiclient_message_sender,
499                    )
500                    .context("failed to create vmbus synic ports")?,
501            )
502        } else {
503            None
504        };
505
506        let (offer_send, offer_recv) = mesh::mpsc_channel();
507        let control = Arc::new(VmbusServerControl {
508            mem: self.gm.clone(),
509            private_mem: self.private_gm.clone(),
510            send: offer_send,
511            use_event: self.synic.prefer_os_events(),
512            force_confidential_external_memory: self.force_confidential_external_memory,
513        });
514
515        let mut server = channels::Server::new(
516            self.vtl,
517            connection_id,
518            self.channel_id_offset,
519            self.use_absolute_channel_order,
520        );
521
522        // If MNF is handled by this server and this is a paravisor for an isolated VM, the monitor
523        // pages must be allocated by the server, not the guest, since the guest will provide shared
524        // pages which can't be used in this case. If the guest doesn't support server-specified
525        // monitor pages, MNF will be disabled for all channels for that connection.
526        server.set_require_server_allocated_mnf(self.enable_mnf && self.private_gm.is_some());
527
528        // If requested, limit the maximum protocol version and feature flags.
529        if let Some(version) = self.max_version {
530            server.set_compatibility_version(version, self.delay_max_version);
531        }
532        let (relay_request_send, relay_response_recv) =
533            if let Some(server_relay) = self.server_relay {
534                let r = server_relay.response_receive.boxed().fuse();
535                (server_relay.request_send, r)
536            } else {
537                let (req_send, req_recv) = mesh::channel();
538                let resp_recv = req_recv
539                    .map(|req: ModifyRelayRequest| {
540                        // Map to the correct response type for the request.
541                        if req.version.is_some() {
542                            ModifyRelayResponse::Supported(
543                                protocol::ConnectionState::SUCCESSFUL,
544                                protocol::FeatureFlags::from_bits(u32::MAX),
545                            )
546                        } else {
547                            ModifyRelayResponse::Modified(protocol::ConnectionState::SUCCESSFUL)
548                        }
549                    })
550                    .boxed()
551                    .fuse();
552                (req_send, resp_recv)
553            };
554
555        // If no hvsock notifier was specified, use a default one that always sends an error response.
556        let (hvsock_send, hvsock_recv) = if let Some(hvsock_notify) = self.hvsock_notify {
557            let r = hvsock_notify.response_receive.boxed().fuse();
558            (hvsock_notify.request_send, r)
559        } else {
560            let (req_send, req_recv) = mesh::channel();
561            let resp_recv = req_recv
562                .map(|r: HvsockConnectRequest| HvsockConnectResult::from_request(&r, false))
563                .boxed()
564                .fuse();
565            (req_send, resp_recv)
566        };
567
568        let inner = ServerTaskInner {
569            running: false,
570            send_messages_while_stopped: self.send_messages_while_stopped,
571            gm: self.gm,
572            private_gm: self.private_gm,
573            vtl: self.vtl,
574            redirect_vtl,
575            redirect_sint,
576            message_port: self
577                .synic
578                .new_guest_message_port(redirect_vtl, 0, redirect_sint)?,
579            synic: self.synic,
580            hvsock_requests: 0,
581            hvsock_send,
582            saved_state_notify: self.saved_state_notify,
583            channels: HashMap::new(),
584            channel_responses: FuturesUnordered::new(),
585            relay_send: relay_request_send,
586            external_server_send: self.external_server,
587            channel_bitmap: None,
588            shared_event_port: None,
589            reset_done: Vec::new(),
590            mnf_support: self.enable_mnf.then(MnfSupport::default),
591        };
592
593        let (task_send, task_recv) = mesh::channel();
594        let mut server_task = ServerTask {
595            driver: Box::new(self.spawner.clone()),
596            server,
597            task_recv,
598            offer_recv,
599            message_recv,
600            server_request_recv: SelectAll::new(),
601            inner,
602            external_requests: self.external_requests,
603            next_seq: 0,
604            perform_post_restore_on_start: false,
605            unstick_on_start: false,
606            channel_unstickers: FuturesUnordered::new(),
607            channel_unstick_delay: self.channel_unstick_delay,
608        };
609
610        let task = self.spawner.spawn("vmbus server", async move {
611            server_task.run(relay_response_recv, hvsock_recv).await;
612            server_task
613        });
614
615        Ok(VmbusServer {
616            task_send,
617            control,
618            _message_port,
619            _multiclient_message_port,
620            task,
621        })
622    }
623}
624
625impl VmbusServer {
626    /// Creates a new builder for `VmbusServer` with the default options.
627    pub fn builder<T: SpawnDriver + Clone>(
628        spawner: T,
629        synic: Arc<dyn SynicPortAccess>,
630        gm: GuestMemory,
631    ) -> VmbusServerBuilder<T> {
632        VmbusServerBuilder::new(spawner, synic, gm)
633    }
634
635    pub async fn save(&self) -> SavedState {
636        self.task_send.call(VmbusRequest::Save, ()).await.unwrap()
637    }
638
639    pub async fn restore(&self, state: SavedState) -> Result<(), RestoreError> {
640        self.task_send
641            .call(VmbusRequest::Restore, Box::new(state))
642            .await
643            .unwrap()
644    }
645
646    /// Stop the control plane.
647    pub async fn stop(&self) {
648        self.task_send.call(VmbusRequest::Stop, ()).await.unwrap()
649    }
650
651    /// Starts the control plane.
652    pub fn start(&self) {
653        self.task_send.send(VmbusRequest::Start);
654    }
655
656    /// Resets the vmbus channel state.
657    pub async fn reset(&self) {
658        tracing::debug!("resetting channel state");
659        self.task_send.call(VmbusRequest::Reset, ()).await.unwrap()
660    }
661
662    /// Tears down the vmbus control plane.
663    pub async fn shutdown(self) {
664        drop(self.task_send);
665        let _ = self.task.await;
666    }
667
668    /// Returns an object that can be used to offer channels.
669    pub fn control(&self) -> Arc<VmbusServerControl> {
670        self.control.clone()
671    }
672
673    /// Returns the message connection ID to use for a communication from the guest for servers
674    /// that use a non-standard SINT or VTL.
675    fn get_child_message_connection_id(vp_index: u32, sint_index: u8, vtl: Vtl) -> u32 {
676        MULTICLIENT_MESSAGE_CONNECTION_ID
677            | (vtl as u32) << 22
678            | vp_index << 8
679            | (sint_index as u32) << 4
680    }
681
682    fn get_child_event_port_id(channel_id: protocol::ChannelId, sint_index: u8, vtl: Vtl) -> u32 {
683        EVENT_PORT_ID | (vtl as u32) << 22 | channel_id.0 << 8 | (sint_index as u32) << 4
684    }
685}
686
687#[derive(mesh::MeshPayload)]
688pub struct RestoreInfo {
689    open_data: Option<OpenData>,
690    gpadls: Vec<(GpadlId, u16, Vec<u64>)>,
691    interrupt: Option<Interrupt>,
692}
693
694#[derive(Default)]
695pub struct SynicMessage {
696    data: Vec<u8>,
697    multiclient: bool,
698    trusted: bool,
699}
700
701/// Information used by a server that supports MNF.
702#[derive(Default)]
703struct MnfSupport {
704    allocated_monitor_page: Option<MonitorPageGpas>,
705}
706
707/// Disambiguates offer instances that may have reused the same offer ID.
708#[derive(Debug, Clone, Copy)]
709struct OfferInstanceId {
710    offer_id: OfferId,
711    seq: u64,
712}
713
714struct ServerTask {
715    driver: Box<dyn Driver>,
716    server: channels::Server,
717    task_recv: mesh::Receiver<VmbusRequest>,
718    offer_recv: mesh::Receiver<OfferRequest>,
719    message_recv: mpsc::Receiver<SynicMessage>,
720    server_request_recv:
721        SelectAll<TaggedStream<OfferInstanceId, mesh::Receiver<ChannelServerRequest>>>,
722    inner: ServerTaskInner,
723    external_requests: Option<mesh::Receiver<InitiateContactRequest>>,
724    /// Next value for [`Channel::seq`].
725    next_seq: u64,
726    perform_post_restore_on_start: bool,
727    unstick_on_start: bool,
728    channel_unstickers: FuturesUnordered<Pin<Box<dyn Send + Future<Output = OfferInstanceId>>>>,
729    channel_unstick_delay: Option<Duration>,
730}
731
732struct ServerTaskInner {
733    running: bool,
734    send_messages_while_stopped: bool,
735    gm: GuestMemory,
736    private_gm: Option<GuestMemory>,
737    synic: Arc<dyn SynicPortAccess>,
738    vtl: Vtl,
739    redirect_vtl: Vtl,
740    redirect_sint: u8,
741    message_port: Box<dyn GuestMessagePort>,
742    hvsock_requests: usize,
743    hvsock_send: mesh::Sender<HvsockConnectRequest>,
744    saved_state_notify: Option<mesh::Sender<SavedStateRequest>>,
745    channels: HashMap<OfferId, Channel>,
746    channel_responses: FuturesUnordered<
747        Pin<Box<dyn Send + Future<Output = (OfferId, u64, Result<ChannelResponse, RpcError>)>>>,
748    >,
749    external_server_send: Option<mesh::Sender<InitiateContactRequest>>,
750    relay_send: mesh::Sender<ModifyRelayRequest>,
751    channel_bitmap: Option<Arc<ChannelBitmap>>,
752    shared_event_port: Option<Box<dyn Send>>,
753    reset_done: Vec<Rpc<(), ()>>,
754    /// Stores information needed to support MNF. If `None`, this server doesn't support MNF (in
755    /// the case of OpenHCL, that means it will be handled by the relay host).
756    mnf_support: Option<MnfSupport>,
757}
758
759#[derive(Debug)]
760enum ChannelResponse {
761    Open(bool),
762    Close,
763    Gpadl(GpadlId, bool),
764    TeardownGpadl(GpadlId),
765    Modify(i32),
766}
767
768#[derive(Debug, Copy, Clone, PartialEq, Eq)]
769enum ChannelUnstickState {
770    None,
771    Queued,
772    NeedsRequeue,
773}
774
775struct Channel {
776    key: OfferKey,
777    send: mesh::Sender<ChannelRequest>,
778    seq: u64,
779    state: ChannelState,
780    gpadls: Arc<GpadlMap>,
781    guest_to_host_event: Arc<ChannelEvent>,
782    flags: protocol::OfferFlags,
783    // A channel can be reserved no matter what state it is in. This allows the message port for a
784    // reserved channel to remain available even if the channel is closed, so the guest can read the
785    // close reserved channel response. The reserved state is cleared when the channel is revoked,
786    // reopened, or the guest sends an unload message.
787    reserved_state: ReservedState,
788    unstick_state: ChannelUnstickState,
789}
790
791struct ReservedState {
792    message_port: Option<Box<dyn GuestMessagePort>>,
793    target: ConnectionTarget,
794}
795
796struct ChannelOpenState {
797    open_params: OpenParams,
798    _event_port: Box<dyn Send>,
799    guest_event_port: Option<Box<dyn GuestEventPort>>,
800    host_to_guest_interrupt: Interrupt,
801}
802
803impl ChannelOpenState {
804    fn set_event_port_target_vp(&mut self, vp: u32) -> anyhow::Result<()> {
805        let Some(guest_event_port) = self.guest_event_port.as_mut() else {
806            anyhow::bail!("cannot set target VP if the channel interrupt is disabled");
807        };
808
809        guest_event_port.set_target_vp(vp)?;
810        Ok(())
811    }
812}
813
814enum ChannelState {
815    Closed,
816    Open(Box<ChannelOpenState>),
817    Closing,
818}
819
820impl ServerTask {
821    fn handle_offer(&mut self, mut info: OfferInfo) -> anyhow::Result<()> {
822        let key = info.params.key();
823        let flags = info.params.flags;
824
825        if self.inner.mnf_support.is_some() && self.inner.synic.monitor_support().is_some() {
826            // If this server is handling MnF, ignore any relayed monitor IDs but still enable MnF
827            // for those channels.
828            // N.B. Since this can only happen in OpenHCL, which emulates MnF, the latency is
829            //      ignored.
830            if info.params.use_mnf.is_relayed() {
831                info.params.use_mnf = MnfUsage::Enabled {
832                    latency: Duration::ZERO,
833                }
834            }
835        } else if info.params.use_mnf.is_enabled() {
836            // If the server is not handling MnF, disable it for the channel. This does not affect
837            // channels with a relayed monitor ID.
838            info.params.use_mnf = MnfUsage::Disabled;
839        }
840
841        let offer_id = self
842            .server
843            .with_notifier(&mut self.inner)
844            .offer_channel(info.params)
845            .context("channel offer failed")?;
846
847        tracing::debug!(?offer_id, %key, "offered channel");
848
849        let seq = self.next_seq;
850        self.next_seq += 1;
851        self.inner.channels.insert(
852            offer_id,
853            Channel {
854                key,
855                send: info.request_send,
856                state: ChannelState::Closed,
857                gpadls: GpadlMap::new(),
858                guest_to_host_event: Arc::new(ChannelEvent(info.event)),
859                seq,
860                flags,
861                reserved_state: ReservedState {
862                    message_port: None,
863                    target: ConnectionTarget { vp: 0, sint: 0 },
864                },
865                unstick_state: ChannelUnstickState::None,
866            },
867        );
868
869        self.server_request_recv.push(TaggedStream::new(
870            OfferInstanceId { offer_id, seq },
871            info.server_request_recv,
872        ));
873
874        Ok(())
875    }
876
877    fn handle_revoke(&mut self, id: OfferInstanceId) {
878        // The channel may or may not exist in the map depending on whether it's been explicitly
879        // revoked before being dropped.
880        if let Some(channel) = self.inner.channels.get(&id.offer_id) {
881            if channel.seq == id.seq {
882                tracing::info!(?id.offer_id, key = %channel.key, "revoking channel");
883                self.inner.channels.remove(&id.offer_id);
884                self.server
885                    .with_notifier(&mut self.inner)
886                    .revoke_channel(id.offer_id);
887            }
888        }
889    }
890
891    fn handle_response(
892        &mut self,
893        offer_id: OfferId,
894        seq: u64,
895        response: Result<ChannelResponse, RpcError>,
896    ) {
897        // Validate the sequence to ensure the response is not for a revoked channel.
898        let channel = self
899            .inner
900            .channels
901            .get(&offer_id)
902            .filter(|channel| channel.seq == seq);
903
904        if let Some(channel) = channel {
905            match response {
906                Ok(response) => match response {
907                    ChannelResponse::Open(result) => self.handle_open(offer_id, result),
908                    ChannelResponse::Close => self.handle_close(offer_id),
909                    ChannelResponse::Gpadl(gpadl_id, ok) => {
910                        self.handle_gpadl_create(offer_id, gpadl_id, ok)
911                    }
912                    ChannelResponse::TeardownGpadl(gpadl_id) => {
913                        self.handle_gpadl_teardown(offer_id, gpadl_id)
914                    }
915                    ChannelResponse::Modify(status) => self.handle_modify_channel(offer_id, status),
916                },
917                Err(err) => {
918                    tracing::error!(
919                        key = %channel.key,
920                        error = &err as &dyn std::error::Error,
921                        "channel response failure, channel is in inconsistent state until revoked"
922                    );
923                }
924            }
925        } else {
926            tracing::debug!(offer_id = ?offer_id, seq, ?response, "received response after revoke");
927        }
928    }
929
930    fn handle_open(&mut self, offer_id: OfferId, success: bool) {
931        let status = if success {
932            let channel = self
933                .inner
934                .channels
935                .get_mut(&offer_id)
936                .expect("channel exists");
937
938            // Some guests ignore interrupts before they receive the OpenResult message. To avoid
939            // a potential hang, signal the channel after a delay if needed.
940            if let Some(delay) = self.channel_unstick_delay {
941                if channel.unstick_state == ChannelUnstickState::None {
942                    channel.unstick_state = ChannelUnstickState::Queued;
943                    let seq = channel.seq;
944                    let mut timer = PolledTimer::new(&self.driver);
945                    self.channel_unstickers.push(Box::pin(async move {
946                        timer.sleep(delay).await;
947                        OfferInstanceId { offer_id, seq }
948                    }));
949                } else {
950                    channel.unstick_state = ChannelUnstickState::NeedsRequeue;
951                }
952            }
953
954            0
955        } else {
956            protocol::STATUS_UNSUCCESSFUL
957        };
958
959        self.server
960            .with_notifier(&mut self.inner)
961            .open_complete(offer_id, status);
962    }
963
964    fn handle_close(&mut self, offer_id: OfferId) {
965        let channel = self
966            .inner
967            .channels
968            .get_mut(&offer_id)
969            .expect("channel still exists");
970
971        match &mut channel.state {
972            ChannelState::Closing => {
973                tracing::debug!(?offer_id, key = %channel.key, "closing channel");
974                channel.state = ChannelState::Closed;
975                self.server
976                    .with_notifier(&mut self.inner)
977                    .close_complete(offer_id);
978            }
979            _ => {
980                tracing::error!(?offer_id, key = %channel.key, "invalid close channel response");
981            }
982        };
983    }
984
985    fn handle_gpadl_create(&mut self, offer_id: OfferId, gpadl_id: GpadlId, ok: bool) {
986        let status = if ok { 0 } else { protocol::STATUS_UNSUCCESSFUL };
987        self.server
988            .with_notifier(&mut self.inner)
989            .gpadl_create_complete(offer_id, gpadl_id, status);
990    }
991
992    fn handle_gpadl_teardown(&mut self, offer_id: OfferId, gpadl_id: GpadlId) {
993        self.server
994            .with_notifier(&mut self.inner)
995            .gpadl_teardown_complete(offer_id, gpadl_id);
996    }
997
998    fn handle_modify_channel(&mut self, offer_id: OfferId, status: i32) {
999        self.server
1000            .with_notifier(&mut self.inner)
1001            .modify_channel_complete(offer_id, status);
1002    }
1003
1004    fn handle_restore_channel(
1005        &mut self,
1006        offer_id: OfferId,
1007        open: bool,
1008    ) -> anyhow::Result<RestoreResult> {
1009        let gpadls = self.server.channel_gpadls(offer_id);
1010
1011        // If the channel is opened, handle that before calling into channels so that failure can
1012        // be handled before the channel is marked restored.
1013        let open_request = open
1014            .then(|| -> anyhow::Result<_> {
1015                let params = self.server.get_restore_open_params(offer_id)?;
1016                let (channel, interrupt) = self.inner.open_channel(offer_id, &params)?;
1017                Ok(OpenRequest::new(
1018                    params.open_data,
1019                    interrupt,
1020                    self.server
1021                        .get_version()
1022                        .expect("must be connected")
1023                        .feature_flags,
1024                    channel.flags,
1025                ))
1026            })
1027            .transpose()?;
1028
1029        self.server
1030            .with_notifier(&mut self.inner)
1031            .restore_channel(offer_id, open_request.is_some())?;
1032
1033        let channel = self.inner.channels.get_mut(&offer_id).unwrap();
1034        for gpadl in &gpadls {
1035            if let Ok(buf) = MultiPagedRangeBuf::from_range_buffer(
1036                gpadl.request.count.into(),
1037                gpadl.request.buf.clone(),
1038            ) {
1039                channel.gpadls.add(gpadl.request.id, buf);
1040            }
1041        }
1042
1043        let result = RestoreResult {
1044            open_request,
1045            gpadls,
1046        };
1047        Ok(result)
1048    }
1049
1050    async fn handle_request(&mut self, request: VmbusRequest) {
1051        tracing::debug!(?request, "handle_request");
1052        match request {
1053            VmbusRequest::Reset(rpc) => self.handle_reset(rpc),
1054            VmbusRequest::Inspect(deferred) => {
1055                deferred.respond(|resp| {
1056                    resp.field("message_port", &self.inner.message_port)
1057                        .field("running", self.inner.running)
1058                        .field("hvsock_requests", self.inner.hvsock_requests)
1059                        .field("channel_unstick_delay", self.channel_unstick_delay)
1060                        .field_mut_with("unstick_channels", |v| {
1061                            let v: inspect::ValueKind = if let Some(v) = v {
1062                                if v == "force" {
1063                                    self.unstick_channels(true);
1064                                    v.into()
1065                                } else {
1066                                    let v =
1067                                        v.parse().ok().context("expected false, true, or force")?;
1068                                    if v {
1069                                        self.unstick_channels(false);
1070                                    }
1071                                    v.into()
1072                                }
1073                            } else {
1074                                false.into()
1075                            };
1076                            anyhow::Ok(v)
1077                        })
1078                        .merge(&self.server.with_notifier(&mut self.inner));
1079                });
1080            }
1081            VmbusRequest::Save(rpc) => rpc.handle_sync(|()| SavedState {
1082                server: self.server.save(),
1083                lost_synic_bug_fixed: true,
1084            }),
1085            VmbusRequest::Restore(rpc) => {
1086                rpc.handle(async |state| {
1087                    self.unstick_on_start = !state.lost_synic_bug_fixed;
1088                    self.perform_post_restore_on_start = true;
1089                    if let Some(sender) = &self.inner.saved_state_notify {
1090                        tracing::trace!("sending saved state to proxy");
1091                        if let Err(err) = sender
1092                            .call_failable(SavedStateRequest::Set, Box::new(state.server.clone()))
1093                            .await
1094                        {
1095                            tracing::error!(
1096                                err = &err as &dyn std::error::Error,
1097                                "failed to restore proxy saved state"
1098                            );
1099                            return Err(RestoreError::ServerError(err.into()));
1100                        }
1101                    }
1102
1103                    self.server
1104                        .with_notifier(&mut self.inner)
1105                        .restore(state.server)
1106                })
1107                .await
1108            }
1109            VmbusRequest::Stop(rpc) => rpc.handle_sync(|()| {
1110                if self.inner.running {
1111                    self.inner.running = false;
1112                }
1113            }),
1114            VmbusRequest::Start => {
1115                if !self.inner.running {
1116                    self.inner.running = true;
1117                    if self.perform_post_restore_on_start {
1118                        if let Some(sender) = self.inner.saved_state_notify.as_ref() {
1119                            // Indicate to the proxy that the server is starting and that it should
1120                            // clear its saved state cache.
1121                            tracing::trace!("sending clear saved state message to proxy");
1122                            sender
1123                                .call(SavedStateRequest::Clear, ())
1124                                .await
1125                                .expect("failed to clear proxy saved state");
1126                        }
1127
1128                        self.server
1129                            .with_notifier(&mut self.inner)
1130                            .revoke_unclaimed_channels();
1131
1132                        self.perform_post_restore_on_start = false;
1133                    }
1134
1135                    if self.unstick_on_start {
1136                        tracing::info!(
1137                            "lost synic bug fix is not in yet, call unstick_channels to mitigate the issue."
1138                        );
1139                        self.unstick_channels(false);
1140                        self.unstick_on_start = false;
1141                    }
1142                }
1143            }
1144        }
1145    }
1146
1147    fn handle_reset(&mut self, rpc: Rpc<(), ()>) {
1148        let needs_reset = self.inner.reset_done.is_empty();
1149        self.inner.reset_done.push(rpc);
1150        if needs_reset {
1151            self.server.with_notifier(&mut self.inner).reset();
1152        }
1153    }
1154
1155    fn handle_relay_response(&mut self, response: ModifyRelayResponse) {
1156        // Convert to a matching ModifyConnectionResponse.
1157        let response = match response {
1158            ModifyRelayResponse::Supported(state, features) => {
1159                // Provide the server-allocated monitor page to the server only if they're actually being
1160                // used (if not, they may still be allocated from a previous connection).
1161                let allocated_monitor_gpas = self
1162                    .inner
1163                    .mnf_support
1164                    .as_ref()
1165                    .and_then(|mnf| mnf.allocated_monitor_page);
1166
1167                ModifyConnectionResponse::Supported(state, features, allocated_monitor_gpas)
1168            }
1169            ModifyRelayResponse::Unsupported => ModifyConnectionResponse::Unsupported,
1170            ModifyRelayResponse::Modified(state) => ModifyConnectionResponse::Modified(state),
1171        };
1172
1173        self.server
1174            .with_notifier(&mut self.inner)
1175            .complete_modify_connection(response);
1176    }
1177
1178    fn handle_tl_connect_result(&mut self, result: HvsockConnectResult) {
1179        tracing::debug!(?result, "hvsock connect result");
1180        assert_ne!(self.inner.hvsock_requests, 0);
1181        self.inner.hvsock_requests -= 1;
1182
1183        self.server
1184            .with_notifier(&mut self.inner)
1185            .send_tl_connect_result(result);
1186    }
1187
1188    fn handle_synic_message(&mut self, message: SynicMessage) {
1189        match self
1190            .server
1191            .with_notifier(&mut self.inner)
1192            .handle_synic_message(message)
1193        {
1194            Ok(()) => {}
1195            Err(err) => {
1196                tracing::warn!(
1197                    error = &err as &dyn std::error::Error,
1198                    "synic message error"
1199                );
1200            }
1201        }
1202    }
1203
1204    /// Handles a request forwarded by a different vmbus server. This is used to forward requests
1205    /// for different VTLs to different servers.
1206    ///
1207    /// N.B. This uses the same mechanism as the HCL server relay, so all requests, even the ones
1208    ///      meant for the primary server, are forwarded. In that case the primary server depends
1209    ///      on this server to send back a response so it can continue handling it.
1210    fn handle_external_request(&mut self, request: InitiateContactRequest) {
1211        self.server
1212            .with_notifier(&mut self.inner)
1213            .initiate_contact(request);
1214    }
1215
1216    async fn run(
1217        &mut self,
1218        mut relay_response_recv: impl futures::stream::FusedStream<Item = ModifyRelayResponse> + Unpin,
1219        mut hvsock_recv: impl futures::stream::FusedStream<Item = HvsockConnectResult> + Unpin,
1220    ) {
1221        loop {
1222            // Create an OptionFuture for each event that should only be handled
1223            // while the VM is running. In other cases, leave the events in
1224            // their respective queues.
1225
1226            let running_not_resetting = self.inner.running && self.inner.reset_done.is_empty();
1227            let mut external_requests = OptionFuture::from(
1228                running_not_resetting
1229                    .then(|| {
1230                        self.external_requests
1231                            .as_mut()
1232                            .map(|r| r.select_next_some())
1233                    })
1234                    .flatten(),
1235            );
1236
1237            // Try to send any pending messages while the VM is running.
1238            let has_pending_messages = self.server.has_pending_messages();
1239            let message_port = self.inner.message_port.as_mut();
1240            let mut flush_pending_messages =
1241                OptionFuture::from((running_not_resetting && has_pending_messages).then(|| {
1242                    poll_fn(|cx| {
1243                        self.server.poll_flush_pending_messages(|msg| {
1244                            message_port.poll_post_message(cx, VMBUS_MESSAGE_TYPE, msg.data())
1245                        })
1246                    })
1247                    .fuse()
1248                }));
1249
1250            // Only handle new incoming messages if there are no outgoing messages pending, and not
1251            // too many hvsock requests outstanding. This puts a bound on the resources used by the
1252            // guest.
1253            let mut message_recv = OptionFuture::from(
1254                (running_not_resetting
1255                    && !has_pending_messages
1256                    && self.inner.hvsock_requests < MAX_CONCURRENT_HVSOCK_REQUESTS)
1257                    .then(|| self.message_recv.select_next_some()),
1258            );
1259
1260            // Accept channel responses until stopped or when resetting.
1261            let mut channel_response = OptionFuture::from(
1262                (self.inner.running || !self.inner.reset_done.is_empty())
1263                    .then(|| self.inner.channel_responses.select_next_some()),
1264            );
1265
1266            // Accept hvsock connect responses while the VM is running.
1267            let mut hvsock_response =
1268                OptionFuture::from(running_not_resetting.then(|| hvsock_recv.select_next_some()));
1269
1270            let mut channel_unstickers = OptionFuture::from(
1271                running_not_resetting.then(|| self.channel_unstickers.select_next_some()),
1272            );
1273
1274            futures::select! { // merge semantics
1275                r = self.task_recv.recv().fuse() => {
1276                    if let Ok(request) = r {
1277                        self.handle_request(request).await;
1278                    } else {
1279                        break;
1280                    }
1281                }
1282                r = self.offer_recv.select_next_some() => {
1283                    match r {
1284                        OfferRequest::Offer(rpc) => {
1285                            rpc.handle_failable_sync(|request| { self.handle_offer(request) })
1286                        },
1287                        OfferRequest::ForceReset(rpc) => {
1288                            self.handle_reset(rpc);
1289                        }
1290                    }
1291                }
1292                r = self.server_request_recv.select_next_some() => {
1293                    match r {
1294                        (id, Some(request)) => match request {
1295                            ChannelServerRequest::Restore(rpc) => rpc.handle_failable_sync(|open| {
1296                                self.handle_restore_channel(id.offer_id, open)
1297                            }),
1298                            ChannelServerRequest::Revoke(rpc) => rpc.handle_sync(|_| {
1299                                self.handle_revoke(id);
1300                            })
1301                        },
1302                        (id, None) => self.handle_revoke(id),
1303                    }
1304                }
1305                r = channel_response => {
1306                    let (id, seq, response) = r.unwrap();
1307                    self.handle_response(id, seq, response);
1308                }
1309                r = relay_response_recv.select_next_some() => {
1310                    self.handle_relay_response(r);
1311                },
1312                r = hvsock_response => {
1313                    self.handle_tl_connect_result(r.unwrap());
1314                }
1315                data = message_recv => {
1316                    let data = data.unwrap();
1317                    self.handle_synic_message(data);
1318                }
1319                r = external_requests => {
1320                    let r = r.unwrap();
1321                    self.handle_external_request(r);
1322                }
1323                r = channel_unstickers => {
1324                    self.unstick_channel_by_id(r.unwrap());
1325                }
1326                _r = flush_pending_messages => {}
1327                complete => break,
1328            }
1329        }
1330    }
1331
1332    /// Wakes the guest and optionally the host for every open channel. If `force`, always wakes
1333    /// them. If `!force`, only wake for rings that are in the state where a notification is
1334    /// expected.
1335    fn unstick_channels(&self, force: bool) {
1336        let Some(version) = self.server.get_version() else {
1337            tracing::warn!("cannot unstick when not connected");
1338            return;
1339        };
1340
1341        for channel in self.inner.channels.values() {
1342            let gm = self.inner.get_gm_for_channel(version, channel);
1343            if let Err(err) = Self::unstick_channel(gm, channel, force, true) {
1344                tracing::warn!(
1345                    channel = %channel.key,
1346                    error = err.as_ref() as &dyn std::error::Error,
1347                    "could not unstick channel"
1348                );
1349            }
1350        }
1351    }
1352
1353    /// Wakes the guest for the specified channel if it's open and the rings are in a state where
1354    /// notification is expected.
1355    fn unstick_channel_by_id(&mut self, id: OfferInstanceId) {
1356        let Some(version) = self.server.get_version() else {
1357            tracelimit::warn_ratelimited!("cannot unstick when not connected");
1358            return;
1359        };
1360
1361        if let Some(channel) = self.inner.channels.get_mut(&id.offer_id) {
1362            if channel.seq != id.seq {
1363                // The channel was revoked.
1364                return;
1365            }
1366
1367            // The channel was closed and reopened before the delay expired, so wait again to ensure
1368            // we don't signal too early.
1369            if channel.unstick_state == ChannelUnstickState::NeedsRequeue {
1370                channel.unstick_state = ChannelUnstickState::Queued;
1371                let mut timer = PolledTimer::new(&self.driver);
1372                let delay = self.channel_unstick_delay.unwrap();
1373                self.channel_unstickers.push(Box::pin(async move {
1374                    timer.sleep(delay).await;
1375                    id
1376                }));
1377
1378                return;
1379            }
1380
1381            channel.unstick_state = ChannelUnstickState::None;
1382            let gm = select_gm_for_channel(
1383                &self.inner.gm,
1384                self.inner.private_gm.as_ref(),
1385                version,
1386                channel,
1387            );
1388            if let Err(err) = Self::unstick_channel(gm, channel, false, false) {
1389                tracelimit::warn_ratelimited!(
1390                    channel = %channel.key,
1391                    error = err.as_ref() as &dyn std::error::Error,
1392                    "could not unstick channel"
1393                );
1394            }
1395        }
1396    }
1397
1398    fn unstick_channel(
1399        gm: &GuestMemory,
1400        channel: &Channel,
1401        force: bool,
1402        unstick_host: bool,
1403    ) -> anyhow::Result<()> {
1404        if let ChannelState::Open(state) = &channel.state {
1405            if force {
1406                tracing::info!(channel = %channel.key, "waking host and guest");
1407                if unstick_host {
1408                    channel.guest_to_host_event.0.deliver();
1409                }
1410                state.host_to_guest_interrupt.deliver();
1411                return Ok(());
1412            }
1413
1414            let gpadl = channel
1415                .gpadls
1416                .clone()
1417                .view()
1418                .map(state.open_params.open_data.ring_gpadl_id)
1419                .context("couldn't find ring gpadl")?;
1420
1421            let aligned = AlignedGpadlView::new(gpadl)
1422                .ok()
1423                .context("ring not aligned")?;
1424            let (in_gpadl, out_gpadl) = aligned
1425                .split(state.open_params.open_data.ring_offset)
1426                .ok()
1427                .context("couldn't split ring")?;
1428
1429            if let Err(err) = Self::unstick_incoming_ring(
1430                gm,
1431                channel,
1432                in_gpadl,
1433                unstick_host.then_some(channel.guest_to_host_event.as_ref()),
1434                &state.host_to_guest_interrupt,
1435            ) {
1436                tracelimit::warn_ratelimited!(
1437                    channel = %channel.key,
1438                    error = err.as_ref() as &dyn std::error::Error,
1439                    "could not unstick incoming ring"
1440                );
1441            }
1442            if let Err(err) = Self::unstick_outgoing_ring(
1443                gm,
1444                channel,
1445                out_gpadl,
1446                unstick_host.then_some(channel.guest_to_host_event.as_ref()),
1447                &state.host_to_guest_interrupt,
1448            ) {
1449                tracelimit::warn_ratelimited!(
1450                    channel = %channel.key,
1451                    error = err.as_ref() as &dyn std::error::Error,
1452                    "could not unstick outgoing ring"
1453                );
1454            }
1455        }
1456        Ok(())
1457    }
1458
1459    fn unstick_incoming_ring(
1460        gm: &GuestMemory,
1461        channel: &Channel,
1462        in_gpadl: AlignedGpadlView,
1463        guest_to_host_event: Option<&ChannelEvent>,
1464        host_to_guest_interrupt: &Interrupt,
1465    ) -> anyhow::Result<()> {
1466        let control_page = lock_gpn_with_subrange(gm, in_gpadl.gpns()[0])?;
1467        if let Some(guest_to_host_event) = guest_to_host_event {
1468            if ring::reader_needs_signal(control_page.pages()[0]) {
1469                tracelimit::info_ratelimited!(channel = %channel.key, "waking host for incoming ring");
1470                guest_to_host_event.0.deliver();
1471            }
1472        }
1473
1474        let ring_size = gpadl_ring_size(&in_gpadl).try_into()?;
1475        if ring::writer_needs_signal(control_page.pages()[0], ring_size) {
1476            tracelimit::info_ratelimited!(channel = %channel.key, "waking guest for incoming ring");
1477            host_to_guest_interrupt.deliver();
1478        }
1479        Ok(())
1480    }
1481
1482    fn unstick_outgoing_ring(
1483        gm: &GuestMemory,
1484        channel: &Channel,
1485        out_gpadl: AlignedGpadlView,
1486        guest_to_host_event: Option<&ChannelEvent>,
1487        host_to_guest_interrupt: &Interrupt,
1488    ) -> anyhow::Result<()> {
1489        let control_page = lock_gpn_with_subrange(gm, out_gpadl.gpns()[0])?;
1490        if ring::reader_needs_signal(control_page.pages()[0]) {
1491            tracelimit::info_ratelimited!(channel = %channel.key, "waking guest for outgoing ring");
1492            host_to_guest_interrupt.deliver();
1493        }
1494
1495        if let Some(guest_to_host_event) = guest_to_host_event {
1496            let ring_size = gpadl_ring_size(&out_gpadl).try_into()?;
1497            if ring::writer_needs_signal(control_page.pages()[0], ring_size) {
1498                tracelimit::info_ratelimited!(channel = %channel.key, "waking host for outgoing ring");
1499                guest_to_host_event.0.deliver();
1500            }
1501        }
1502        Ok(())
1503    }
1504}
1505
1506impl Notifier for ServerTaskInner {
1507    fn notify(&mut self, offer_id: OfferId, action: channels::Action) {
1508        let channel = self
1509            .channels
1510            .get_mut(&offer_id)
1511            .expect("channel does not exist");
1512
1513        fn handle<I: 'static + Send, R: 'static + Send>(
1514            offer_id: OfferId,
1515            channel: &Channel,
1516            req: impl FnOnce(Rpc<I, R>) -> ChannelRequest,
1517            input: I,
1518            f: impl 'static + Send + FnOnce(R) -> ChannelResponse,
1519        ) -> Pin<Box<dyn Send + Future<Output = (OfferId, u64, Result<ChannelResponse, RpcError>)>>>
1520        {
1521            let recv = channel.send.call(req, input);
1522            let seq = channel.seq;
1523            Box::pin(async move {
1524                let r = recv.await.map(f);
1525                (offer_id, seq, r)
1526            })
1527        }
1528
1529        let response = match action {
1530            channels::Action::Open(open_params, version) => {
1531                let seq = channel.seq;
1532                let key = channel.key;
1533                match self.open_channel(offer_id, &open_params) {
1534                    Ok((channel, interrupt)) => handle(
1535                        offer_id,
1536                        channel,
1537                        ChannelRequest::Open,
1538                        OpenRequest::new(
1539                            open_params.open_data,
1540                            interrupt,
1541                            version.feature_flags,
1542                            channel.flags,
1543                        ),
1544                        ChannelResponse::Open,
1545                    ),
1546                    Err(err) => {
1547                        tracelimit::error_ratelimited!(
1548                            err = err.as_ref() as &dyn std::error::Error,
1549                            ?offer_id,
1550                            %key,
1551                            "could not open channel",
1552                        );
1553
1554                        // Return an error response to the channels module if the open_channel call
1555                        // failed.
1556                        Box::pin(future::ready((
1557                            offer_id,
1558                            seq,
1559                            Ok(ChannelResponse::Open(false)),
1560                        )))
1561                    }
1562                }
1563            }
1564            channels::Action::Close => {
1565                if let Some(channel_bitmap) = self.channel_bitmap.as_ref() {
1566                    if let ChannelState::Open(ref state) = channel.state {
1567                        channel_bitmap.unregister_channel(state.open_params.event_flag);
1568                    }
1569                }
1570
1571                channel.state = ChannelState::Closing;
1572                handle(offer_id, channel, ChannelRequest::Close, (), |()| {
1573                    ChannelResponse::Close
1574                })
1575            }
1576            channels::Action::Gpadl(gpadl_id, count, buf) => {
1577                channel.gpadls.add(
1578                    gpadl_id,
1579                    MultiPagedRangeBuf::from_range_buffer(count.into(), buf.clone()).unwrap(),
1580                );
1581                handle(
1582                    offer_id,
1583                    channel,
1584                    ChannelRequest::Gpadl,
1585                    GpadlRequest {
1586                        id: gpadl_id,
1587                        count,
1588                        buf,
1589                    },
1590                    move |r| ChannelResponse::Gpadl(gpadl_id, r),
1591                )
1592            }
1593            channels::Action::TeardownGpadl {
1594                gpadl_id,
1595                post_restore,
1596            } => {
1597                if !post_restore {
1598                    channel.gpadls.remove(gpadl_id, Box::new(|| ()));
1599                }
1600
1601                handle(
1602                    offer_id,
1603                    channel,
1604                    ChannelRequest::TeardownGpadl,
1605                    gpadl_id,
1606                    move |()| ChannelResponse::TeardownGpadl(gpadl_id),
1607                )
1608            }
1609            channels::Action::Modify { target_vp } => {
1610                let ChannelState::Open(state) = &mut channel.state else {
1611                    unreachable!();
1612                };
1613
1614                if let Err(err) = state.set_event_port_target_vp(target_vp) {
1615                    tracelimit::error_ratelimited!(
1616                        error = err.as_ref() as &dyn std::error::Error,
1617                        channel = %channel.key,
1618                        "could not modify channel",
1619                    );
1620
1621                    // Send an immediate error response.
1622                    let seq = channel.seq;
1623                    Box::pin(async move {
1624                        (
1625                            offer_id,
1626                            seq,
1627                            Ok(ChannelResponse::Modify(protocol::STATUS_UNSUCCESSFUL)),
1628                        )
1629                    })
1630                } else {
1631                    handle(
1632                        offer_id,
1633                        channel,
1634                        ChannelRequest::Modify,
1635                        ModifyRequest::TargetVp { target_vp },
1636                        ChannelResponse::Modify,
1637                    )
1638                }
1639            }
1640        };
1641        self.channel_responses.push(response);
1642    }
1643
1644    fn modify_connection(&mut self, mut request: ModifyConnectionRequest) -> anyhow::Result<()> {
1645        self.map_interrupt_page(request.interrupt_page)
1646            .context("Failed to map interrupt page.")?;
1647
1648        self.set_monitor_page(&mut request)
1649            .context("Failed to map monitor page.")?;
1650
1651        if let Some(vp) = request.target_message_vp {
1652            self.message_port.set_target_vp(vp)?;
1653        }
1654
1655        if request.notify_relay {
1656            self.relay_send.send(request.into());
1657        }
1658
1659        Ok(())
1660    }
1661
1662    fn forward_unhandled(&mut self, request: InitiateContactRequest) {
1663        if let Some(external_server) = &self.external_server_send {
1664            external_server.send(request);
1665        } else {
1666            tracing::warn!(?request, "nowhere to forward unhandled request")
1667        }
1668    }
1669
1670    fn inspect(&self, version: Option<VersionInfo>, offer_id: OfferId, req: inspect::Request<'_>) {
1671        let channel = self.channels.get(&offer_id).expect("should exist");
1672        let mut resp = req.respond();
1673        if let ChannelState::Open(state) = &channel.state {
1674            let mem = self.get_gm_for_channel(version.expect("must be connected"), channel);
1675            inspect_rings(
1676                &mut resp,
1677                mem,
1678                channel.gpadls.clone(),
1679                &state.open_params.open_data,
1680            );
1681        }
1682    }
1683
1684    fn send_message(&mut self, message: &OutgoingMessage, target: MessageTarget) -> bool {
1685        // If the server is paused, queue all messages, to avoid affecting synic
1686        // state during/after it has been saved or reset.
1687        //
1688        // Note that messages to reserved channels or custom targets will be
1689        // dropped. However, such messages should only be sent in response to
1690        // guest requests, which should not be processed while the server is
1691        // paused.
1692        //
1693        // FUTURE: it would be better to ensure that no messages are generated
1694        // by operations that run while the server is paused. E.g., defer
1695        // sending offer or revoke messages for new or revoked offers. This
1696        // would prevent the queue from growing without bound.
1697        if !self.running && !self.send_messages_while_stopped {
1698            if !matches!(target, MessageTarget::Default) {
1699                tracelimit::error_ratelimited!(?target, "dropping message while paused");
1700            }
1701            return false;
1702        }
1703
1704        let mut port_storage;
1705        let port = match target {
1706            MessageTarget::Default => self.message_port.as_mut(),
1707            MessageTarget::ReservedChannel(offer_id, target) => {
1708                if let Some(port) = self.get_reserved_channel_message_port(offer_id, target) {
1709                    port.as_mut()
1710                } else {
1711                    // Updating the port failed, so there is no way to send the message.
1712                    return true;
1713                }
1714            }
1715            MessageTarget::Custom(target) => {
1716                port_storage = match self.synic.new_guest_message_port(
1717                    self.redirect_vtl,
1718                    target.vp,
1719                    target.sint,
1720                ) {
1721                    Ok(port) => port,
1722                    Err(err) => {
1723                        tracing::error!(
1724                            ?err,
1725                            ?self.redirect_vtl,
1726                            ?target,
1727                            "could not create message port"
1728                        );
1729
1730                        // There is no way to send the message.
1731                        return true;
1732                    }
1733                };
1734                port_storage.as_mut()
1735            }
1736        };
1737
1738        // If this returns Pending, the channels module will queue the message and the ServerTask
1739        // main loop will try to send it again later.
1740        matches!(
1741            port.poll_post_message(
1742                &mut std::task::Context::from_waker(std::task::Waker::noop()),
1743                VMBUS_MESSAGE_TYPE,
1744                message.data()
1745            ),
1746            Poll::Ready(())
1747        )
1748    }
1749
1750    fn notify_hvsock(&mut self, request: &HvsockConnectRequest) {
1751        tracing::debug!(?request, "received hvsock connect request");
1752        self.hvsock_requests += 1;
1753        self.hvsock_send.send(*request);
1754    }
1755
1756    fn reset_complete(&mut self) {
1757        if let Some(monitor) = self.synic.monitor_support() {
1758            if let Err(err) = monitor.set_monitor_page(self.vtl, None) {
1759                tracing::warn!(?err, "resetting monitor page failed")
1760            }
1761        }
1762
1763        self.unreserve_channels();
1764        for done in self.reset_done.drain(..) {
1765            done.complete(());
1766        }
1767    }
1768
1769    fn unload_complete(&mut self) {
1770        self.unreserve_channels();
1771    }
1772}
1773
1774impl ServerTaskInner {
1775    fn open_channel(
1776        &mut self,
1777        offer_id: OfferId,
1778        open_params: &OpenParams,
1779    ) -> anyhow::Result<(&mut Channel, Interrupt)> {
1780        let channel = self
1781            .channels
1782            .get_mut(&offer_id)
1783            .expect("channel does not exist");
1784
1785        // Always register with the channel bitmap; if Win7, this may be unnecessary.
1786        if let Some(channel_bitmap) = self.channel_bitmap.as_ref() {
1787            channel_bitmap.register_channel(
1788                open_params.event_flag,
1789                channel.guest_to_host_event.0.clone(),
1790            );
1791        }
1792        // Always set up an event port; if V1, this will be unused.
1793        // N.B. The event port must be created before the device is notified of the open by the
1794        //      caller. The device may begin communicating with the guest immediately when it is
1795        //      notified, so the event port must exist so that the guest can send interrupts.
1796        let event_port = self
1797            .synic
1798            .add_event_port(
1799                open_params.connection_id,
1800                self.vtl,
1801                channel.guest_to_host_event.clone(),
1802                open_params.monitor_info,
1803            )
1804            .context("failed to create guest-to-host event port")?;
1805
1806        // For pre-Win8 guests, the host-to-guest event always targets vp 0 and the channel
1807        // bitmap is used instead of the event flag.
1808        let (target_vp, event_flag) = if self.channel_bitmap.is_some() {
1809            (Some(0), 0)
1810        } else {
1811            (open_params.open_data.target_vp, open_params.event_flag)
1812        };
1813
1814        let (guest_event_port, interrupt) = if let Some(target_vp) = target_vp {
1815            let (target_vtl, target_sint) = if open_params.flags.redirect_interrupt() {
1816                (self.redirect_vtl, self.redirect_sint)
1817            } else {
1818                (self.vtl, VMBUS_SINT)
1819            };
1820
1821            let guest_event_port = self.synic.new_guest_event_port(
1822                VmbusServer::get_child_event_port_id(open_params.channel_id, VMBUS_SINT, self.vtl),
1823                target_vtl,
1824                target_vp,
1825                target_sint,
1826                event_flag,
1827                open_params.monitor_info,
1828            )?;
1829
1830            let interrupt = ChannelBitmap::create_interrupt(
1831                &self.channel_bitmap,
1832                guest_event_port.interrupt(),
1833                open_params.event_flag,
1834            );
1835
1836            (Some(guest_event_port), interrupt)
1837        } else {
1838            // Use a dummy interrupt which does nothing, but make sure it has an event to avoid
1839            // proxy_integration from trying to wrap it.
1840            (None, Interrupt::null_event())
1841        };
1842
1843        // Delete any previously reserved state.
1844        channel.reserved_state.message_port = None;
1845
1846        // If the channel is reserved, create a message port for it.
1847        if let Some(target) = open_params.reserved_target {
1848            channel.reserved_state.message_port = Some(self.synic.new_guest_message_port(
1849                self.redirect_vtl,
1850                target.vp,
1851                target.sint,
1852            )?);
1853
1854            channel.reserved_state.target = target;
1855        }
1856
1857        channel.state = ChannelState::Open(Box::new(ChannelOpenState {
1858            open_params: *open_params,
1859            _event_port: event_port,
1860            guest_event_port,
1861            host_to_guest_interrupt: interrupt.clone(),
1862        }));
1863        Ok((channel, interrupt))
1864    }
1865
1866    /// If the client specified an interrupt page, map it into host memory and
1867    /// set up the shared event port.
1868    fn map_interrupt_page(&mut self, interrupt_page: Update<u64>) -> anyhow::Result<()> {
1869        let interrupt_page = match interrupt_page {
1870            Update::Unchanged => return Ok(()),
1871            Update::Reset => {
1872                self.channel_bitmap = None;
1873                self.shared_event_port = None;
1874                return Ok(());
1875            }
1876            Update::Set(interrupt_page) => interrupt_page,
1877        };
1878
1879        assert_ne!(interrupt_page, 0);
1880
1881        if interrupt_page % PAGE_SIZE as u64 != 0 {
1882            anyhow::bail!("interrupt page {:#x} is not page aligned", interrupt_page);
1883        }
1884
1885        // Use a subrange to access the interrupt page to give GuestMemory's without a full mapping
1886        // a chance to create one.
1887        let interrupt_page = lock_page_with_subrange(&self.gm, interrupt_page)?;
1888        let channel_bitmap = Arc::new(ChannelBitmap::new(interrupt_page));
1889        self.channel_bitmap = Some(channel_bitmap.clone());
1890
1891        // Create the shared event port for pre-Win8 guests.
1892        let interrupt = Interrupt::from_fn(move || {
1893            channel_bitmap.handle_shared_interrupt();
1894        });
1895
1896        self.shared_event_port = Some(self.synic.add_event_port(
1897            SHARED_EVENT_CONNECTION_ID,
1898            self.vtl,
1899            Arc::new(ChannelEvent(interrupt)),
1900            None,
1901        )?);
1902
1903        Ok(())
1904    }
1905
1906    fn set_monitor_page(&mut self, request: &mut ModifyConnectionRequest) -> anyhow::Result<()> {
1907        let monitor_page = match request.monitor_page {
1908            Update::Unchanged => return Ok(()),
1909            Update::Reset => None,
1910            Update::Set(value) => Some(value),
1911        };
1912
1913        // TODO: can this check be moved into channels.rs?
1914        if self.channels.iter().any(|(_, c)| {
1915            matches!(
1916                &c.state,
1917                ChannelState::Open(state) if state.open_params.monitor_info.is_some()
1918            )
1919        }) {
1920            anyhow::bail!("attempt to change monitor page while open channels using mnf");
1921        }
1922
1923        // Check if the server is handling MNF.
1924        // N.B. If the server is not handling MNF, there is currently no way to request
1925        //      server-allocated monitor pages from the relay host.
1926        if let Some(mnf_support) = self.mnf_support.as_mut() {
1927            if let Some(monitor) = self.synic.monitor_support() {
1928                mnf_support.allocated_monitor_page = None;
1929
1930                if let Some(version) = request.version {
1931                    if version.feature_flags.server_specified_monitor_pages() {
1932                        if let Some(monitor_page) = monitor.allocate_monitor_page(self.vtl)? {
1933                            tracelimit::info_ratelimited!(
1934                                ?monitor_page,
1935                                "using server-allocated monitor pages"
1936                            );
1937                            mnf_support.allocated_monitor_page = Some(monitor_page);
1938                        }
1939                    }
1940                }
1941
1942                // If no monitor page was allocated above, use the one provided by the client.
1943                if mnf_support.allocated_monitor_page.is_none() {
1944                    if let Err(err) = monitor.set_monitor_page(self.vtl, monitor_page) {
1945                        anyhow::bail!(
1946                            "setting monitor page failed, err = {err:?}, monitor_page = {monitor_page:?}"
1947                        );
1948                    }
1949                }
1950            }
1951
1952            // If MNF is configured to be handled by this server (even if it's not actually
1953            // supported by the synic), don't forward the pages to the relay.
1954            request.monitor_page = Update::Unchanged;
1955        }
1956
1957        Ok(())
1958    }
1959
1960    fn get_reserved_channel_message_port(
1961        &mut self,
1962        offer_id: OfferId,
1963        new_target: ConnectionTarget,
1964    ) -> Option<&mut Box<dyn GuestMessagePort>> {
1965        let channel = self
1966            .channels
1967            .get_mut(&offer_id)
1968            .expect("channel does not exist");
1969
1970        assert!(
1971            channel.reserved_state.message_port.is_some(),
1972            "channel is not reserved"
1973        );
1974
1975        // On close, the guest may have changed the message target it wants to use for the close
1976        // response. If so, update the message port.
1977        if channel.reserved_state.target.sint != new_target.sint {
1978            // Destroy the old port before creating the new one.
1979            channel.reserved_state.message_port = None;
1980            let message_port = self
1981                .synic
1982                .new_guest_message_port(self.redirect_vtl, new_target.vp, new_target.sint)
1983                .inspect_err(|err| {
1984                    tracing::error!(
1985                        key = %channel.key,
1986                        ?err,
1987                        ?self.redirect_vtl,
1988                        ?new_target,
1989                        "could not create reserved channel message port"
1990                    )
1991                })
1992                .ok()?;
1993
1994            channel.reserved_state.message_port = Some(message_port);
1995            channel.reserved_state.target = new_target;
1996        } else if channel.reserved_state.target.vp != new_target.vp {
1997            let message_port = channel.reserved_state.message_port.as_mut().unwrap();
1998
1999            // The vp has changed, but the SINT is the same. Just update the vp. If this fails,
2000            // ignore it and just send to the old vp.
2001            if let Err(err) = message_port.set_target_vp(new_target.vp) {
2002                tracing::error!(
2003                    key = %channel.key,
2004                    ?err,
2005                    ?self.redirect_vtl,
2006                    ?new_target,
2007                    "could not update reserved channel message port"
2008                );
2009            }
2010
2011            channel.reserved_state.target = new_target;
2012            return Some(message_port);
2013        }
2014
2015        Some(channel.reserved_state.message_port.as_mut().unwrap())
2016    }
2017
2018    fn unreserve_channels(&mut self) {
2019        // Unreserve all closed channels.
2020        for channel in self.channels.values_mut() {
2021            if let ChannelState::Closed = channel.state {
2022                channel.reserved_state.message_port = None;
2023            }
2024        }
2025    }
2026
2027    fn get_gm_for_channel(&self, version: VersionInfo, channel: &Channel) -> &GuestMemory {
2028        select_gm_for_channel(&self.gm, self.private_gm.as_ref(), version, channel)
2029    }
2030}
2031
2032fn select_gm_for_channel<'a>(
2033    gm: &'a GuestMemory,
2034    private_gm: Option<&'a GuestMemory>,
2035    version: VersionInfo,
2036    channel: &Channel,
2037) -> &'a GuestMemory {
2038    if channel.flags.confidential_ring_buffer() && version.feature_flags.confidential_channels() {
2039        if let Some(private_gm) = private_gm {
2040            return private_gm;
2041        }
2042    }
2043
2044    gm
2045}
2046
2047/// Control point for [`VmbusServer`], allowing callers to offer channels.
2048#[derive(Clone)]
2049pub struct VmbusServerControl {
2050    mem: GuestMemory,
2051    private_mem: Option<GuestMemory>,
2052    send: mesh::Sender<OfferRequest>,
2053    use_event: bool,
2054    force_confidential_external_memory: bool,
2055}
2056
2057impl VmbusServerControl {
2058    /// Offers a channel to the vmbus server, where the flags and user_defined data are already set.
2059    /// This is used by the relay to forward the host's parameters.
2060    pub async fn offer_core(&self, offer_info: OfferInfo) -> anyhow::Result<OfferResources> {
2061        let flags = offer_info.params.flags;
2062        self.send
2063            .call_failable(OfferRequest::Offer, offer_info)
2064            .await?;
2065        Ok(OfferResources::new(
2066            self.mem.clone(),
2067            if flags.confidential_ring_buffer() || flags.confidential_external_memory() {
2068                self.private_mem.clone()
2069            } else {
2070                None
2071            },
2072        ))
2073    }
2074
2075    /// Force reset all channels and protocol state, without requiring the
2076    /// server to be paused.
2077    pub async fn force_reset(&self) -> anyhow::Result<()> {
2078        self.send
2079            .call(OfferRequest::ForceReset, ())
2080            .await
2081            .context("vmbus server is gone")
2082    }
2083
2084    async fn offer(&self, request: OfferInput) -> anyhow::Result<OfferResources> {
2085        let mut offer_info = OfferInfo {
2086            params: request.params.into(),
2087            event: request.event,
2088            request_send: request.request_send,
2089            server_request_recv: request.server_request_recv,
2090        };
2091
2092        if self.force_confidential_external_memory {
2093            tracing::warn!(
2094                key = %offer_info.params.key(),
2095                "forcing confidential external memory for channel"
2096            );
2097
2098            offer_info
2099                .params
2100                .flags
2101                .set_confidential_external_memory(true);
2102        }
2103
2104        self.offer_core(offer_info).await
2105    }
2106}
2107
2108/// Inspects the specified ring buffer state by directly accessing guest memory.
2109fn inspect_rings(
2110    resp: &mut inspect::Response<'_>,
2111    gm: &GuestMemory,
2112    gpadl_map: Arc<GpadlMap>,
2113    open_data: &OpenData,
2114) -> Option<()> {
2115    let gpadl = gpadl_map
2116        .view()
2117        .map(GpadlId(open_data.ring_gpadl_id.0))
2118        .ok()?;
2119
2120    let aligned = AlignedGpadlView::new(gpadl).ok()?;
2121    let (in_gpadl, out_gpadl) = aligned.split(open_data.ring_offset).ok()?;
2122    resp.child("incoming_ring", |req| inspect_ring(req, &in_gpadl, gm));
2123    resp.child("outgoing_ring", |req| inspect_ring(req, &out_gpadl, gm));
2124    Some(())
2125}
2126
2127/// Inspects the incoming or outgoing ring buffer by directly accessing guest memory.
2128fn inspect_ring(req: inspect::Request<'_>, gpadl: &AlignedGpadlView, gm: &GuestMemory) {
2129    let mut resp = req.respond();
2130
2131    resp.hex("ring_size", gpadl_ring_size(gpadl));
2132
2133    // Lock just the control page. Use a subrange to allow a GuestMemory without a full mapping to
2134    // create one.
2135    if let Ok(pages) = lock_gpn_with_subrange(gm, gpadl.gpns()[0]) {
2136        ring::inspect_ring(pages.pages()[0], &mut resp);
2137    }
2138}
2139
2140fn gpadl_ring_size(gpadl: &AlignedGpadlView) -> usize {
2141    // Data size excluding the control page.
2142    (gpadl.gpns().len() - 1) * PAGE_SIZE
2143}
2144
2145/// Helper to create a subrange before locking a single page.
2146///
2147/// This allows us to lock a page in a `GuestMemory` that doesn't have a full mapping, but can
2148/// create one for a subrange.
2149fn lock_page_with_subrange(gm: &GuestMemory, offset: u64) -> anyhow::Result<guestmem::LockedPages> {
2150    Ok(gm
2151        .lockable_subrange(offset, PAGE_SIZE as u64)?
2152        .lock_gpns(false, &[0])?)
2153}
2154
2155/// Helper to create a subrange before locking a single page from a gpn.
2156///
2157/// This allows us to lock a page in a `GuestMemory` that doesn't have a full mapping, but can
2158/// create one for a subrange.
2159fn lock_gpn_with_subrange(gm: &GuestMemory, gpn: u64) -> anyhow::Result<guestmem::LockedPages> {
2160    lock_page_with_subrange(gm, gpn * PAGE_SIZE as u64)
2161}
2162
2163pub(crate) struct MessageSender {
2164    send: mpsc::Sender<SynicMessage>,
2165    multiclient: bool,
2166}
2167
2168impl MessageSender {
2169    fn poll_handle_message(
2170        &self,
2171        cx: &mut std::task::Context<'_>,
2172        msg: &[u8],
2173        trusted: bool,
2174    ) -> Poll<Result<(), SendError>> {
2175        let mut send = self.send.clone();
2176        ready!(send.poll_ready(cx))?;
2177        send.start_send(SynicMessage {
2178            data: msg.to_vec(),
2179            multiclient: self.multiclient,
2180            trusted,
2181        })?;
2182
2183        Poll::Ready(Ok(()))
2184    }
2185}
2186
2187impl MessagePort for MessageSender {
2188    fn poll_handle_message(
2189        &self,
2190        cx: &mut std::task::Context<'_>,
2191        msg: &[u8],
2192        trusted: bool,
2193    ) -> Poll<()> {
2194        if let Err(err) = ready!(self.poll_handle_message(cx, msg, trusted)) {
2195            tracelimit::error_ratelimited!(
2196                error = &err as &dyn std::error::Error,
2197                "failed to send message"
2198            );
2199        }
2200
2201        Poll::Ready(())
2202    }
2203}
2204
2205#[async_trait]
2206impl ParentBus for VmbusServerControl {
2207    async fn add_child(&self, request: OfferInput) -> anyhow::Result<OfferResources> {
2208        self.offer(request).await
2209    }
2210
2211    fn clone_bus(&self) -> Box<dyn ParentBus> {
2212        Box::new(self.clone())
2213    }
2214
2215    fn use_event(&self) -> bool {
2216        self.use_event
2217    }
2218}