Skip to main content

vmbus_server/
lib.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4#![expect(missing_docs)]
5#![forbid(unsafe_code)]
6
7mod channel_bitmap;
8pub mod channels;
9pub mod hvsock;
10mod monitor;
11mod proxyintegration;
12#[cfg(test)]
13mod tests;
14
15/// The GUID type used for vmbus channel identifiers.
16pub type Guid = guid::Guid;
17
18use anyhow::Context;
19use async_trait::async_trait;
20use channel_bitmap::ChannelBitmap;
21use channels::ConnectionTarget;
22pub use channels::InitiateContactRequest;
23use channels::MessageTarget;
24pub use channels::MnfUsage;
25use channels::ModifyConnectionRequest;
26use channels::ModifyConnectionResponse;
27use channels::Notifier;
28use channels::OfferId;
29pub use channels::OfferParamsInternal;
30use channels::OpenParams;
31use channels::RestoreError;
32pub use channels::Update;
33use futures::FutureExt;
34use futures::StreamExt;
35use futures::channel::mpsc;
36use futures::channel::mpsc::SendError;
37use futures::future::OptionFuture;
38use futures::future::poll_fn;
39use futures::stream::SelectAll;
40use guestmem::GuestMemory;
41use hvdef::Vtl;
42use inspect::Inspect;
43use mesh::payload::Protobuf;
44use mesh::rpc::FailableRpc;
45use mesh::rpc::Rpc;
46use mesh::rpc::RpcError;
47use mesh::rpc::RpcSend;
48use pal_async::driver::Driver;
49use pal_async::driver::SpawnDriver;
50use pal_async::task::Task;
51use pal_async::timer::PolledTimer;
52use pal_event::Event;
53#[cfg(windows)]
54pub use proxyintegration::ProxyIntegration;
55#[cfg(windows)]
56pub use proxyintegration::ProxyServerInfo;
57use ring::PAGE_SIZE;
58use std::collections::HashMap;
59use std::future;
60use std::future::Future;
61use std::pin::Pin;
62use std::sync::Arc;
63use std::task::Poll;
64use std::task::ready;
65use std::time::Duration;
66use unicycle::FuturesUnordered;
67use vmbus_channel::bus::ChannelRequest;
68use vmbus_channel::bus::ChannelServerRequest;
69use vmbus_channel::bus::GpadlRequest;
70use vmbus_channel::bus::ModifyRequest;
71use vmbus_channel::bus::OfferInput;
72use vmbus_channel::bus::OfferKey;
73use vmbus_channel::bus::OfferResources;
74use vmbus_channel::bus::OpenData;
75use vmbus_channel::bus::OpenRequest;
76use vmbus_channel::bus::ParentBus;
77use vmbus_channel::bus::RestoreResult;
78use vmbus_channel::gpadl::GpadlMap;
79use vmbus_channel::gpadl_ring::AlignedGpadlView;
80use vmbus_core::HvsockConnectRequest;
81use vmbus_core::HvsockConnectResult;
82use vmbus_core::MaxVersionInfo;
83use vmbus_core::OutgoingMessage;
84use vmbus_core::TaggedStream;
85use vmbus_core::VMBUS_SINT;
86use vmbus_core::VersionInfo;
87use vmbus_core::protocol;
88pub use vmbus_core::protocol::GpadlId;
89#[cfg(windows)]
90use vmbus_proxy::ProxyHandle;
91use vmbus_ring as ring;
92use vmbus_ring::gparange::MultiPagedRangeBuf;
93use vmcore::interrupt::Interrupt;
94use vmcore::save_restore::SavedStateRoot;
95use vmcore::synic::EventPort;
96use vmcore::synic::GuestEventPort;
97use vmcore::synic::GuestMessagePort;
98use vmcore::synic::MessagePort;
99use vmcore::synic::MonitorPageGpas;
100use vmcore::synic::SynicPortAccess;
101
102pub const REDIRECT_SINT: u8 = 7;
103pub const REDIRECT_VTL: Vtl = Vtl::Vtl2;
104const SHARED_EVENT_CONNECTION_ID: u32 = 2;
105const EVENT_PORT_ID: u32 = 2;
106const VMBUS_MESSAGE_TYPE: u32 = 1;
107
108const MAX_CONCURRENT_HVSOCK_REQUESTS: usize = 16;
109
110#[derive(Inspect)]
111pub struct VmbusServer {
112    #[inspect(flatten, send = "VmbusRequest::Inspect")]
113    task_send: mesh::Sender<VmbusRequest>,
114    #[inspect(skip)]
115    control: Arc<VmbusServerControl>,
116    #[inspect(skip)]
117    _message_port: Box<dyn Sync + Send>,
118    #[inspect(skip)]
119    _multiclient_message_port: Option<Box<dyn Sync + Send>>,
120    #[inspect(skip)]
121    task: Task<ServerTask>,
122}
123
124pub struct VmbusServerBuilder<T: SpawnDriver> {
125    spawner: T,
126    synic: Arc<dyn SynicPortAccess>,
127    gm: GuestMemory,
128    private_gm: Option<GuestMemory>,
129    vtl: Vtl,
130    hvsock_notify: Option<HvsockServerChannelHalf>,
131    server_relay: Option<VmbusServerChannelHalf>,
132    saved_state_notify: Option<mesh::Sender<SavedStateRequest>>,
133    external_server: Option<mesh::Sender<InitiateContactRequest>>,
134    external_requests: Option<mesh::Receiver<InitiateContactRequest>>,
135    use_message_redirect: bool,
136    channel_id_offset: u16,
137    max_version: Option<MaxVersionInfo>,
138    delay_max_version: bool,
139    max_restore_version: Option<MaxVersionInfo>,
140    enable_mnf: bool,
141    force_confidential_external_memory: bool,
142    send_messages_while_stopped: bool,
143    channel_unstick_delay: Option<Duration>,
144    use_absolute_channel_order: bool,
145}
146
147#[derive(mesh::MeshPayload)]
148/// The request to send to the proxy to set or clear its saved state cache.
149pub enum SavedStateRequest {
150    Set(FailableRpc<Box<channels::SavedState>, ()>),
151    Clear(Rpc<(), ()>),
152}
153
154/// The server side of the connection between a vmbus server and a relay.
155pub struct ServerChannelHalf<Request, Response> {
156    request_send: mesh::Sender<Request>,
157    response_receive: mesh::Receiver<Response>,
158}
159
160/// The relay side of a connection between a vmbus server and a relay.
161pub struct RelayChannelHalf<Request, Response> {
162    pub request_receive: mesh::Receiver<Request>,
163    pub response_send: mesh::Sender<Response>,
164}
165
166/// A connection between a vmbus server and a relay.
167pub struct RelayChannel<Request, Response> {
168    pub relay_half: RelayChannelHalf<Request, Response>,
169    pub server_half: ServerChannelHalf<Request, Response>,
170}
171
172impl<Request: 'static + Send, Response: 'static + Send> RelayChannel<Request, Response> {
173    /// Creates a new channel between the vmbus server and a relay.
174    pub fn new() -> Self {
175        let (request_send, request_receive) = mesh::channel();
176        let (response_send, response_receive) = mesh::channel();
177        Self {
178            relay_half: RelayChannelHalf {
179                request_receive,
180                response_send,
181            },
182            server_half: ServerChannelHalf {
183                request_send,
184                response_receive,
185            },
186        }
187    }
188}
189
190pub type VmbusServerChannelHalf = ServerChannelHalf<ModifyRelayRequest, ModifyRelayResponse>;
191pub type VmbusRelayChannelHalf = RelayChannelHalf<ModifyRelayRequest, ModifyRelayResponse>;
192pub type VmbusRelayChannel = RelayChannel<ModifyRelayRequest, ModifyRelayResponse>;
193pub type HvsockServerChannelHalf = ServerChannelHalf<HvsockConnectRequest, HvsockConnectResult>;
194pub type HvsockRelayChannelHalf = RelayChannelHalf<HvsockConnectRequest, HvsockConnectResult>;
195pub type HvsockRelayChannel = RelayChannel<HvsockConnectRequest, HvsockConnectResult>;
196
197/// A request from the server to the relay to modify connection state.
198///
199/// The version and use_interrupt_page fields can only be present if this request was sent for an
200/// InitiateContact message from the guest.
201///
202/// If `version` is `Some`, the relay must respond with either `ModifyRelayResponse::Supported` or
203/// `ModifyRelayResponse::Unsupported`. If `version` is `None`, the relay must respond with
204/// `ModifyRelayResponse::Modified`.
205#[derive(Debug, Copy, Clone)]
206pub struct ModifyRelayRequest {
207    pub version: Option<u32>,
208    pub monitor_page: Update<MonitorPageGpas>,
209    pub use_interrupt_page: Option<bool>,
210}
211
212/// A response from the relay to a ModifyRelayRequest from the server.
213#[derive(Debug, Copy, Clone)]
214pub enum ModifyRelayResponse {
215    /// The requested version change is supported, and the relay completed the connection
216    /// modification with the specified status. All of the feature flags supported by the relay host
217    /// are included, regardless of what features were requested.
218    Supported(protocol::ConnectionState, protocol::FeatureFlags),
219    /// A version change was requested but the relay host doesn't support that version. This
220    /// response cannot be returned for a request with no version change set.
221    Unsupported,
222    /// The connection modification completed with the specified status. This response must be sent
223    /// if no version change was requested.
224    Modified(protocol::ConnectionState),
225}
226
227impl From<ModifyConnectionRequest> for ModifyRelayRequest {
228    fn from(value: ModifyConnectionRequest) -> Self {
229        Self {
230            version: value.version.map(|v| v.version as u32),
231            monitor_page: value.monitor_page,
232            use_interrupt_page: match value.interrupt_page {
233                Update::Unchanged => None,
234                Update::Reset => Some(false),
235                Update::Set(_) => Some(true),
236            },
237        }
238    }
239}
240
241#[derive(Debug)]
242enum VmbusRequest {
243    Reset(Rpc<(), ()>),
244    Inspect(inspect::Deferred),
245    Save(Rpc<(), SavedState>),
246    Restore(Rpc<Box<SavedState>, Result<(), RestoreError>>),
247    Start,
248    Stop(Rpc<(), ()>),
249}
250
251#[derive(mesh::MeshPayload, Debug)]
252pub struct OfferInfo {
253    pub params: OfferParamsInternal,
254    pub event: Interrupt,
255    pub request_send: mesh::Sender<ChannelRequest>,
256    pub server_request_recv: mesh::Receiver<ChannelServerRequest>,
257}
258
259#[expect(clippy::large_enum_variant)]
260#[derive(mesh::MeshPayload)]
261pub(crate) enum OfferRequest {
262    Offer(FailableRpc<OfferInfo, ()>),
263    ForceReset(Rpc<(), ()>),
264}
265
266struct ChannelEvent(Interrupt);
267
268impl EventPort for ChannelEvent {
269    fn handle_event(&self, _flag: u16) {
270        self.0.deliver();
271    }
272
273    fn os_event(&self) -> Option<&Event> {
274        self.0.event()
275    }
276}
277
278#[derive(Debug, Protobuf, SavedStateRoot)]
279#[mesh(package = "vmbus.server")]
280pub struct SavedState {
281    #[mesh(1)]
282    pub server: channels::SavedState,
283    // Indicates if the lost synic bug is fixed or not. By default it's false.
284    // During the restore process, we check if the field is not true then
285    // unstick_channels() function will be called to mitigate the issue.
286    #[mesh(2)]
287    pub lost_synic_bug_fixed: bool,
288}
289
290const MESSAGE_CONNECTION_ID: u32 = 1;
291const MULTICLIENT_MESSAGE_CONNECTION_ID: u32 = 4;
292
293impl<T: SpawnDriver + Clone> VmbusServerBuilder<T> {
294    /// Creates a new builder for `VmbusServer` with the default options.
295    pub fn new(spawner: T, synic: Arc<dyn SynicPortAccess>, gm: GuestMemory) -> Self {
296        Self {
297            spawner,
298            synic,
299            gm,
300            private_gm: None,
301            vtl: Vtl::Vtl0,
302            hvsock_notify: None,
303            server_relay: None,
304            saved_state_notify: None,
305            external_server: None,
306            external_requests: None,
307            use_message_redirect: false,
308            channel_id_offset: 0,
309            max_version: None,
310            delay_max_version: false,
311            max_restore_version: None,
312            enable_mnf: false,
313            force_confidential_external_memory: false,
314            send_messages_while_stopped: false,
315            channel_unstick_delay: Some(Duration::from_millis(100)),
316            use_absolute_channel_order: false,
317        }
318    }
319
320    /// Sets a separate guest memory instance to use for channels that are confidential (non-relay
321    /// channels in Underhill on a hardware isolated VM). This is not relevant for a non-Underhill
322    /// VmBus server.
323    pub fn private_gm(mut self, private_gm: Option<GuestMemory>) -> Self {
324        self.private_gm = private_gm;
325        self
326    }
327
328    /// Sets the VTL that this instance will serve.
329    pub fn vtl(mut self, vtl: Vtl) -> Self {
330        self.vtl = vtl;
331        self
332    }
333
334    /// Sets a send/receive pair used to handle hvsocket requests.
335    pub fn hvsock_notify(mut self, hvsock_notify: Option<HvsockServerChannelHalf>) -> Self {
336        self.hvsock_notify = hvsock_notify;
337        self
338    }
339
340    /// Sets a send channel used to enlighten ProxyIntegration about saved channels.
341    pub fn saved_state_notify(
342        mut self,
343        saved_state_notify: Option<mesh::Sender<SavedStateRequest>>,
344    ) -> Self {
345        self.saved_state_notify = saved_state_notify;
346        self
347    }
348
349    /// Sets a send/receive pair that will be notified of server requests. This is used by the
350    /// Underhill relay.
351    pub fn server_relay(mut self, server_relay: Option<VmbusServerChannelHalf>) -> Self {
352        self.server_relay = server_relay;
353        self
354    }
355
356    /// Sets a receiver that receives requests from another server.
357    pub fn external_requests(
358        mut self,
359        external_requests: Option<mesh::Receiver<InitiateContactRequest>>,
360    ) -> Self {
361        self.external_requests = external_requests;
362        self
363    }
364
365    /// Sets a sender used to forward unhandled connect requests (which used a different VTL)
366    /// to another server.
367    pub fn external_server(
368        mut self,
369        external_server: Option<mesh::Sender<InitiateContactRequest>>,
370    ) -> Self {
371        self.external_server = external_server;
372        self
373    }
374
375    /// Sets a value which indicates whether the vmbus control plane is redirected to Underhill.
376    pub fn use_message_redirect(mut self, use_message_redirect: bool) -> Self {
377        self.use_message_redirect = use_message_redirect;
378        self
379    }
380
381    /// Tells the server to use an offset when generating channel IDs to void collisions with
382    /// another vmbus server.
383    ///
384    /// N.B. This should only be used by the Underhill vmbus server.
385    pub fn enable_channel_id_offset(mut self, enable: bool) -> Self {
386        self.channel_id_offset = if enable { 1024 } else { 0 };
387        self
388    }
389
390    /// Tells the server to limit the protocol version offered to the guest.
391    ///
392    /// N.B. This is used for testing older protocols without requiring a specific guest OS.
393    pub fn max_version(mut self, max_version: Option<MaxVersionInfo>) -> Self {
394        self.max_version = max_version;
395        self
396    }
397
398    /// Delay limiting the maximum version until after the first `Unload` message.
399    ///
400    /// N.B. This is used to enable the use of versions older than `Version::Win10` with Uefi boot,
401    ///      since that's the oldest version the Uefi client supports.
402    pub fn delay_max_version(mut self, delay: bool) -> Self {
403        self.delay_max_version = delay;
404        self
405    }
406
407    /// Tells the server to limit the protocol version accepted when restoring from saved state.
408    ///
409    /// This is configured separately from [`Self::max_version`] so that the version limit enforced
410    /// during restore can differ from the limit used for new connections. This allows a feature to
411    /// be available for rollback from a future version that has enabled it by default.
412    pub fn max_restore_version(mut self, max_restore_version: Option<MaxVersionInfo>) -> Self {
413        self.max_restore_version = max_restore_version;
414        self
415    }
416
417    /// Enable MNF support in the server.
418    ///
419    /// N.B. Enabling this has no effect if the synic does not support mapping monitor pages.
420    pub fn enable_mnf(mut self, enable: bool) -> Self {
421        self.enable_mnf = enable;
422        self
423    }
424
425    /// Force all non-relay channels to use encrypted external memory. Used for testing purposes
426    /// only.
427    pub fn force_confidential_external_memory(mut self, force: bool) -> Self {
428        self.force_confidential_external_memory = force;
429        self
430    }
431
432    /// Send messages to the partition even while stopped, which can cause
433    /// corrupted synic states across VM reset.
434    ///
435    /// This option is used to prevent messages from getting into the queue, for
436    /// saved state compatibility with release/2411. It can be removed once that
437    /// release is no longer supported.
438    pub fn send_messages_while_stopped(mut self, send: bool) -> Self {
439        self.send_messages_while_stopped = send;
440        self
441    }
442
443    /// Sets the delay before unsticking a vmbus channel after it has been opened.
444    ///
445    /// This option provides a work around for guests that ignore interrupts before they receive the
446    /// OpenResult message, by triggering an interrupt after the channel has been opened.
447    ///
448    /// If not set, the default is 100ms. If set to `None`, no interrupt will be triggered.
449    pub fn channel_unstick_delay(mut self, delay: Option<Duration>) -> Self {
450        self.channel_unstick_delay = delay;
451        self
452    }
453
454    /// Sets whether the channel order value provided in an offer is the primary way of ordering
455    /// channels when assigning channel IDs, rather than the default behavior of ordering by
456    /// interface ID first.
457    pub fn use_absolute_channel_order(mut self, assign: bool) -> Self {
458        self.use_absolute_channel_order = assign;
459        self
460    }
461
462    /// Creates a new instance of the server.
463    ///
464    /// When the object is dropped, all channels will be closed and revoked
465    /// automatically.
466    pub fn build(self) -> anyhow::Result<VmbusServer> {
467        #[expect(clippy::disallowed_methods)] // TODO
468        let (message_send, message_recv) = mpsc::channel(64);
469        let message_sender = Arc::new(MessageSender {
470            send: message_send.clone(),
471            multiclient: self.use_message_redirect,
472        });
473
474        let (redirect_vtl, redirect_sint) = if self.use_message_redirect {
475            (REDIRECT_VTL, REDIRECT_SINT)
476        } else {
477            (self.vtl, VMBUS_SINT)
478        };
479
480        // If this server is not for VTL2, use a server-specific connection ID rather than the
481        // standard one.
482        let connection_id = if self.vtl == Vtl::Vtl0 && !self.use_message_redirect {
483            MESSAGE_CONNECTION_ID
484        } else {
485            // TODO: This ID should be using the correct target VP, but that is not known until
486            //       InitiateContact.
487            VmbusServer::get_child_message_connection_id(0, redirect_sint, redirect_vtl)
488        };
489
490        let _message_port = self
491            .synic
492            .add_message_port(connection_id, redirect_vtl, message_sender)
493            .context("failed to create vmbus synic ports")?;
494
495        // If this server is for VTL0, it is also responsible for the multiclient message port.
496        // N.B. If control plane redirection is enabled, the redirected message port is used for
497        //      multiclient and no separate multiclient port is created.
498        let _multiclient_message_port = if self.vtl == Vtl::Vtl0 && !self.use_message_redirect {
499            let multiclient_message_sender = Arc::new(MessageSender {
500                send: message_send,
501                multiclient: true,
502            });
503
504            Some(
505                self.synic
506                    .add_message_port(
507                        MULTICLIENT_MESSAGE_CONNECTION_ID,
508                        self.vtl,
509                        multiclient_message_sender,
510                    )
511                    .context("failed to create vmbus synic ports")?,
512            )
513        } else {
514            None
515        };
516
517        let (offer_send, offer_recv) = mesh::mpsc_channel();
518        let control = Arc::new(VmbusServerControl {
519            mem: self.gm.clone(),
520            private_mem: self.private_gm.clone(),
521            send: offer_send,
522            use_event: self.synic.prefer_os_events(),
523            force_confidential_external_memory: self.force_confidential_external_memory,
524        });
525
526        let mut server = channels::Server::new(
527            self.vtl,
528            connection_id,
529            self.channel_id_offset,
530            self.use_absolute_channel_order,
531        );
532
533        // If MNF is handled by this server and this is a paravisor for an isolated VM, the monitor
534        // pages must be allocated by the server, not the guest, since the guest will provide shared
535        // pages which can't be used in this case. If the guest doesn't support server-specified
536        // monitor pages, MNF will be disabled for all channels for that connection.
537        server.set_require_server_allocated_mnf(self.enable_mnf && self.private_gm.is_some());
538
539        // If requested, limit the maximum protocol version and feature flags.
540        if let Some(version) = self.max_version {
541            server.set_compatibility_version(version, self.delay_max_version);
542        }
543
544        if let Some(version) = self.max_restore_version {
545            server.set_restore_compatibility_version(version);
546        }
547
548        let (relay_request_send, relay_response_recv) =
549            if let Some(server_relay) = self.server_relay {
550                let r = server_relay.response_receive.boxed().fuse();
551                (server_relay.request_send, r)
552            } else {
553                let (req_send, req_recv) = mesh::channel();
554                let resp_recv = req_recv
555                    .map(|req: ModifyRelayRequest| {
556                        // Map to the correct response type for the request.
557                        if req.version.is_some() {
558                            ModifyRelayResponse::Supported(
559                                protocol::ConnectionState::SUCCESSFUL,
560                                protocol::FeatureFlags::from_bits(u32::MAX),
561                            )
562                        } else {
563                            ModifyRelayResponse::Modified(protocol::ConnectionState::SUCCESSFUL)
564                        }
565                    })
566                    .boxed()
567                    .fuse();
568                (req_send, resp_recv)
569            };
570
571        // If no hvsock notifier was specified, use a default one that always sends an error response.
572        let (hvsock_send, hvsock_recv) = if let Some(hvsock_notify) = self.hvsock_notify {
573            let r = hvsock_notify.response_receive.boxed().fuse();
574            (hvsock_notify.request_send, r)
575        } else {
576            let (req_send, req_recv) = mesh::channel();
577            let resp_recv = req_recv
578                .map(|r: HvsockConnectRequest| HvsockConnectResult::from_request(&r, false))
579                .boxed()
580                .fuse();
581            (req_send, resp_recv)
582        };
583
584        let inner = ServerTaskInner {
585            running: false,
586            send_messages_while_stopped: self.send_messages_while_stopped,
587            gm: self.gm,
588            private_gm: self.private_gm,
589            vtl: self.vtl,
590            redirect_vtl,
591            redirect_sint,
592            message_port: self
593                .synic
594                .new_guest_message_port(redirect_vtl, 0, redirect_sint)?,
595            synic: self.synic,
596            hvsock_requests: 0,
597            hvsock_send,
598            saved_state_notify: self.saved_state_notify,
599            channels: HashMap::new(),
600            channel_responses: FuturesUnordered::new(),
601            relay_send: relay_request_send,
602            external_server_send: self.external_server,
603            channel_bitmap: None,
604            shared_event_port: None,
605            reset_done: Vec::new(),
606            mnf_support: self.enable_mnf.then(MnfSupport::default),
607        };
608
609        let (task_send, task_recv) = mesh::channel();
610        let mut server_task = ServerTask {
611            driver: Box::new(self.spawner.clone()),
612            server,
613            task_recv,
614            offer_recv,
615            message_recv,
616            server_request_recv: SelectAll::new(),
617            inner,
618            external_requests: self.external_requests,
619            next_seq: 0,
620            perform_post_restore_on_start: false,
621            unstick_on_start: false,
622            channel_unstickers: FuturesUnordered::new(),
623            channel_unstick_delay: self.channel_unstick_delay,
624        };
625
626        let task = self.spawner.spawn("vmbus server", async move {
627            server_task.run(relay_response_recv, hvsock_recv).await;
628            server_task
629        });
630
631        Ok(VmbusServer {
632            task_send,
633            control,
634            _message_port,
635            _multiclient_message_port,
636            task,
637        })
638    }
639}
640
641impl VmbusServer {
642    /// Creates a new builder for `VmbusServer` with the default options.
643    pub fn builder<T: SpawnDriver + Clone>(
644        spawner: T,
645        synic: Arc<dyn SynicPortAccess>,
646        gm: GuestMemory,
647    ) -> VmbusServerBuilder<T> {
648        VmbusServerBuilder::new(spawner, synic, gm)
649    }
650
651    pub async fn save(&self) -> SavedState {
652        self.task_send.call(VmbusRequest::Save, ()).await.unwrap()
653    }
654
655    pub async fn restore(&self, state: SavedState) -> Result<(), RestoreError> {
656        self.task_send
657            .call(VmbusRequest::Restore, Box::new(state))
658            .await
659            .unwrap()
660    }
661
662    /// Stop the control plane.
663    pub async fn stop(&self) {
664        self.task_send.call(VmbusRequest::Stop, ()).await.unwrap()
665    }
666
667    /// Starts the control plane.
668    pub fn start(&self) {
669        self.task_send.send(VmbusRequest::Start);
670    }
671
672    /// Resets the vmbus channel state.
673    pub async fn reset(&self) {
674        tracing::debug!("resetting channel state");
675        self.task_send.call(VmbusRequest::Reset, ()).await.unwrap()
676    }
677
678    /// Tears down the vmbus control plane.
679    pub async fn shutdown(self) {
680        drop(self.task_send);
681        let _ = self.task.await;
682    }
683
684    /// Returns an object that can be used to offer channels.
685    pub fn control(&self) -> Arc<VmbusServerControl> {
686        self.control.clone()
687    }
688
689    /// Returns the message connection ID to use for a communication from the guest for servers
690    /// that use a non-standard SINT or VTL.
691    fn get_child_message_connection_id(vp_index: u32, sint_index: u8, vtl: Vtl) -> u32 {
692        MULTICLIENT_MESSAGE_CONNECTION_ID
693            | (vtl as u32) << 22
694            | vp_index << 8
695            | (sint_index as u32) << 4
696    }
697
698    fn get_child_event_port_id(channel_id: protocol::ChannelId, sint_index: u8, vtl: Vtl) -> u32 {
699        EVENT_PORT_ID | (vtl as u32) << 22 | channel_id.0 << 8 | (sint_index as u32) << 4
700    }
701}
702
703#[derive(mesh::MeshPayload)]
704pub struct RestoreInfo {
705    open_data: Option<OpenData>,
706    gpadls: Vec<(GpadlId, u16, Vec<u64>)>,
707    interrupt: Option<Interrupt>,
708}
709
710#[derive(Default)]
711pub struct SynicMessage {
712    data: Vec<u8>,
713    multiclient: bool,
714    trusted: bool,
715}
716
717/// Information used by a server that supports MNF.
718#[derive(Default)]
719struct MnfSupport {
720    allocated_monitor_page: Option<MonitorPageGpas>,
721}
722
723/// Disambiguates offer instances that may have reused the same offer ID.
724#[derive(Debug, Clone, Copy)]
725struct OfferInstanceId {
726    offer_id: OfferId,
727    seq: u64,
728}
729
730struct ServerTask {
731    driver: Box<dyn Driver>,
732    server: channels::Server,
733    task_recv: mesh::Receiver<VmbusRequest>,
734    offer_recv: mesh::Receiver<OfferRequest>,
735    message_recv: mpsc::Receiver<SynicMessage>,
736    server_request_recv:
737        SelectAll<TaggedStream<OfferInstanceId, mesh::Receiver<ChannelServerRequest>>>,
738    inner: ServerTaskInner,
739    external_requests: Option<mesh::Receiver<InitiateContactRequest>>,
740    /// Next value for [`Channel::seq`].
741    next_seq: u64,
742    perform_post_restore_on_start: bool,
743    unstick_on_start: bool,
744    channel_unstickers: FuturesUnordered<Pin<Box<dyn Send + Future<Output = OfferInstanceId>>>>,
745    channel_unstick_delay: Option<Duration>,
746}
747
748struct ServerTaskInner {
749    running: bool,
750    send_messages_while_stopped: bool,
751    gm: GuestMemory,
752    private_gm: Option<GuestMemory>,
753    synic: Arc<dyn SynicPortAccess>,
754    vtl: Vtl,
755    redirect_vtl: Vtl,
756    redirect_sint: u8,
757    message_port: Box<dyn GuestMessagePort>,
758    hvsock_requests: usize,
759    hvsock_send: mesh::Sender<HvsockConnectRequest>,
760    saved_state_notify: Option<mesh::Sender<SavedStateRequest>>,
761    channels: HashMap<OfferId, Channel>,
762    channel_responses: FuturesUnordered<
763        Pin<Box<dyn Send + Future<Output = (OfferId, u64, Result<ChannelResponse, RpcError>)>>>,
764    >,
765    external_server_send: Option<mesh::Sender<InitiateContactRequest>>,
766    relay_send: mesh::Sender<ModifyRelayRequest>,
767    channel_bitmap: Option<Arc<ChannelBitmap>>,
768    shared_event_port: Option<Box<dyn Send>>,
769    reset_done: Vec<Rpc<(), ()>>,
770    /// Stores information needed to support MNF. If `None`, this server doesn't support MNF (in
771    /// the case of OpenHCL, that means it will be handled by the relay host).
772    mnf_support: Option<MnfSupport>,
773}
774
775#[derive(Debug)]
776enum ChannelResponse {
777    Open(bool),
778    Close,
779    Gpadl(GpadlId, bool),
780    TeardownGpadl(GpadlId),
781    Modify(i32),
782}
783
784#[derive(Debug, Copy, Clone, PartialEq, Eq)]
785enum ChannelUnstickState {
786    None,
787    Queued,
788    NeedsRequeue,
789}
790
791struct Channel {
792    key: OfferKey,
793    send: mesh::Sender<ChannelRequest>,
794    seq: u64,
795    state: ChannelState,
796    gpadls: Arc<GpadlMap>,
797    guest_to_host_event: Arc<ChannelEvent>,
798    flags: protocol::OfferFlags,
799    // A channel can be reserved no matter what state it is in. This allows the message port for a
800    // reserved channel to remain available even if the channel is closed, so the guest can read the
801    // close reserved channel response. The reserved state is cleared when the channel is revoked,
802    // reopened, or the guest sends an unload message.
803    reserved_state: ReservedState,
804    unstick_state: ChannelUnstickState,
805}
806
807struct ReservedState {
808    message_port: Option<Box<dyn GuestMessagePort>>,
809    target: ConnectionTarget,
810}
811
812struct ChannelOpenState {
813    open_params: OpenParams,
814    _event_port: Box<dyn Send>,
815    guest_event_port: Option<Box<dyn GuestEventPort>>,
816    host_to_guest_interrupt: Interrupt,
817}
818
819impl ChannelOpenState {
820    fn set_event_port_target_vp(&mut self, vp: u32) -> anyhow::Result<()> {
821        let Some(guest_event_port) = self.guest_event_port.as_mut() else {
822            anyhow::bail!("cannot set target VP if the channel interrupt is disabled");
823        };
824
825        guest_event_port.set_target_vp(vp)?;
826        Ok(())
827    }
828}
829
830enum ChannelState {
831    Closed,
832    Open(Box<ChannelOpenState>),
833    Closing,
834}
835
836impl ServerTask {
837    fn handle_offer(&mut self, mut info: OfferInfo) -> anyhow::Result<()> {
838        let key = info.params.key();
839        let flags = info.params.flags;
840
841        if self.inner.mnf_support.is_some() && self.inner.synic.monitor_support().is_some() {
842            // If this server is handling MnF, ignore any relayed monitor IDs but still enable MnF
843            // for those channels.
844            // N.B. Since this can only happen in OpenHCL, which emulates MnF, the latency is
845            //      ignored.
846            if info.params.use_mnf.is_relayed() {
847                info.params.use_mnf = MnfUsage::Enabled {
848                    latency: Duration::ZERO,
849                }
850            }
851        } else if info.params.use_mnf.is_enabled() {
852            // If the server is not handling MnF, disable it for the channel. This does not affect
853            // channels with a relayed monitor ID.
854            info.params.use_mnf = MnfUsage::Disabled;
855        }
856
857        let offer_id = self
858            .server
859            .with_notifier(&mut self.inner)
860            .offer_channel(info.params)
861            .context("channel offer failed")?;
862
863        tracing::debug!(?offer_id, %key, "offered channel");
864
865        let seq = self.next_seq;
866        self.next_seq += 1;
867        self.inner.channels.insert(
868            offer_id,
869            Channel {
870                key,
871                send: info.request_send,
872                state: ChannelState::Closed,
873                gpadls: GpadlMap::new(),
874                guest_to_host_event: Arc::new(ChannelEvent(info.event)),
875                seq,
876                flags,
877                reserved_state: ReservedState {
878                    message_port: None,
879                    target: ConnectionTarget { vp: 0, sint: 0 },
880                },
881                unstick_state: ChannelUnstickState::None,
882            },
883        );
884
885        self.server_request_recv.push(TaggedStream::new(
886            OfferInstanceId { offer_id, seq },
887            info.server_request_recv,
888        ));
889
890        Ok(())
891    }
892
893    fn handle_revoke(&mut self, id: OfferInstanceId) {
894        // The channel may or may not exist in the map depending on whether it's been explicitly
895        // revoked before being dropped.
896        if let Some(channel) = self.inner.channels.get(&id.offer_id) {
897            if channel.seq == id.seq {
898                tracing::info!(?id.offer_id, key = %channel.key, "revoking channel");
899                self.inner.channels.remove(&id.offer_id);
900                self.server
901                    .with_notifier(&mut self.inner)
902                    .revoke_channel(id.offer_id);
903            }
904        }
905    }
906
907    fn handle_response(
908        &mut self,
909        offer_id: OfferId,
910        seq: u64,
911        response: Result<ChannelResponse, RpcError>,
912    ) {
913        // Validate the sequence to ensure the response is not for a revoked channel.
914        let channel = self
915            .inner
916            .channels
917            .get(&offer_id)
918            .filter(|channel| channel.seq == seq);
919
920        if let Some(channel) = channel {
921            match response {
922                Ok(response) => match response {
923                    ChannelResponse::Open(result) => self.handle_open(offer_id, result),
924                    ChannelResponse::Close => self.handle_close(offer_id),
925                    ChannelResponse::Gpadl(gpadl_id, ok) => {
926                        self.handle_gpadl_create(offer_id, gpadl_id, ok)
927                    }
928                    ChannelResponse::TeardownGpadl(gpadl_id) => {
929                        self.handle_gpadl_teardown(offer_id, gpadl_id)
930                    }
931                    ChannelResponse::Modify(status) => self.handle_modify_channel(offer_id, status),
932                },
933                Err(err) => {
934                    tracing::error!(
935                        key = %channel.key,
936                        error = &err as &dyn std::error::Error,
937                        "channel response failure, channel is in inconsistent state until revoked"
938                    );
939                }
940            }
941        } else {
942            tracing::debug!(offer_id = ?offer_id, seq, ?response, "received response after revoke");
943        }
944    }
945
946    fn handle_open(&mut self, offer_id: OfferId, success: bool) {
947        let status = if success {
948            let channel = self
949                .inner
950                .channels
951                .get_mut(&offer_id)
952                .expect("channel exists");
953
954            // Some guests ignore interrupts before they receive the OpenResult message. To avoid
955            // a potential hang, signal the channel after a delay if needed.
956            if let Some(delay) = self.channel_unstick_delay {
957                if channel.unstick_state == ChannelUnstickState::None {
958                    channel.unstick_state = ChannelUnstickState::Queued;
959                    let seq = channel.seq;
960                    let mut timer = PolledTimer::new(&self.driver);
961                    self.channel_unstickers.push(Box::pin(async move {
962                        timer.sleep(delay).await;
963                        OfferInstanceId { offer_id, seq }
964                    }));
965                } else {
966                    channel.unstick_state = ChannelUnstickState::NeedsRequeue;
967                }
968            }
969
970            0
971        } else {
972            protocol::STATUS_UNSUCCESSFUL
973        };
974
975        self.server
976            .with_notifier(&mut self.inner)
977            .open_complete(offer_id, status);
978    }
979
980    fn handle_close(&mut self, offer_id: OfferId) {
981        let channel = self
982            .inner
983            .channels
984            .get_mut(&offer_id)
985            .expect("channel still exists");
986
987        match &mut channel.state {
988            ChannelState::Closing => {
989                tracing::debug!(?offer_id, key = %channel.key, "closing channel");
990                channel.state = ChannelState::Closed;
991                self.server
992                    .with_notifier(&mut self.inner)
993                    .close_complete(offer_id);
994            }
995            _ => {
996                tracing::error!(?offer_id, key = %channel.key, "invalid close channel response");
997            }
998        };
999    }
1000
1001    fn handle_gpadl_create(&mut self, offer_id: OfferId, gpadl_id: GpadlId, ok: bool) {
1002        let status = if ok { 0 } else { protocol::STATUS_UNSUCCESSFUL };
1003        self.server
1004            .with_notifier(&mut self.inner)
1005            .gpadl_create_complete(offer_id, gpadl_id, status);
1006    }
1007
1008    fn handle_gpadl_teardown(&mut self, offer_id: OfferId, gpadl_id: GpadlId) {
1009        self.server
1010            .with_notifier(&mut self.inner)
1011            .gpadl_teardown_complete(offer_id, gpadl_id);
1012    }
1013
1014    fn handle_modify_channel(&mut self, offer_id: OfferId, status: i32) {
1015        self.server
1016            .with_notifier(&mut self.inner)
1017            .modify_channel_complete(offer_id, status);
1018    }
1019
1020    fn handle_restore_channel(
1021        &mut self,
1022        offer_id: OfferId,
1023        open: bool,
1024    ) -> anyhow::Result<RestoreResult> {
1025        let gpadls = self.server.channel_gpadls(offer_id);
1026
1027        // If the channel is opened, handle that before calling into channels so that failure can
1028        // be handled before the channel is marked restored.
1029        let open_request = open
1030            .then(|| -> anyhow::Result<_> {
1031                let params = self.server.get_restore_open_params(offer_id)?;
1032                let (channel, interrupt) = self.inner.open_channel(offer_id, &params)?;
1033                Ok(OpenRequest::new(
1034                    params.open_data,
1035                    interrupt,
1036                    self.server
1037                        .get_version()
1038                        .expect("must be connected")
1039                        .feature_flags,
1040                    channel.flags,
1041                ))
1042            })
1043            .transpose()?;
1044
1045        self.server
1046            .with_notifier(&mut self.inner)
1047            .restore_channel(offer_id, open_request.is_some())?;
1048
1049        let channel = self.inner.channels.get_mut(&offer_id).unwrap();
1050        for gpadl in &gpadls {
1051            if let Ok(buf) = MultiPagedRangeBuf::from_range_buffer(
1052                gpadl.request.count.into(),
1053                gpadl.request.buf.clone(),
1054            ) {
1055                channel.gpadls.add(gpadl.request.id, buf);
1056            }
1057        }
1058
1059        let result = RestoreResult {
1060            open_request,
1061            gpadls,
1062        };
1063        Ok(result)
1064    }
1065
1066    async fn handle_request(&mut self, request: VmbusRequest) {
1067        tracing::debug!(?request, "handle_request");
1068        match request {
1069            VmbusRequest::Reset(rpc) => self.handle_reset(rpc),
1070            VmbusRequest::Inspect(deferred) => {
1071                deferred.respond(|resp| {
1072                    resp.field("message_port", &self.inner.message_port)
1073                        .field("running", self.inner.running)
1074                        .field("hvsock_requests", self.inner.hvsock_requests)
1075                        .field("channel_unstick_delay", self.channel_unstick_delay)
1076                        .field_mut_with("unstick_channels", |v| {
1077                            let v: inspect::ValueKind = if let Some(v) = v {
1078                                if v == "force" {
1079                                    self.unstick_channels(true);
1080                                    v.into()
1081                                } else {
1082                                    let v =
1083                                        v.parse().ok().context("expected false, true, or force")?;
1084                                    if v {
1085                                        self.unstick_channels(false);
1086                                    }
1087                                    v.into()
1088                                }
1089                            } else {
1090                                false.into()
1091                            };
1092                            anyhow::Ok(v)
1093                        })
1094                        .merge(&self.server.with_notifier(&mut self.inner));
1095                });
1096            }
1097            VmbusRequest::Save(rpc) => rpc.handle_sync(|()| SavedState {
1098                server: self.server.save(),
1099                lost_synic_bug_fixed: true,
1100            }),
1101            VmbusRequest::Restore(rpc) => {
1102                rpc.handle(async |state| {
1103                    self.unstick_on_start = !state.lost_synic_bug_fixed;
1104                    self.perform_post_restore_on_start = true;
1105                    if let Some(sender) = &self.inner.saved_state_notify {
1106                        tracing::trace!("sending saved state to proxy");
1107                        if let Err(err) = sender
1108                            .call_failable(SavedStateRequest::Set, Box::new(state.server.clone()))
1109                            .await
1110                        {
1111                            tracing::error!(
1112                                err = &err as &dyn std::error::Error,
1113                                "failed to restore proxy saved state"
1114                            );
1115                            return Err(RestoreError::ServerError(err.into()));
1116                        }
1117                    }
1118
1119                    self.server
1120                        .with_notifier(&mut self.inner)
1121                        .restore(state.server)
1122                })
1123                .await
1124            }
1125            VmbusRequest::Stop(rpc) => rpc.handle_sync(|()| {
1126                if self.inner.running {
1127                    self.inner.running = false;
1128                }
1129            }),
1130            VmbusRequest::Start => {
1131                if !self.inner.running {
1132                    self.inner.running = true;
1133                    if self.perform_post_restore_on_start {
1134                        if let Some(sender) = self.inner.saved_state_notify.as_ref() {
1135                            // Indicate to the proxy that the server is starting and that it should
1136                            // clear its saved state cache.
1137                            tracing::trace!("sending clear saved state message to proxy");
1138                            sender
1139                                .call(SavedStateRequest::Clear, ())
1140                                .await
1141                                .expect("failed to clear proxy saved state");
1142                        }
1143
1144                        self.server
1145                            .with_notifier(&mut self.inner)
1146                            .revoke_unclaimed_channels();
1147
1148                        self.perform_post_restore_on_start = false;
1149                    }
1150
1151                    if self.unstick_on_start {
1152                        tracing::info!(
1153                            "lost synic bug fix is not in yet, call unstick_channels to mitigate the issue."
1154                        );
1155                        self.unstick_channels(false);
1156                        self.unstick_on_start = false;
1157                    }
1158                }
1159            }
1160        }
1161    }
1162
1163    fn handle_reset(&mut self, rpc: Rpc<(), ()>) {
1164        let needs_reset = self.inner.reset_done.is_empty();
1165        self.inner.reset_done.push(rpc);
1166        if needs_reset {
1167            self.server.with_notifier(&mut self.inner).reset();
1168        }
1169    }
1170
1171    fn handle_relay_response(&mut self, response: ModifyRelayResponse) {
1172        // Convert to a matching ModifyConnectionResponse.
1173        let response = match response {
1174            ModifyRelayResponse::Supported(state, features) => {
1175                // Provide the server-allocated monitor page to the server only if they're actually being
1176                // used (if not, they may still be allocated from a previous connection).
1177                let allocated_monitor_gpas = self
1178                    .inner
1179                    .mnf_support
1180                    .as_ref()
1181                    .and_then(|mnf| mnf.allocated_monitor_page);
1182
1183                ModifyConnectionResponse::Supported(state, features, allocated_monitor_gpas)
1184            }
1185            ModifyRelayResponse::Unsupported => ModifyConnectionResponse::Unsupported,
1186            ModifyRelayResponse::Modified(state) => ModifyConnectionResponse::Modified(state),
1187        };
1188
1189        self.server
1190            .with_notifier(&mut self.inner)
1191            .complete_modify_connection(response);
1192    }
1193
1194    fn handle_tl_connect_result(&mut self, result: HvsockConnectResult) {
1195        tracing::debug!(?result, "hvsock connect result");
1196        assert_ne!(self.inner.hvsock_requests, 0);
1197        self.inner.hvsock_requests -= 1;
1198
1199        self.server
1200            .with_notifier(&mut self.inner)
1201            .send_tl_connect_result(result);
1202    }
1203
1204    fn handle_synic_message(&mut self, message: SynicMessage) {
1205        match self
1206            .server
1207            .with_notifier(&mut self.inner)
1208            .handle_synic_message(message)
1209        {
1210            Ok(()) => {}
1211            Err(err) => {
1212                tracing::warn!(
1213                    error = &err as &dyn std::error::Error,
1214                    "synic message error"
1215                );
1216            }
1217        }
1218    }
1219
1220    /// Handles a request forwarded by a different vmbus server. This is used to forward requests
1221    /// for different VTLs to different servers.
1222    ///
1223    /// N.B. This uses the same mechanism as the HCL server relay, so all requests, even the ones
1224    ///      meant for the primary server, are forwarded. In that case the primary server depends
1225    ///      on this server to send back a response so it can continue handling it.
1226    fn handle_external_request(&mut self, request: InitiateContactRequest) {
1227        self.server
1228            .with_notifier(&mut self.inner)
1229            .initiate_contact(request);
1230    }
1231
1232    async fn run(
1233        &mut self,
1234        mut relay_response_recv: impl futures::stream::FusedStream<Item = ModifyRelayResponse> + Unpin,
1235        mut hvsock_recv: impl futures::stream::FusedStream<Item = HvsockConnectResult> + Unpin,
1236    ) {
1237        loop {
1238            // Create an OptionFuture for each event that should only be handled
1239            // while the VM is running. In other cases, leave the events in
1240            // their respective queues.
1241
1242            let running_not_resetting = self.inner.running && self.inner.reset_done.is_empty();
1243            let mut external_requests = OptionFuture::from(
1244                running_not_resetting
1245                    .then(|| {
1246                        self.external_requests
1247                            .as_mut()
1248                            .map(|r| r.select_next_some())
1249                    })
1250                    .flatten(),
1251            );
1252
1253            // Try to send any pending messages while the VM is running.
1254            let has_pending_messages = self.server.has_pending_messages();
1255            let message_port = self.inner.message_port.as_mut();
1256            let mut flush_pending_messages =
1257                OptionFuture::from((running_not_resetting && has_pending_messages).then(|| {
1258                    poll_fn(|cx| {
1259                        self.server.poll_flush_pending_messages(|msg| {
1260                            message_port.poll_post_message(cx, VMBUS_MESSAGE_TYPE, msg.data())
1261                        })
1262                    })
1263                    .fuse()
1264                }));
1265
1266            // Only handle new incoming messages if there are no outgoing messages pending, and not
1267            // too many hvsock requests outstanding. This puts a bound on the resources used by the
1268            // guest.
1269            let mut message_recv = OptionFuture::from(
1270                (running_not_resetting
1271                    && !has_pending_messages
1272                    && self.inner.hvsock_requests < MAX_CONCURRENT_HVSOCK_REQUESTS)
1273                    .then(|| self.message_recv.select_next_some()),
1274            );
1275
1276            // Accept channel responses until stopped or when resetting.
1277            let mut channel_response = OptionFuture::from(
1278                (self.inner.running || !self.inner.reset_done.is_empty())
1279                    .then(|| self.inner.channel_responses.select_next_some()),
1280            );
1281
1282            // Accept hvsock connect responses while the VM is running.
1283            let mut hvsock_response =
1284                OptionFuture::from(running_not_resetting.then(|| hvsock_recv.select_next_some()));
1285
1286            let mut channel_unstickers = OptionFuture::from(
1287                running_not_resetting.then(|| self.channel_unstickers.select_next_some()),
1288            );
1289
1290            futures::select! { // merge semantics
1291                r = self.task_recv.recv().fuse() => {
1292                    if let Ok(request) = r {
1293                        self.handle_request(request).await;
1294                    } else {
1295                        break;
1296                    }
1297                }
1298                r = self.offer_recv.select_next_some() => {
1299                    match r {
1300                        OfferRequest::Offer(rpc) => {
1301                            rpc.handle_failable_sync(|request| { self.handle_offer(request) })
1302                        },
1303                        OfferRequest::ForceReset(rpc) => {
1304                            self.handle_reset(rpc);
1305                        }
1306                    }
1307                }
1308                r = self.server_request_recv.select_next_some() => {
1309                    match r {
1310                        (id, Some(request)) => match request {
1311                            ChannelServerRequest::Restore(rpc) => rpc.handle_failable_sync(|open| {
1312                                self.handle_restore_channel(id.offer_id, open)
1313                            }),
1314                            ChannelServerRequest::Revoke(rpc) => rpc.handle_sync(|_| {
1315                                self.handle_revoke(id);
1316                            })
1317                        },
1318                        (id, None) => self.handle_revoke(id),
1319                    }
1320                }
1321                r = channel_response => {
1322                    let (id, seq, response) = r.unwrap();
1323                    self.handle_response(id, seq, response);
1324                }
1325                r = relay_response_recv.select_next_some() => {
1326                    self.handle_relay_response(r);
1327                },
1328                r = hvsock_response => {
1329                    self.handle_tl_connect_result(r.unwrap());
1330                }
1331                data = message_recv => {
1332                    let data = data.unwrap();
1333                    self.handle_synic_message(data);
1334                }
1335                r = external_requests => {
1336                    let r = r.unwrap();
1337                    self.handle_external_request(r);
1338                }
1339                r = channel_unstickers => {
1340                    self.unstick_channel_by_id(r.unwrap());
1341                }
1342                _r = flush_pending_messages => {}
1343                complete => break,
1344            }
1345        }
1346    }
1347
1348    /// Wakes the guest and optionally the host for every open channel. If `force`, always wakes
1349    /// them. If `!force`, only wake for rings that are in the state where a notification is
1350    /// expected.
1351    fn unstick_channels(&self, force: bool) {
1352        let Some(version) = self.server.get_version() else {
1353            tracing::warn!("cannot unstick when not connected");
1354            return;
1355        };
1356
1357        for channel in self.inner.channels.values() {
1358            let gm = self.inner.get_gm_for_channel(version, channel);
1359            if let Err(err) = Self::unstick_channel(gm, channel, force, true) {
1360                tracing::warn!(
1361                    channel = %channel.key,
1362                    error = err.as_ref() as &dyn std::error::Error,
1363                    "could not unstick channel"
1364                );
1365            }
1366        }
1367    }
1368
1369    /// Wakes the guest for the specified channel if it's open and the rings are in a state where
1370    /// notification is expected.
1371    fn unstick_channel_by_id(&mut self, id: OfferInstanceId) {
1372        let Some(version) = self.server.get_version() else {
1373            tracelimit::warn_ratelimited!("cannot unstick when not connected");
1374            return;
1375        };
1376
1377        if let Some(channel) = self.inner.channels.get_mut(&id.offer_id) {
1378            if channel.seq != id.seq {
1379                // The channel was revoked.
1380                return;
1381            }
1382
1383            // The channel was closed and reopened before the delay expired, so wait again to ensure
1384            // we don't signal too early.
1385            if channel.unstick_state == ChannelUnstickState::NeedsRequeue {
1386                channel.unstick_state = ChannelUnstickState::Queued;
1387                let mut timer = PolledTimer::new(&self.driver);
1388                let delay = self.channel_unstick_delay.unwrap();
1389                self.channel_unstickers.push(Box::pin(async move {
1390                    timer.sleep(delay).await;
1391                    id
1392                }));
1393
1394                return;
1395            }
1396
1397            channel.unstick_state = ChannelUnstickState::None;
1398            let gm = select_gm_for_channel(
1399                &self.inner.gm,
1400                self.inner.private_gm.as_ref(),
1401                version,
1402                channel,
1403            );
1404            if let Err(err) = Self::unstick_channel(gm, channel, false, false) {
1405                tracelimit::warn_ratelimited!(
1406                    channel = %channel.key,
1407                    error = err.as_ref() as &dyn std::error::Error,
1408                    "could not unstick channel"
1409                );
1410            }
1411        }
1412    }
1413
1414    fn unstick_channel(
1415        gm: &GuestMemory,
1416        channel: &Channel,
1417        force: bool,
1418        unstick_host: bool,
1419    ) -> anyhow::Result<()> {
1420        if let ChannelState::Open(state) = &channel.state {
1421            if force {
1422                tracing::info!(channel = %channel.key, "waking host and guest");
1423                if unstick_host {
1424                    channel.guest_to_host_event.0.deliver();
1425                }
1426                state.host_to_guest_interrupt.deliver();
1427                return Ok(());
1428            }
1429
1430            let gpadl = channel
1431                .gpadls
1432                .clone()
1433                .view()
1434                .map(state.open_params.open_data.ring_gpadl_id)
1435                .context("couldn't find ring gpadl")?;
1436
1437            let aligned = AlignedGpadlView::new(gpadl)
1438                .ok()
1439                .context("ring not aligned")?;
1440            let (in_gpadl, out_gpadl) = aligned
1441                .split(state.open_params.open_data.ring_offset)
1442                .ok()
1443                .context("couldn't split ring")?;
1444
1445            if let Err(err) = Self::unstick_incoming_ring(
1446                gm,
1447                channel,
1448                in_gpadl,
1449                unstick_host.then_some(channel.guest_to_host_event.as_ref()),
1450                &state.host_to_guest_interrupt,
1451            ) {
1452                tracelimit::warn_ratelimited!(
1453                    channel = %channel.key,
1454                    error = err.as_ref() as &dyn std::error::Error,
1455                    "could not unstick incoming ring"
1456                );
1457            }
1458            if let Err(err) = Self::unstick_outgoing_ring(
1459                gm,
1460                channel,
1461                out_gpadl,
1462                unstick_host.then_some(channel.guest_to_host_event.as_ref()),
1463                &state.host_to_guest_interrupt,
1464            ) {
1465                tracelimit::warn_ratelimited!(
1466                    channel = %channel.key,
1467                    error = err.as_ref() as &dyn std::error::Error,
1468                    "could not unstick outgoing ring"
1469                );
1470            }
1471        }
1472        Ok(())
1473    }
1474
1475    fn unstick_incoming_ring(
1476        gm: &GuestMemory,
1477        channel: &Channel,
1478        in_gpadl: AlignedGpadlView,
1479        guest_to_host_event: Option<&ChannelEvent>,
1480        host_to_guest_interrupt: &Interrupt,
1481    ) -> anyhow::Result<()> {
1482        let control_page = lock_gpn_with_subrange(gm, in_gpadl.gpns()[0])?;
1483        if let Some(guest_to_host_event) = guest_to_host_event {
1484            if ring::reader_needs_signal(control_page.pages()[0]) {
1485                tracelimit::info_ratelimited!(channel = %channel.key, "waking host for incoming ring");
1486                guest_to_host_event.0.deliver();
1487            }
1488        }
1489
1490        let ring_size = gpadl_ring_size(&in_gpadl).try_into()?;
1491        if ring::writer_needs_signal(control_page.pages()[0], ring_size) {
1492            tracelimit::info_ratelimited!(channel = %channel.key, "waking guest for incoming ring");
1493            host_to_guest_interrupt.deliver();
1494        }
1495        Ok(())
1496    }
1497
1498    fn unstick_outgoing_ring(
1499        gm: &GuestMemory,
1500        channel: &Channel,
1501        out_gpadl: AlignedGpadlView,
1502        guest_to_host_event: Option<&ChannelEvent>,
1503        host_to_guest_interrupt: &Interrupt,
1504    ) -> anyhow::Result<()> {
1505        let control_page = lock_gpn_with_subrange(gm, out_gpadl.gpns()[0])?;
1506        if ring::reader_needs_signal(control_page.pages()[0]) {
1507            tracelimit::info_ratelimited!(channel = %channel.key, "waking guest for outgoing ring");
1508            host_to_guest_interrupt.deliver();
1509        }
1510
1511        if let Some(guest_to_host_event) = guest_to_host_event {
1512            let ring_size = gpadl_ring_size(&out_gpadl).try_into()?;
1513            if ring::writer_needs_signal(control_page.pages()[0], ring_size) {
1514                tracelimit::info_ratelimited!(channel = %channel.key, "waking host for outgoing ring");
1515                guest_to_host_event.0.deliver();
1516            }
1517        }
1518        Ok(())
1519    }
1520}
1521
1522impl Notifier for ServerTaskInner {
1523    fn notify(&mut self, offer_id: OfferId, action: channels::Action) {
1524        let channel = self
1525            .channels
1526            .get_mut(&offer_id)
1527            .expect("channel does not exist");
1528
1529        fn handle<I: 'static + Send, R: 'static + Send>(
1530            offer_id: OfferId,
1531            channel: &Channel,
1532            req: impl FnOnce(Rpc<I, R>) -> ChannelRequest,
1533            input: I,
1534            f: impl 'static + Send + FnOnce(R) -> ChannelResponse,
1535        ) -> Pin<Box<dyn Send + Future<Output = (OfferId, u64, Result<ChannelResponse, RpcError>)>>>
1536        {
1537            let recv = channel.send.call(req, input);
1538            let seq = channel.seq;
1539            Box::pin(async move {
1540                let r = recv.await.map(f);
1541                (offer_id, seq, r)
1542            })
1543        }
1544
1545        let response = match action {
1546            channels::Action::Open(open_params, version) => {
1547                let seq = channel.seq;
1548                let key = channel.key;
1549                match self.open_channel(offer_id, &open_params) {
1550                    Ok((channel, interrupt)) => handle(
1551                        offer_id,
1552                        channel,
1553                        ChannelRequest::Open,
1554                        OpenRequest::new(
1555                            open_params.open_data,
1556                            interrupt,
1557                            version.feature_flags,
1558                            channel.flags,
1559                        ),
1560                        ChannelResponse::Open,
1561                    ),
1562                    Err(err) => {
1563                        tracelimit::error_ratelimited!(
1564                            err = err.as_ref() as &dyn std::error::Error,
1565                            ?offer_id,
1566                            %key,
1567                            "could not open channel",
1568                        );
1569
1570                        // Return an error response to the channels module if the open_channel call
1571                        // failed.
1572                        Box::pin(future::ready((
1573                            offer_id,
1574                            seq,
1575                            Ok(ChannelResponse::Open(false)),
1576                        )))
1577                    }
1578                }
1579            }
1580            channels::Action::Close => {
1581                if let Some(channel_bitmap) = self.channel_bitmap.as_ref() {
1582                    if let ChannelState::Open(ref state) = channel.state {
1583                        channel_bitmap.unregister_channel(state.open_params.event_flag);
1584                    }
1585                }
1586
1587                channel.state = ChannelState::Closing;
1588                handle(offer_id, channel, ChannelRequest::Close, (), |()| {
1589                    ChannelResponse::Close
1590                })
1591            }
1592            channels::Action::Gpadl(gpadl_id, count, buf) => {
1593                channel.gpadls.add(
1594                    gpadl_id,
1595                    MultiPagedRangeBuf::from_range_buffer(count.into(), buf.clone()).unwrap(),
1596                );
1597                handle(
1598                    offer_id,
1599                    channel,
1600                    ChannelRequest::Gpadl,
1601                    GpadlRequest {
1602                        id: gpadl_id,
1603                        count,
1604                        buf,
1605                    },
1606                    move |r| ChannelResponse::Gpadl(gpadl_id, r),
1607                )
1608            }
1609            channels::Action::TeardownGpadl {
1610                gpadl_id,
1611                post_restore,
1612            } => {
1613                if !post_restore {
1614                    channel.gpadls.remove(gpadl_id, Box::new(|| ()));
1615                }
1616
1617                handle(
1618                    offer_id,
1619                    channel,
1620                    ChannelRequest::TeardownGpadl,
1621                    gpadl_id,
1622                    move |()| ChannelResponse::TeardownGpadl(gpadl_id),
1623                )
1624            }
1625            channels::Action::Modify { target_vp } => {
1626                let ChannelState::Open(state) = &mut channel.state else {
1627                    unreachable!();
1628                };
1629
1630                if let Err(err) = state.set_event_port_target_vp(target_vp) {
1631                    tracelimit::error_ratelimited!(
1632                        error = err.as_ref() as &dyn std::error::Error,
1633                        channel = %channel.key,
1634                        "could not modify channel",
1635                    );
1636
1637                    // Send an immediate error response.
1638                    let seq = channel.seq;
1639                    Box::pin(async move {
1640                        (
1641                            offer_id,
1642                            seq,
1643                            Ok(ChannelResponse::Modify(protocol::STATUS_UNSUCCESSFUL)),
1644                        )
1645                    })
1646                } else {
1647                    handle(
1648                        offer_id,
1649                        channel,
1650                        ChannelRequest::Modify,
1651                        ModifyRequest::TargetVp { target_vp },
1652                        ChannelResponse::Modify,
1653                    )
1654                }
1655            }
1656        };
1657        self.channel_responses.push(response);
1658    }
1659
1660    fn modify_connection(&mut self, mut request: ModifyConnectionRequest) -> anyhow::Result<()> {
1661        self.map_interrupt_page(request.interrupt_page)
1662            .context("Failed to map interrupt page.")?;
1663
1664        self.set_monitor_page(&mut request)
1665            .context("Failed to map monitor page.")?;
1666
1667        if let Some(vp) = request.target_message_vp {
1668            self.message_port.set_target_vp(vp)?;
1669        }
1670
1671        if request.notify_relay {
1672            self.relay_send.send(request.into());
1673        }
1674
1675        Ok(())
1676    }
1677
1678    fn forward_unhandled(&mut self, request: InitiateContactRequest) {
1679        if let Some(external_server) = &self.external_server_send {
1680            external_server.send(request);
1681        } else {
1682            tracing::warn!(?request, "nowhere to forward unhandled request")
1683        }
1684    }
1685
1686    fn inspect(&self, version: Option<VersionInfo>, offer_id: OfferId, req: inspect::Request<'_>) {
1687        let channel = self.channels.get(&offer_id).expect("should exist");
1688        let mut resp = req.respond();
1689        // Only inspect ring buffers if we have an active connection; during
1690        // disconnect/reset, `version` may be None even if an individual channel
1691        // is still in the Open state, and we must not panic during inspect
1692        // (e.g., timeout diagnostics rely on inspect succeeding).
1693        if let (ChannelState::Open(state), Some(version)) = (&channel.state, version) {
1694            let mem = self.get_gm_for_channel(version, channel);
1695            inspect_rings(
1696                &mut resp,
1697                mem,
1698                channel.gpadls.clone(),
1699                &state.open_params.open_data,
1700            );
1701        }
1702    }
1703
1704    fn send_message(&mut self, message: &OutgoingMessage, target: MessageTarget) -> bool {
1705        // If the server is paused, queue all messages, to avoid affecting synic
1706        // state during/after it has been saved or reset.
1707        //
1708        // Note that messages to reserved channels or custom targets will be
1709        // dropped. However, such messages should only be sent in response to
1710        // guest requests, which should not be processed while the server is
1711        // paused.
1712        //
1713        // FUTURE: it would be better to ensure that no messages are generated
1714        // by operations that run while the server is paused. E.g., defer
1715        // sending offer or revoke messages for new or revoked offers. This
1716        // would prevent the queue from growing without bound.
1717        if !self.running && !self.send_messages_while_stopped {
1718            if !matches!(target, MessageTarget::Default) {
1719                tracelimit::error_ratelimited!(?target, "dropping message while paused");
1720            }
1721            return false;
1722        }
1723
1724        let mut port_storage;
1725        let port = match target {
1726            MessageTarget::Default => self.message_port.as_mut(),
1727            MessageTarget::ReservedChannel(offer_id, target) => {
1728                if let Some(port) = self.get_reserved_channel_message_port(offer_id, target) {
1729                    port.as_mut()
1730                } else {
1731                    // Updating the port failed, so there is no way to send the message.
1732                    return true;
1733                }
1734            }
1735            MessageTarget::Custom(target) => {
1736                port_storage = match self.synic.new_guest_message_port(
1737                    self.redirect_vtl,
1738                    target.vp,
1739                    target.sint,
1740                ) {
1741                    Ok(port) => port,
1742                    Err(err) => {
1743                        tracing::error!(
1744                            ?err,
1745                            ?self.redirect_vtl,
1746                            ?target,
1747                            "could not create message port"
1748                        );
1749
1750                        // There is no way to send the message.
1751                        return true;
1752                    }
1753                };
1754                port_storage.as_mut()
1755            }
1756        };
1757
1758        // If this returns Pending, the channels module will queue the message and the ServerTask
1759        // main loop will try to send it again later.
1760        matches!(
1761            port.poll_post_message(
1762                &mut std::task::Context::from_waker(std::task::Waker::noop()),
1763                VMBUS_MESSAGE_TYPE,
1764                message.data()
1765            ),
1766            Poll::Ready(())
1767        )
1768    }
1769
1770    fn notify_hvsock(&mut self, request: &HvsockConnectRequest) {
1771        tracing::debug!(?request, "received hvsock connect request");
1772        self.hvsock_requests += 1;
1773        self.hvsock_send.send(*request);
1774    }
1775
1776    fn reset_complete(&mut self) {
1777        if let Some(monitor) = self.synic.monitor_support() {
1778            if let Err(err) = monitor.set_monitor_page(self.vtl, None) {
1779                tracing::warn!(?err, "resetting monitor page failed")
1780            }
1781        }
1782
1783        self.unreserve_channels();
1784        for done in self.reset_done.drain(..) {
1785            done.complete(());
1786        }
1787    }
1788
1789    fn unload_complete(&mut self) {
1790        self.unreserve_channels();
1791    }
1792}
1793
1794impl ServerTaskInner {
1795    fn open_channel(
1796        &mut self,
1797        offer_id: OfferId,
1798        open_params: &OpenParams,
1799    ) -> anyhow::Result<(&mut Channel, Interrupt)> {
1800        let channel = self
1801            .channels
1802            .get_mut(&offer_id)
1803            .expect("channel does not exist");
1804
1805        // Always register with the channel bitmap; if Win7, this may be unnecessary.
1806        if let Some(channel_bitmap) = self.channel_bitmap.as_ref() {
1807            channel_bitmap.register_channel(
1808                open_params.event_flag,
1809                channel.guest_to_host_event.0.clone(),
1810            );
1811        }
1812        // Always set up an event port; if V1, this will be unused.
1813        // N.B. The event port must be created before the device is notified of the open by the
1814        //      caller. The device may begin communicating with the guest immediately when it is
1815        //      notified, so the event port must exist so that the guest can send interrupts.
1816        let event_port = self
1817            .synic
1818            .add_event_port(
1819                open_params.connection_id,
1820                self.vtl,
1821                channel.guest_to_host_event.clone(),
1822                open_params.monitor_info,
1823            )
1824            .context("failed to create guest-to-host event port")?;
1825
1826        // For pre-Win8 guests, the host-to-guest event always targets vp 0 and the channel
1827        // bitmap is used instead of the event flag.
1828        let (target_vp, event_flag) = if self.channel_bitmap.is_some() {
1829            (Some(0), 0)
1830        } else {
1831            (open_params.open_data.target_vp, open_params.event_flag)
1832        };
1833
1834        let (guest_event_port, interrupt) = if let Some(target_vp) = target_vp {
1835            let (target_vtl, target_sint) = if open_params.flags.redirect_interrupt() {
1836                (self.redirect_vtl, self.redirect_sint)
1837            } else {
1838                (self.vtl, VMBUS_SINT)
1839            };
1840
1841            let guest_event_port = self.synic.new_guest_event_port(
1842                VmbusServer::get_child_event_port_id(open_params.channel_id, VMBUS_SINT, self.vtl),
1843                target_vtl,
1844                target_vp,
1845                target_sint,
1846                event_flag,
1847                open_params.monitor_info,
1848            )?;
1849
1850            let interrupt = ChannelBitmap::create_interrupt(
1851                &self.channel_bitmap,
1852                guest_event_port.interrupt(),
1853                open_params.event_flag,
1854            );
1855
1856            (Some(guest_event_port), interrupt)
1857        } else {
1858            // Use a dummy interrupt which does nothing.
1859            (None, Interrupt::null())
1860        };
1861
1862        // Delete any previously reserved state.
1863        channel.reserved_state.message_port = None;
1864
1865        // If the channel is reserved, create a message port for it.
1866        if let Some(target) = open_params.reserved_target {
1867            channel.reserved_state.message_port = Some(self.synic.new_guest_message_port(
1868                self.redirect_vtl,
1869                target.vp,
1870                target.sint,
1871            )?);
1872
1873            channel.reserved_state.target = target;
1874        }
1875
1876        channel.state = ChannelState::Open(Box::new(ChannelOpenState {
1877            open_params: *open_params,
1878            _event_port: event_port,
1879            guest_event_port,
1880            host_to_guest_interrupt: interrupt.clone(),
1881        }));
1882        Ok((channel, interrupt))
1883    }
1884
1885    /// If the client specified an interrupt page, map it into host memory and
1886    /// set up the shared event port.
1887    fn map_interrupt_page(&mut self, interrupt_page: Update<u64>) -> anyhow::Result<()> {
1888        let interrupt_page = match interrupt_page {
1889            Update::Unchanged => return Ok(()),
1890            Update::Reset => {
1891                self.channel_bitmap = None;
1892                self.shared_event_port = None;
1893                return Ok(());
1894            }
1895            Update::Set(interrupt_page) => interrupt_page,
1896        };
1897
1898        assert_ne!(interrupt_page, 0);
1899
1900        if interrupt_page % PAGE_SIZE as u64 != 0 {
1901            anyhow::bail!("interrupt page {:#x} is not page aligned", interrupt_page);
1902        }
1903
1904        // Use a subrange to access the interrupt page to give GuestMemory's without a full mapping
1905        // a chance to create one.
1906        let interrupt_page = lock_page_with_subrange(&self.gm, interrupt_page)?;
1907        let channel_bitmap = Arc::new(ChannelBitmap::new(interrupt_page));
1908        self.channel_bitmap = Some(channel_bitmap.clone());
1909
1910        // Create the shared event port for pre-Win8 guests.
1911        let interrupt = Interrupt::from_fn(move || {
1912            channel_bitmap.handle_shared_interrupt();
1913        });
1914
1915        self.shared_event_port = Some(self.synic.add_event_port(
1916            SHARED_EVENT_CONNECTION_ID,
1917            self.vtl,
1918            Arc::new(ChannelEvent(interrupt)),
1919            None,
1920        )?);
1921
1922        Ok(())
1923    }
1924
1925    fn set_monitor_page(&mut self, request: &mut ModifyConnectionRequest) -> anyhow::Result<()> {
1926        let monitor_page = match request.monitor_page {
1927            Update::Unchanged => return Ok(()),
1928            Update::Reset => None,
1929            Update::Set(value) => Some(value),
1930        };
1931
1932        // TODO: can this check be moved into channels.rs?
1933        if self.channels.iter().any(|(_, c)| {
1934            matches!(
1935                &c.state,
1936                ChannelState::Open(state) if state.open_params.monitor_info.is_some()
1937            )
1938        }) {
1939            anyhow::bail!("attempt to change monitor page while open channels using mnf");
1940        }
1941
1942        // Check if the server is handling MNF.
1943        // N.B. If the server is not handling MNF, there is currently no way to request
1944        //      server-allocated monitor pages from the relay host.
1945        if let Some(mnf_support) = self.mnf_support.as_mut() {
1946            if let Some(monitor) = self.synic.monitor_support() {
1947                mnf_support.allocated_monitor_page = None;
1948
1949                if let Some(version) = request.version {
1950                    if version.feature_flags.server_specified_monitor_pages() {
1951                        if let Some(monitor_page) = monitor.allocate_monitor_page(self.vtl)? {
1952                            tracelimit::info_ratelimited!(
1953                                ?monitor_page,
1954                                "using server-allocated monitor pages"
1955                            );
1956                            mnf_support.allocated_monitor_page = Some(monitor_page);
1957                        }
1958                    }
1959                }
1960
1961                // If no monitor page was allocated above, use the one provided by the client.
1962                if mnf_support.allocated_monitor_page.is_none() {
1963                    if let Err(err) = monitor.set_monitor_page(self.vtl, monitor_page) {
1964                        anyhow::bail!(
1965                            "setting monitor page failed, err = {err:?}, monitor_page = {monitor_page:?}"
1966                        );
1967                    }
1968                }
1969            }
1970
1971            // If MNF is configured to be handled by this server (even if it's not actually
1972            // supported by the synic), don't forward the pages to the relay.
1973            request.monitor_page = Update::Unchanged;
1974        }
1975
1976        Ok(())
1977    }
1978
1979    fn get_reserved_channel_message_port(
1980        &mut self,
1981        offer_id: OfferId,
1982        new_target: ConnectionTarget,
1983    ) -> Option<&mut Box<dyn GuestMessagePort>> {
1984        let channel = self
1985            .channels
1986            .get_mut(&offer_id)
1987            .expect("channel does not exist");
1988
1989        assert!(
1990            channel.reserved_state.message_port.is_some(),
1991            "channel is not reserved"
1992        );
1993
1994        // On close, the guest may have changed the message target it wants to use for the close
1995        // response. If so, update the message port.
1996        if channel.reserved_state.target.sint != new_target.sint {
1997            // Destroy the old port before creating the new one.
1998            channel.reserved_state.message_port = None;
1999            let message_port = self
2000                .synic
2001                .new_guest_message_port(self.redirect_vtl, new_target.vp, new_target.sint)
2002                .inspect_err(|err| {
2003                    tracing::error!(
2004                        key = %channel.key,
2005                        ?err,
2006                        ?self.redirect_vtl,
2007                        ?new_target,
2008                        "could not create reserved channel message port"
2009                    )
2010                })
2011                .ok()?;
2012
2013            channel.reserved_state.message_port = Some(message_port);
2014            channel.reserved_state.target = new_target;
2015        } else if channel.reserved_state.target.vp != new_target.vp {
2016            let message_port = channel.reserved_state.message_port.as_mut().unwrap();
2017
2018            // The vp has changed, but the SINT is the same. Just update the vp. If this fails,
2019            // ignore it and just send to the old vp.
2020            if let Err(err) = message_port.set_target_vp(new_target.vp) {
2021                tracing::error!(
2022                    key = %channel.key,
2023                    ?err,
2024                    ?self.redirect_vtl,
2025                    ?new_target,
2026                    "could not update reserved channel message port"
2027                );
2028            }
2029
2030            channel.reserved_state.target = new_target;
2031            return Some(message_port);
2032        }
2033
2034        Some(channel.reserved_state.message_port.as_mut().unwrap())
2035    }
2036
2037    fn unreserve_channels(&mut self) {
2038        // Unreserve all closed channels.
2039        for channel in self.channels.values_mut() {
2040            if let ChannelState::Closed = channel.state {
2041                channel.reserved_state.message_port = None;
2042            }
2043        }
2044    }
2045
2046    fn get_gm_for_channel(&self, version: VersionInfo, channel: &Channel) -> &GuestMemory {
2047        select_gm_for_channel(&self.gm, self.private_gm.as_ref(), version, channel)
2048    }
2049}
2050
2051fn select_gm_for_channel<'a>(
2052    gm: &'a GuestMemory,
2053    private_gm: Option<&'a GuestMemory>,
2054    version: VersionInfo,
2055    channel: &Channel,
2056) -> &'a GuestMemory {
2057    if channel.flags.confidential_ring_buffer() && version.feature_flags.confidential_channels() {
2058        if let Some(private_gm) = private_gm {
2059            return private_gm;
2060        }
2061    }
2062
2063    gm
2064}
2065
2066/// Control point for [`VmbusServer`], allowing callers to offer channels.
2067#[derive(Clone)]
2068pub struct VmbusServerControl {
2069    mem: GuestMemory,
2070    private_mem: Option<GuestMemory>,
2071    send: mesh::Sender<OfferRequest>,
2072    use_event: bool,
2073    force_confidential_external_memory: bool,
2074}
2075
2076impl VmbusServerControl {
2077    /// Offers a channel to the vmbus server, where the flags and user_defined data are already set.
2078    /// This is used by the relay to forward the host's parameters.
2079    pub async fn offer_core(&self, offer_info: OfferInfo) -> anyhow::Result<OfferResources> {
2080        let flags = offer_info.params.flags;
2081        self.send
2082            .call_failable(OfferRequest::Offer, offer_info)
2083            .await?;
2084        Ok(OfferResources::new(
2085            self.mem.clone(),
2086            if flags.confidential_ring_buffer() || flags.confidential_external_memory() {
2087                self.private_mem.clone()
2088            } else {
2089                None
2090            },
2091        ))
2092    }
2093
2094    /// Force reset all channels and protocol state, without requiring the
2095    /// server to be paused.
2096    pub async fn force_reset(&self) -> anyhow::Result<()> {
2097        self.send
2098            .call(OfferRequest::ForceReset, ())
2099            .await
2100            .context("vmbus server is gone")
2101    }
2102
2103    async fn offer(&self, request: OfferInput) -> anyhow::Result<OfferResources> {
2104        let mut offer_info = OfferInfo {
2105            params: request.params.into(),
2106            event: request.event,
2107            request_send: request.request_send,
2108            server_request_recv: request.server_request_recv,
2109        };
2110
2111        if self.force_confidential_external_memory {
2112            tracing::warn!(
2113                key = %offer_info.params.key(),
2114                "forcing confidential external memory for channel"
2115            );
2116
2117            offer_info
2118                .params
2119                .flags
2120                .set_confidential_external_memory(true);
2121        }
2122
2123        self.offer_core(offer_info).await
2124    }
2125}
2126
2127/// Inspects the specified ring buffer state by directly accessing guest memory.
2128fn inspect_rings(
2129    resp: &mut inspect::Response<'_>,
2130    gm: &GuestMemory,
2131    gpadl_map: Arc<GpadlMap>,
2132    open_data: &OpenData,
2133) -> Option<()> {
2134    let gpadl = gpadl_map
2135        .view()
2136        .map(GpadlId(open_data.ring_gpadl_id.0))
2137        .ok()?;
2138
2139    let aligned = AlignedGpadlView::new(gpadl).ok()?;
2140    let (in_gpadl, out_gpadl) = aligned.split(open_data.ring_offset).ok()?;
2141    resp.child("incoming_ring", |req| inspect_ring(req, &in_gpadl, gm));
2142    resp.child("outgoing_ring", |req| inspect_ring(req, &out_gpadl, gm));
2143    Some(())
2144}
2145
2146/// Inspects the incoming or outgoing ring buffer by directly accessing guest memory.
2147fn inspect_ring(req: inspect::Request<'_>, gpadl: &AlignedGpadlView, gm: &GuestMemory) {
2148    let mut resp = req.respond();
2149
2150    resp.hex("ring_size", gpadl_ring_size(gpadl));
2151
2152    // Lock just the control page. Use a subrange to allow a GuestMemory without a full mapping to
2153    // create one.
2154    if let Ok(pages) = lock_gpn_with_subrange(gm, gpadl.gpns()[0]) {
2155        ring::inspect_ring(pages.pages()[0], &mut resp);
2156    }
2157}
2158
2159fn gpadl_ring_size(gpadl: &AlignedGpadlView) -> usize {
2160    // Data size excluding the control page.
2161    (gpadl.gpns().len() - 1) * PAGE_SIZE
2162}
2163
2164/// Helper to create a subrange before locking a single page.
2165///
2166/// This allows us to lock a page in a `GuestMemory` that doesn't have a full mapping, but can
2167/// create one for a subrange.
2168fn lock_page_with_subrange(gm: &GuestMemory, offset: u64) -> anyhow::Result<guestmem::LockedPages> {
2169    Ok(gm
2170        .lockable_subrange(offset, PAGE_SIZE as u64)?
2171        .lock_gpns(false, &[0])?)
2172}
2173
2174/// Helper to create a subrange before locking a single page from a gpn.
2175///
2176/// This allows us to lock a page in a `GuestMemory` that doesn't have a full mapping, but can
2177/// create one for a subrange.
2178fn lock_gpn_with_subrange(gm: &GuestMemory, gpn: u64) -> anyhow::Result<guestmem::LockedPages> {
2179    lock_page_with_subrange(gm, gpn * PAGE_SIZE as u64)
2180}
2181
2182pub(crate) struct MessageSender {
2183    send: mpsc::Sender<SynicMessage>,
2184    multiclient: bool,
2185}
2186
2187impl MessageSender {
2188    fn poll_handle_message(
2189        &self,
2190        cx: &mut std::task::Context<'_>,
2191        msg: &[u8],
2192        trusted: bool,
2193    ) -> Poll<Result<(), SendError>> {
2194        let mut send = self.send.clone();
2195        ready!(send.poll_ready(cx))?;
2196        send.start_send(SynicMessage {
2197            data: msg.to_vec(),
2198            multiclient: self.multiclient,
2199            trusted,
2200        })?;
2201
2202        Poll::Ready(Ok(()))
2203    }
2204}
2205
2206impl MessagePort for MessageSender {
2207    fn poll_handle_message(
2208        &self,
2209        cx: &mut std::task::Context<'_>,
2210        msg: &[u8],
2211        trusted: bool,
2212    ) -> Poll<()> {
2213        if let Err(err) = ready!(self.poll_handle_message(cx, msg, trusted)) {
2214            tracelimit::error_ratelimited!(
2215                error = &err as &dyn std::error::Error,
2216                "failed to send message"
2217            );
2218        }
2219
2220        Poll::Ready(())
2221    }
2222}
2223
2224#[async_trait]
2225impl ParentBus for VmbusServerControl {
2226    async fn add_child(&self, request: OfferInput) -> anyhow::Result<OfferResources> {
2227        self.offer(request).await
2228    }
2229
2230    fn clone_bus(&self) -> Box<dyn ParentBus> {
2231        Box::new(self.clone())
2232    }
2233
2234    fn use_event(&self) -> bool {
2235        self.use_event
2236    }
2237}