1#![expect(missing_docs)]
5#![forbid(unsafe_code)]
6
7mod channel_bitmap;
8pub mod channels;
9pub mod hvsock;
10mod monitor;
11mod proxyintegration;
12#[cfg(test)]
13mod tests;
14
15pub type Guid = guid::Guid;
17
18use anyhow::Context;
19use async_trait::async_trait;
20use channel_bitmap::ChannelBitmap;
21use channels::ConnectionTarget;
22pub use channels::InitiateContactRequest;
23use channels::MessageTarget;
24pub use channels::MnfUsage;
25use channels::ModifyConnectionRequest;
26use channels::ModifyConnectionResponse;
27use channels::Notifier;
28use channels::OfferId;
29pub use channels::OfferParamsInternal;
30use channels::OpenParams;
31use channels::RestoreError;
32pub use channels::Update;
33use futures::FutureExt;
34use futures::StreamExt;
35use futures::channel::mpsc;
36use futures::channel::mpsc::SendError;
37use futures::future::OptionFuture;
38use futures::future::poll_fn;
39use futures::stream::SelectAll;
40use guestmem::GuestMemory;
41use hvdef::Vtl;
42use inspect::Inspect;
43use mesh::payload::Protobuf;
44use mesh::rpc::FailableRpc;
45use mesh::rpc::Rpc;
46use mesh::rpc::RpcError;
47use mesh::rpc::RpcSend;
48use pal_async::driver::Driver;
49use pal_async::driver::SpawnDriver;
50use pal_async::task::Task;
51use pal_async::timer::PolledTimer;
52use pal_event::Event;
53#[cfg(windows)]
54pub use proxyintegration::ProxyIntegration;
55#[cfg(windows)]
56pub use proxyintegration::ProxyServerInfo;
57use ring::PAGE_SIZE;
58use std::collections::HashMap;
59use std::future;
60use std::future::Future;
61use std::pin::Pin;
62use std::sync::Arc;
63use std::task::Poll;
64use std::task::ready;
65use std::time::Duration;
66use unicycle::FuturesUnordered;
67use vmbus_channel::bus::ChannelRequest;
68use vmbus_channel::bus::ChannelServerRequest;
69use vmbus_channel::bus::GpadlRequest;
70use vmbus_channel::bus::ModifyRequest;
71use vmbus_channel::bus::OfferInput;
72use vmbus_channel::bus::OfferKey;
73use vmbus_channel::bus::OfferResources;
74use vmbus_channel::bus::OpenData;
75use vmbus_channel::bus::OpenRequest;
76use vmbus_channel::bus::ParentBus;
77use vmbus_channel::bus::RestoreResult;
78use vmbus_channel::gpadl::GpadlMap;
79use vmbus_channel::gpadl_ring::AlignedGpadlView;
80use vmbus_core::HvsockConnectRequest;
81use vmbus_core::HvsockConnectResult;
82use vmbus_core::MaxVersionInfo;
83use vmbus_core::OutgoingMessage;
84use vmbus_core::TaggedStream;
85use vmbus_core::VMBUS_SINT;
86use vmbus_core::VersionInfo;
87use vmbus_core::protocol;
88pub use vmbus_core::protocol::GpadlId;
89#[cfg(windows)]
90use vmbus_proxy::ProxyHandle;
91use vmbus_ring as ring;
92use vmbus_ring::gparange::MultiPagedRangeBuf;
93use vmcore::interrupt::Interrupt;
94use vmcore::save_restore::SavedStateRoot;
95use vmcore::synic::EventPort;
96use vmcore::synic::GuestEventPort;
97use vmcore::synic::GuestMessagePort;
98use vmcore::synic::MessagePort;
99use vmcore::synic::MonitorPageGpas;
100use vmcore::synic::SynicPortAccess;
101
102pub const REDIRECT_SINT: u8 = 7;
103pub const REDIRECT_VTL: Vtl = Vtl::Vtl2;
104const SHARED_EVENT_CONNECTION_ID: u32 = 2;
105const EVENT_PORT_ID: u32 = 2;
106const VMBUS_MESSAGE_TYPE: u32 = 1;
107
108const MAX_CONCURRENT_HVSOCK_REQUESTS: usize = 16;
109
110#[derive(Inspect)]
111pub struct VmbusServer {
112 #[inspect(flatten, send = "VmbusRequest::Inspect")]
113 task_send: mesh::Sender<VmbusRequest>,
114 #[inspect(skip)]
115 control: Arc<VmbusServerControl>,
116 #[inspect(skip)]
117 _message_port: Box<dyn Sync + Send>,
118 #[inspect(skip)]
119 _multiclient_message_port: Option<Box<dyn Sync + Send>>,
120 #[inspect(skip)]
121 task: Task<ServerTask>,
122}
123
124pub struct VmbusServerBuilder<T: SpawnDriver> {
125 spawner: T,
126 synic: Arc<dyn SynicPortAccess>,
127 gm: GuestMemory,
128 private_gm: Option<GuestMemory>,
129 vtl: Vtl,
130 hvsock_notify: Option<HvsockServerChannelHalf>,
131 server_relay: Option<VmbusServerChannelHalf>,
132 saved_state_notify: Option<mesh::Sender<SavedStateRequest>>,
133 external_server: Option<mesh::Sender<InitiateContactRequest>>,
134 external_requests: Option<mesh::Receiver<InitiateContactRequest>>,
135 use_message_redirect: bool,
136 channel_id_offset: u16,
137 max_version: Option<MaxVersionInfo>,
138 delay_max_version: bool,
139 max_restore_version: Option<MaxVersionInfo>,
140 enable_mnf: bool,
141 force_confidential_external_memory: bool,
142 send_messages_while_stopped: bool,
143 channel_unstick_delay: Option<Duration>,
144 use_absolute_channel_order: bool,
145}
146
147#[derive(mesh::MeshPayload)]
148pub enum SavedStateRequest {
150 Set(FailableRpc<Box<channels::SavedState>, ()>),
151 Clear(Rpc<(), ()>),
152}
153
154pub struct ServerChannelHalf<Request, Response> {
156 request_send: mesh::Sender<Request>,
157 response_receive: mesh::Receiver<Response>,
158}
159
160pub struct RelayChannelHalf<Request, Response> {
162 pub request_receive: mesh::Receiver<Request>,
163 pub response_send: mesh::Sender<Response>,
164}
165
166pub struct RelayChannel<Request, Response> {
168 pub relay_half: RelayChannelHalf<Request, Response>,
169 pub server_half: ServerChannelHalf<Request, Response>,
170}
171
172impl<Request: 'static + Send, Response: 'static + Send> RelayChannel<Request, Response> {
173 pub fn new() -> Self {
175 let (request_send, request_receive) = mesh::channel();
176 let (response_send, response_receive) = mesh::channel();
177 Self {
178 relay_half: RelayChannelHalf {
179 request_receive,
180 response_send,
181 },
182 server_half: ServerChannelHalf {
183 request_send,
184 response_receive,
185 },
186 }
187 }
188}
189
190pub type VmbusServerChannelHalf = ServerChannelHalf<ModifyRelayRequest, ModifyRelayResponse>;
191pub type VmbusRelayChannelHalf = RelayChannelHalf<ModifyRelayRequest, ModifyRelayResponse>;
192pub type VmbusRelayChannel = RelayChannel<ModifyRelayRequest, ModifyRelayResponse>;
193pub type HvsockServerChannelHalf = ServerChannelHalf<HvsockConnectRequest, HvsockConnectResult>;
194pub type HvsockRelayChannelHalf = RelayChannelHalf<HvsockConnectRequest, HvsockConnectResult>;
195pub type HvsockRelayChannel = RelayChannel<HvsockConnectRequest, HvsockConnectResult>;
196
197#[derive(Debug, Copy, Clone)]
206pub struct ModifyRelayRequest {
207 pub version: Option<u32>,
208 pub monitor_page: Update<MonitorPageGpas>,
209 pub use_interrupt_page: Option<bool>,
210}
211
212#[derive(Debug, Copy, Clone)]
214pub enum ModifyRelayResponse {
215 Supported(protocol::ConnectionState, protocol::FeatureFlags),
219 Unsupported,
222 Modified(protocol::ConnectionState),
225}
226
227impl From<ModifyConnectionRequest> for ModifyRelayRequest {
228 fn from(value: ModifyConnectionRequest) -> Self {
229 Self {
230 version: value.version.map(|v| v.version as u32),
231 monitor_page: value.monitor_page,
232 use_interrupt_page: match value.interrupt_page {
233 Update::Unchanged => None,
234 Update::Reset => Some(false),
235 Update::Set(_) => Some(true),
236 },
237 }
238 }
239}
240
241#[derive(Debug)]
242enum VmbusRequest {
243 Reset(Rpc<(), ()>),
244 Inspect(inspect::Deferred),
245 Save(Rpc<(), SavedState>),
246 Restore(Rpc<Box<SavedState>, Result<(), RestoreError>>),
247 Start,
248 Stop(Rpc<(), ()>),
249}
250
251#[derive(mesh::MeshPayload, Debug)]
252pub struct OfferInfo {
253 pub params: OfferParamsInternal,
254 pub event: Interrupt,
255 pub request_send: mesh::Sender<ChannelRequest>,
256 pub server_request_recv: mesh::Receiver<ChannelServerRequest>,
257}
258
259#[expect(clippy::large_enum_variant)]
260#[derive(mesh::MeshPayload)]
261pub(crate) enum OfferRequest {
262 Offer(FailableRpc<OfferInfo, ()>),
263 ForceReset(Rpc<(), ()>),
264}
265
266struct ChannelEvent(Interrupt);
267
268impl EventPort for ChannelEvent {
269 fn handle_event(&self, _flag: u16) {
270 self.0.deliver();
271 }
272
273 fn os_event(&self) -> Option<&Event> {
274 self.0.event()
275 }
276}
277
278#[derive(Debug, Protobuf, SavedStateRoot)]
279#[mesh(package = "vmbus.server")]
280pub struct SavedState {
281 #[mesh(1)]
282 pub server: channels::SavedState,
283 #[mesh(2)]
287 pub lost_synic_bug_fixed: bool,
288}
289
290const MESSAGE_CONNECTION_ID: u32 = 1;
291const MULTICLIENT_MESSAGE_CONNECTION_ID: u32 = 4;
292
293impl<T: SpawnDriver + Clone> VmbusServerBuilder<T> {
294 pub fn new(spawner: T, synic: Arc<dyn SynicPortAccess>, gm: GuestMemory) -> Self {
296 Self {
297 spawner,
298 synic,
299 gm,
300 private_gm: None,
301 vtl: Vtl::Vtl0,
302 hvsock_notify: None,
303 server_relay: None,
304 saved_state_notify: None,
305 external_server: None,
306 external_requests: None,
307 use_message_redirect: false,
308 channel_id_offset: 0,
309 max_version: None,
310 delay_max_version: false,
311 max_restore_version: None,
312 enable_mnf: false,
313 force_confidential_external_memory: false,
314 send_messages_while_stopped: false,
315 channel_unstick_delay: Some(Duration::from_millis(100)),
316 use_absolute_channel_order: false,
317 }
318 }
319
320 pub fn private_gm(mut self, private_gm: Option<GuestMemory>) -> Self {
324 self.private_gm = private_gm;
325 self
326 }
327
328 pub fn vtl(mut self, vtl: Vtl) -> Self {
330 self.vtl = vtl;
331 self
332 }
333
334 pub fn hvsock_notify(mut self, hvsock_notify: Option<HvsockServerChannelHalf>) -> Self {
336 self.hvsock_notify = hvsock_notify;
337 self
338 }
339
340 pub fn saved_state_notify(
342 mut self,
343 saved_state_notify: Option<mesh::Sender<SavedStateRequest>>,
344 ) -> Self {
345 self.saved_state_notify = saved_state_notify;
346 self
347 }
348
349 pub fn server_relay(mut self, server_relay: Option<VmbusServerChannelHalf>) -> Self {
352 self.server_relay = server_relay;
353 self
354 }
355
356 pub fn external_requests(
358 mut self,
359 external_requests: Option<mesh::Receiver<InitiateContactRequest>>,
360 ) -> Self {
361 self.external_requests = external_requests;
362 self
363 }
364
365 pub fn external_server(
368 mut self,
369 external_server: Option<mesh::Sender<InitiateContactRequest>>,
370 ) -> Self {
371 self.external_server = external_server;
372 self
373 }
374
375 pub fn use_message_redirect(mut self, use_message_redirect: bool) -> Self {
377 self.use_message_redirect = use_message_redirect;
378 self
379 }
380
381 pub fn enable_channel_id_offset(mut self, enable: bool) -> Self {
386 self.channel_id_offset = if enable { 1024 } else { 0 };
387 self
388 }
389
390 pub fn max_version(mut self, max_version: Option<MaxVersionInfo>) -> Self {
394 self.max_version = max_version;
395 self
396 }
397
398 pub fn delay_max_version(mut self, delay: bool) -> Self {
403 self.delay_max_version = delay;
404 self
405 }
406
407 pub fn max_restore_version(mut self, max_restore_version: Option<MaxVersionInfo>) -> Self {
413 self.max_restore_version = max_restore_version;
414 self
415 }
416
417 pub fn enable_mnf(mut self, enable: bool) -> Self {
421 self.enable_mnf = enable;
422 self
423 }
424
425 pub fn force_confidential_external_memory(mut self, force: bool) -> Self {
428 self.force_confidential_external_memory = force;
429 self
430 }
431
432 pub fn send_messages_while_stopped(mut self, send: bool) -> Self {
439 self.send_messages_while_stopped = send;
440 self
441 }
442
443 pub fn channel_unstick_delay(mut self, delay: Option<Duration>) -> Self {
450 self.channel_unstick_delay = delay;
451 self
452 }
453
454 pub fn use_absolute_channel_order(mut self, assign: bool) -> Self {
458 self.use_absolute_channel_order = assign;
459 self
460 }
461
462 pub fn build(self) -> anyhow::Result<VmbusServer> {
467 #[expect(clippy::disallowed_methods)] let (message_send, message_recv) = mpsc::channel(64);
469 let message_sender = Arc::new(MessageSender {
470 send: message_send.clone(),
471 multiclient: self.use_message_redirect,
472 });
473
474 let (redirect_vtl, redirect_sint) = if self.use_message_redirect {
475 (REDIRECT_VTL, REDIRECT_SINT)
476 } else {
477 (self.vtl, VMBUS_SINT)
478 };
479
480 let connection_id = if self.vtl == Vtl::Vtl0 && !self.use_message_redirect {
483 MESSAGE_CONNECTION_ID
484 } else {
485 VmbusServer::get_child_message_connection_id(0, redirect_sint, redirect_vtl)
488 };
489
490 let _message_port = self
491 .synic
492 .add_message_port(connection_id, redirect_vtl, message_sender)
493 .context("failed to create vmbus synic ports")?;
494
495 let _multiclient_message_port = if self.vtl == Vtl::Vtl0 && !self.use_message_redirect {
499 let multiclient_message_sender = Arc::new(MessageSender {
500 send: message_send,
501 multiclient: true,
502 });
503
504 Some(
505 self.synic
506 .add_message_port(
507 MULTICLIENT_MESSAGE_CONNECTION_ID,
508 self.vtl,
509 multiclient_message_sender,
510 )
511 .context("failed to create vmbus synic ports")?,
512 )
513 } else {
514 None
515 };
516
517 let (offer_send, offer_recv) = mesh::mpsc_channel();
518 let control = Arc::new(VmbusServerControl {
519 mem: self.gm.clone(),
520 private_mem: self.private_gm.clone(),
521 send: offer_send,
522 use_event: self.synic.prefer_os_events(),
523 force_confidential_external_memory: self.force_confidential_external_memory,
524 });
525
526 let mut server = channels::Server::new(
527 self.vtl,
528 connection_id,
529 self.channel_id_offset,
530 self.use_absolute_channel_order,
531 );
532
533 server.set_require_server_allocated_mnf(self.enable_mnf && self.private_gm.is_some());
538
539 if let Some(version) = self.max_version {
541 server.set_compatibility_version(version, self.delay_max_version);
542 }
543
544 if let Some(version) = self.max_restore_version {
545 server.set_restore_compatibility_version(version);
546 }
547
548 let (relay_request_send, relay_response_recv) =
549 if let Some(server_relay) = self.server_relay {
550 let r = server_relay.response_receive.boxed().fuse();
551 (server_relay.request_send, r)
552 } else {
553 let (req_send, req_recv) = mesh::channel();
554 let resp_recv = req_recv
555 .map(|req: ModifyRelayRequest| {
556 if req.version.is_some() {
558 ModifyRelayResponse::Supported(
559 protocol::ConnectionState::SUCCESSFUL,
560 protocol::FeatureFlags::from_bits(u32::MAX),
561 )
562 } else {
563 ModifyRelayResponse::Modified(protocol::ConnectionState::SUCCESSFUL)
564 }
565 })
566 .boxed()
567 .fuse();
568 (req_send, resp_recv)
569 };
570
571 let (hvsock_send, hvsock_recv) = if let Some(hvsock_notify) = self.hvsock_notify {
573 let r = hvsock_notify.response_receive.boxed().fuse();
574 (hvsock_notify.request_send, r)
575 } else {
576 let (req_send, req_recv) = mesh::channel();
577 let resp_recv = req_recv
578 .map(|r: HvsockConnectRequest| HvsockConnectResult::from_request(&r, false))
579 .boxed()
580 .fuse();
581 (req_send, resp_recv)
582 };
583
584 let inner = ServerTaskInner {
585 running: false,
586 send_messages_while_stopped: self.send_messages_while_stopped,
587 gm: self.gm,
588 private_gm: self.private_gm,
589 vtl: self.vtl,
590 redirect_vtl,
591 redirect_sint,
592 message_port: self
593 .synic
594 .new_guest_message_port(redirect_vtl, 0, redirect_sint)?,
595 synic: self.synic,
596 hvsock_requests: 0,
597 hvsock_send,
598 saved_state_notify: self.saved_state_notify,
599 channels: HashMap::new(),
600 channel_responses: FuturesUnordered::new(),
601 relay_send: relay_request_send,
602 external_server_send: self.external_server,
603 channel_bitmap: None,
604 shared_event_port: None,
605 reset_done: Vec::new(),
606 mnf_support: self.enable_mnf.then(MnfSupport::default),
607 };
608
609 let (task_send, task_recv) = mesh::channel();
610 let mut server_task = ServerTask {
611 driver: Box::new(self.spawner.clone()),
612 server,
613 task_recv,
614 offer_recv,
615 message_recv,
616 server_request_recv: SelectAll::new(),
617 inner,
618 external_requests: self.external_requests,
619 next_seq: 0,
620 perform_post_restore_on_start: false,
621 unstick_on_start: false,
622 channel_unstickers: FuturesUnordered::new(),
623 channel_unstick_delay: self.channel_unstick_delay,
624 };
625
626 let task = self.spawner.spawn("vmbus server", async move {
627 server_task.run(relay_response_recv, hvsock_recv).await;
628 server_task
629 });
630
631 Ok(VmbusServer {
632 task_send,
633 control,
634 _message_port,
635 _multiclient_message_port,
636 task,
637 })
638 }
639}
640
641impl VmbusServer {
642 pub fn builder<T: SpawnDriver + Clone>(
644 spawner: T,
645 synic: Arc<dyn SynicPortAccess>,
646 gm: GuestMemory,
647 ) -> VmbusServerBuilder<T> {
648 VmbusServerBuilder::new(spawner, synic, gm)
649 }
650
651 pub async fn save(&self) -> SavedState {
652 self.task_send.call(VmbusRequest::Save, ()).await.unwrap()
653 }
654
655 pub async fn restore(&self, state: SavedState) -> Result<(), RestoreError> {
656 self.task_send
657 .call(VmbusRequest::Restore, Box::new(state))
658 .await
659 .unwrap()
660 }
661
662 pub async fn stop(&self) {
664 self.task_send.call(VmbusRequest::Stop, ()).await.unwrap()
665 }
666
667 pub fn start(&self) {
669 self.task_send.send(VmbusRequest::Start);
670 }
671
672 pub async fn reset(&self) {
674 tracing::debug!("resetting channel state");
675 self.task_send.call(VmbusRequest::Reset, ()).await.unwrap()
676 }
677
678 pub async fn shutdown(self) {
680 drop(self.task_send);
681 let _ = self.task.await;
682 }
683
684 pub fn control(&self) -> Arc<VmbusServerControl> {
686 self.control.clone()
687 }
688
689 fn get_child_message_connection_id(vp_index: u32, sint_index: u8, vtl: Vtl) -> u32 {
692 MULTICLIENT_MESSAGE_CONNECTION_ID
693 | (vtl as u32) << 22
694 | vp_index << 8
695 | (sint_index as u32) << 4
696 }
697
698 fn get_child_event_port_id(channel_id: protocol::ChannelId, sint_index: u8, vtl: Vtl) -> u32 {
699 EVENT_PORT_ID | (vtl as u32) << 22 | channel_id.0 << 8 | (sint_index as u32) << 4
700 }
701}
702
703#[derive(mesh::MeshPayload)]
704pub struct RestoreInfo {
705 open_data: Option<OpenData>,
706 gpadls: Vec<(GpadlId, u16, Vec<u64>)>,
707 interrupt: Option<Interrupt>,
708}
709
710#[derive(Default)]
711pub struct SynicMessage {
712 data: Vec<u8>,
713 multiclient: bool,
714 trusted: bool,
715}
716
717#[derive(Default)]
719struct MnfSupport {
720 allocated_monitor_page: Option<MonitorPageGpas>,
721}
722
723#[derive(Debug, Clone, Copy)]
725struct OfferInstanceId {
726 offer_id: OfferId,
727 seq: u64,
728}
729
730struct ServerTask {
731 driver: Box<dyn Driver>,
732 server: channels::Server,
733 task_recv: mesh::Receiver<VmbusRequest>,
734 offer_recv: mesh::Receiver<OfferRequest>,
735 message_recv: mpsc::Receiver<SynicMessage>,
736 server_request_recv:
737 SelectAll<TaggedStream<OfferInstanceId, mesh::Receiver<ChannelServerRequest>>>,
738 inner: ServerTaskInner,
739 external_requests: Option<mesh::Receiver<InitiateContactRequest>>,
740 next_seq: u64,
742 perform_post_restore_on_start: bool,
743 unstick_on_start: bool,
744 channel_unstickers: FuturesUnordered<Pin<Box<dyn Send + Future<Output = OfferInstanceId>>>>,
745 channel_unstick_delay: Option<Duration>,
746}
747
748struct ServerTaskInner {
749 running: bool,
750 send_messages_while_stopped: bool,
751 gm: GuestMemory,
752 private_gm: Option<GuestMemory>,
753 synic: Arc<dyn SynicPortAccess>,
754 vtl: Vtl,
755 redirect_vtl: Vtl,
756 redirect_sint: u8,
757 message_port: Box<dyn GuestMessagePort>,
758 hvsock_requests: usize,
759 hvsock_send: mesh::Sender<HvsockConnectRequest>,
760 saved_state_notify: Option<mesh::Sender<SavedStateRequest>>,
761 channels: HashMap<OfferId, Channel>,
762 channel_responses: FuturesUnordered<
763 Pin<Box<dyn Send + Future<Output = (OfferId, u64, Result<ChannelResponse, RpcError>)>>>,
764 >,
765 external_server_send: Option<mesh::Sender<InitiateContactRequest>>,
766 relay_send: mesh::Sender<ModifyRelayRequest>,
767 channel_bitmap: Option<Arc<ChannelBitmap>>,
768 shared_event_port: Option<Box<dyn Send>>,
769 reset_done: Vec<Rpc<(), ()>>,
770 mnf_support: Option<MnfSupport>,
773}
774
775#[derive(Debug)]
776enum ChannelResponse {
777 Open(bool),
778 Close,
779 Gpadl(GpadlId, bool),
780 TeardownGpadl(GpadlId),
781 Modify(i32),
782}
783
784#[derive(Debug, Copy, Clone, PartialEq, Eq)]
785enum ChannelUnstickState {
786 None,
787 Queued,
788 NeedsRequeue,
789}
790
791struct Channel {
792 key: OfferKey,
793 send: mesh::Sender<ChannelRequest>,
794 seq: u64,
795 state: ChannelState,
796 gpadls: Arc<GpadlMap>,
797 guest_to_host_event: Arc<ChannelEvent>,
798 flags: protocol::OfferFlags,
799 reserved_state: ReservedState,
804 unstick_state: ChannelUnstickState,
805}
806
807struct ReservedState {
808 message_port: Option<Box<dyn GuestMessagePort>>,
809 target: ConnectionTarget,
810}
811
812struct ChannelOpenState {
813 open_params: OpenParams,
814 _event_port: Box<dyn Send>,
815 guest_event_port: Option<Box<dyn GuestEventPort>>,
816 host_to_guest_interrupt: Interrupt,
817}
818
819impl ChannelOpenState {
820 fn set_event_port_target_vp(&mut self, vp: u32) -> anyhow::Result<()> {
821 let Some(guest_event_port) = self.guest_event_port.as_mut() else {
822 anyhow::bail!("cannot set target VP if the channel interrupt is disabled");
823 };
824
825 guest_event_port.set_target_vp(vp)?;
826 Ok(())
827 }
828}
829
830enum ChannelState {
831 Closed,
832 Open(Box<ChannelOpenState>),
833 Closing,
834}
835
836impl ServerTask {
837 fn handle_offer(&mut self, mut info: OfferInfo) -> anyhow::Result<()> {
838 let key = info.params.key();
839 let flags = info.params.flags;
840
841 if self.inner.mnf_support.is_some() && self.inner.synic.monitor_support().is_some() {
842 if info.params.use_mnf.is_relayed() {
847 info.params.use_mnf = MnfUsage::Enabled {
848 latency: Duration::ZERO,
849 }
850 }
851 } else if info.params.use_mnf.is_enabled() {
852 info.params.use_mnf = MnfUsage::Disabled;
855 }
856
857 let offer_id = self
858 .server
859 .with_notifier(&mut self.inner)
860 .offer_channel(info.params)
861 .context("channel offer failed")?;
862
863 tracing::debug!(?offer_id, %key, "offered channel");
864
865 let seq = self.next_seq;
866 self.next_seq += 1;
867 self.inner.channels.insert(
868 offer_id,
869 Channel {
870 key,
871 send: info.request_send,
872 state: ChannelState::Closed,
873 gpadls: GpadlMap::new(),
874 guest_to_host_event: Arc::new(ChannelEvent(info.event)),
875 seq,
876 flags,
877 reserved_state: ReservedState {
878 message_port: None,
879 target: ConnectionTarget { vp: 0, sint: 0 },
880 },
881 unstick_state: ChannelUnstickState::None,
882 },
883 );
884
885 self.server_request_recv.push(TaggedStream::new(
886 OfferInstanceId { offer_id, seq },
887 info.server_request_recv,
888 ));
889
890 Ok(())
891 }
892
893 fn handle_revoke(&mut self, id: OfferInstanceId) {
894 if let Some(channel) = self.inner.channels.get(&id.offer_id) {
897 if channel.seq == id.seq {
898 tracing::info!(?id.offer_id, key = %channel.key, "revoking channel");
899 self.inner.channels.remove(&id.offer_id);
900 self.server
901 .with_notifier(&mut self.inner)
902 .revoke_channel(id.offer_id);
903 }
904 }
905 }
906
907 fn handle_response(
908 &mut self,
909 offer_id: OfferId,
910 seq: u64,
911 response: Result<ChannelResponse, RpcError>,
912 ) {
913 let channel = self
915 .inner
916 .channels
917 .get(&offer_id)
918 .filter(|channel| channel.seq == seq);
919
920 if let Some(channel) = channel {
921 match response {
922 Ok(response) => match response {
923 ChannelResponse::Open(result) => self.handle_open(offer_id, result),
924 ChannelResponse::Close => self.handle_close(offer_id),
925 ChannelResponse::Gpadl(gpadl_id, ok) => {
926 self.handle_gpadl_create(offer_id, gpadl_id, ok)
927 }
928 ChannelResponse::TeardownGpadl(gpadl_id) => {
929 self.handle_gpadl_teardown(offer_id, gpadl_id)
930 }
931 ChannelResponse::Modify(status) => self.handle_modify_channel(offer_id, status),
932 },
933 Err(err) => {
934 tracing::error!(
935 key = %channel.key,
936 error = &err as &dyn std::error::Error,
937 "channel response failure, channel is in inconsistent state until revoked"
938 );
939 }
940 }
941 } else {
942 tracing::debug!(offer_id = ?offer_id, seq, ?response, "received response after revoke");
943 }
944 }
945
946 fn handle_open(&mut self, offer_id: OfferId, success: bool) {
947 let status = if success {
948 let channel = self
949 .inner
950 .channels
951 .get_mut(&offer_id)
952 .expect("channel exists");
953
954 if let Some(delay) = self.channel_unstick_delay {
957 if channel.unstick_state == ChannelUnstickState::None {
958 channel.unstick_state = ChannelUnstickState::Queued;
959 let seq = channel.seq;
960 let mut timer = PolledTimer::new(&self.driver);
961 self.channel_unstickers.push(Box::pin(async move {
962 timer.sleep(delay).await;
963 OfferInstanceId { offer_id, seq }
964 }));
965 } else {
966 channel.unstick_state = ChannelUnstickState::NeedsRequeue;
967 }
968 }
969
970 0
971 } else {
972 protocol::STATUS_UNSUCCESSFUL
973 };
974
975 self.server
976 .with_notifier(&mut self.inner)
977 .open_complete(offer_id, status);
978 }
979
980 fn handle_close(&mut self, offer_id: OfferId) {
981 let channel = self
982 .inner
983 .channels
984 .get_mut(&offer_id)
985 .expect("channel still exists");
986
987 match &mut channel.state {
988 ChannelState::Closing => {
989 tracing::debug!(?offer_id, key = %channel.key, "closing channel");
990 channel.state = ChannelState::Closed;
991 self.server
992 .with_notifier(&mut self.inner)
993 .close_complete(offer_id);
994 }
995 _ => {
996 tracing::error!(?offer_id, key = %channel.key, "invalid close channel response");
997 }
998 };
999 }
1000
1001 fn handle_gpadl_create(&mut self, offer_id: OfferId, gpadl_id: GpadlId, ok: bool) {
1002 let status = if ok { 0 } else { protocol::STATUS_UNSUCCESSFUL };
1003 self.server
1004 .with_notifier(&mut self.inner)
1005 .gpadl_create_complete(offer_id, gpadl_id, status);
1006 }
1007
1008 fn handle_gpadl_teardown(&mut self, offer_id: OfferId, gpadl_id: GpadlId) {
1009 self.server
1010 .with_notifier(&mut self.inner)
1011 .gpadl_teardown_complete(offer_id, gpadl_id);
1012 }
1013
1014 fn handle_modify_channel(&mut self, offer_id: OfferId, status: i32) {
1015 self.server
1016 .with_notifier(&mut self.inner)
1017 .modify_channel_complete(offer_id, status);
1018 }
1019
1020 fn handle_restore_channel(
1021 &mut self,
1022 offer_id: OfferId,
1023 open: bool,
1024 ) -> anyhow::Result<RestoreResult> {
1025 let gpadls = self.server.channel_gpadls(offer_id);
1026
1027 let open_request = open
1030 .then(|| -> anyhow::Result<_> {
1031 let params = self.server.get_restore_open_params(offer_id)?;
1032 let (channel, interrupt) = self.inner.open_channel(offer_id, ¶ms)?;
1033 Ok(OpenRequest::new(
1034 params.open_data,
1035 interrupt,
1036 self.server
1037 .get_version()
1038 .expect("must be connected")
1039 .feature_flags,
1040 channel.flags,
1041 ))
1042 })
1043 .transpose()?;
1044
1045 self.server
1046 .with_notifier(&mut self.inner)
1047 .restore_channel(offer_id, open_request.is_some())?;
1048
1049 let channel = self.inner.channels.get_mut(&offer_id).unwrap();
1050 for gpadl in &gpadls {
1051 if let Ok(buf) = MultiPagedRangeBuf::from_range_buffer(
1052 gpadl.request.count.into(),
1053 gpadl.request.buf.clone(),
1054 ) {
1055 channel.gpadls.add(gpadl.request.id, buf);
1056 }
1057 }
1058
1059 let result = RestoreResult {
1060 open_request,
1061 gpadls,
1062 };
1063 Ok(result)
1064 }
1065
1066 async fn handle_request(&mut self, request: VmbusRequest) {
1067 tracing::debug!(?request, "handle_request");
1068 match request {
1069 VmbusRequest::Reset(rpc) => self.handle_reset(rpc),
1070 VmbusRequest::Inspect(deferred) => {
1071 deferred.respond(|resp| {
1072 resp.field("message_port", &self.inner.message_port)
1073 .field("running", self.inner.running)
1074 .field("hvsock_requests", self.inner.hvsock_requests)
1075 .field("channel_unstick_delay", self.channel_unstick_delay)
1076 .field_mut_with("unstick_channels", |v| {
1077 let v: inspect::ValueKind = if let Some(v) = v {
1078 if v == "force" {
1079 self.unstick_channels(true);
1080 v.into()
1081 } else {
1082 let v =
1083 v.parse().ok().context("expected false, true, or force")?;
1084 if v {
1085 self.unstick_channels(false);
1086 }
1087 v.into()
1088 }
1089 } else {
1090 false.into()
1091 };
1092 anyhow::Ok(v)
1093 })
1094 .merge(&self.server.with_notifier(&mut self.inner));
1095 });
1096 }
1097 VmbusRequest::Save(rpc) => rpc.handle_sync(|()| SavedState {
1098 server: self.server.save(),
1099 lost_synic_bug_fixed: true,
1100 }),
1101 VmbusRequest::Restore(rpc) => {
1102 rpc.handle(async |state| {
1103 self.unstick_on_start = !state.lost_synic_bug_fixed;
1104 self.perform_post_restore_on_start = true;
1105 if let Some(sender) = &self.inner.saved_state_notify {
1106 tracing::trace!("sending saved state to proxy");
1107 if let Err(err) = sender
1108 .call_failable(SavedStateRequest::Set, Box::new(state.server.clone()))
1109 .await
1110 {
1111 tracing::error!(
1112 err = &err as &dyn std::error::Error,
1113 "failed to restore proxy saved state"
1114 );
1115 return Err(RestoreError::ServerError(err.into()));
1116 }
1117 }
1118
1119 self.server
1120 .with_notifier(&mut self.inner)
1121 .restore(state.server)
1122 })
1123 .await
1124 }
1125 VmbusRequest::Stop(rpc) => rpc.handle_sync(|()| {
1126 if self.inner.running {
1127 self.inner.running = false;
1128 }
1129 }),
1130 VmbusRequest::Start => {
1131 if !self.inner.running {
1132 self.inner.running = true;
1133 if self.perform_post_restore_on_start {
1134 if let Some(sender) = self.inner.saved_state_notify.as_ref() {
1135 tracing::trace!("sending clear saved state message to proxy");
1138 sender
1139 .call(SavedStateRequest::Clear, ())
1140 .await
1141 .expect("failed to clear proxy saved state");
1142 }
1143
1144 self.server
1145 .with_notifier(&mut self.inner)
1146 .revoke_unclaimed_channels();
1147
1148 self.perform_post_restore_on_start = false;
1149 }
1150
1151 if self.unstick_on_start {
1152 tracing::info!(
1153 "lost synic bug fix is not in yet, call unstick_channels to mitigate the issue."
1154 );
1155 self.unstick_channels(false);
1156 self.unstick_on_start = false;
1157 }
1158 }
1159 }
1160 }
1161 }
1162
1163 fn handle_reset(&mut self, rpc: Rpc<(), ()>) {
1164 let needs_reset = self.inner.reset_done.is_empty();
1165 self.inner.reset_done.push(rpc);
1166 if needs_reset {
1167 self.server.with_notifier(&mut self.inner).reset();
1168 }
1169 }
1170
1171 fn handle_relay_response(&mut self, response: ModifyRelayResponse) {
1172 let response = match response {
1174 ModifyRelayResponse::Supported(state, features) => {
1175 let allocated_monitor_gpas = self
1178 .inner
1179 .mnf_support
1180 .as_ref()
1181 .and_then(|mnf| mnf.allocated_monitor_page);
1182
1183 ModifyConnectionResponse::Supported(state, features, allocated_monitor_gpas)
1184 }
1185 ModifyRelayResponse::Unsupported => ModifyConnectionResponse::Unsupported,
1186 ModifyRelayResponse::Modified(state) => ModifyConnectionResponse::Modified(state),
1187 };
1188
1189 self.server
1190 .with_notifier(&mut self.inner)
1191 .complete_modify_connection(response);
1192 }
1193
1194 fn handle_tl_connect_result(&mut self, result: HvsockConnectResult) {
1195 tracing::debug!(?result, "hvsock connect result");
1196 assert_ne!(self.inner.hvsock_requests, 0);
1197 self.inner.hvsock_requests -= 1;
1198
1199 self.server
1200 .with_notifier(&mut self.inner)
1201 .send_tl_connect_result(result);
1202 }
1203
1204 fn handle_synic_message(&mut self, message: SynicMessage) {
1205 match self
1206 .server
1207 .with_notifier(&mut self.inner)
1208 .handle_synic_message(message)
1209 {
1210 Ok(()) => {}
1211 Err(err) => {
1212 tracing::warn!(
1213 error = &err as &dyn std::error::Error,
1214 "synic message error"
1215 );
1216 }
1217 }
1218 }
1219
1220 fn handle_external_request(&mut self, request: InitiateContactRequest) {
1227 self.server
1228 .with_notifier(&mut self.inner)
1229 .initiate_contact(request);
1230 }
1231
1232 async fn run(
1233 &mut self,
1234 mut relay_response_recv: impl futures::stream::FusedStream<Item = ModifyRelayResponse> + Unpin,
1235 mut hvsock_recv: impl futures::stream::FusedStream<Item = HvsockConnectResult> + Unpin,
1236 ) {
1237 loop {
1238 let running_not_resetting = self.inner.running && self.inner.reset_done.is_empty();
1243 let mut external_requests = OptionFuture::from(
1244 running_not_resetting
1245 .then(|| {
1246 self.external_requests
1247 .as_mut()
1248 .map(|r| r.select_next_some())
1249 })
1250 .flatten(),
1251 );
1252
1253 let has_pending_messages = self.server.has_pending_messages();
1255 let message_port = self.inner.message_port.as_mut();
1256 let mut flush_pending_messages =
1257 OptionFuture::from((running_not_resetting && has_pending_messages).then(|| {
1258 poll_fn(|cx| {
1259 self.server.poll_flush_pending_messages(|msg| {
1260 message_port.poll_post_message(cx, VMBUS_MESSAGE_TYPE, msg.data())
1261 })
1262 })
1263 .fuse()
1264 }));
1265
1266 let mut message_recv = OptionFuture::from(
1270 (running_not_resetting
1271 && !has_pending_messages
1272 && self.inner.hvsock_requests < MAX_CONCURRENT_HVSOCK_REQUESTS)
1273 .then(|| self.message_recv.select_next_some()),
1274 );
1275
1276 let mut channel_response = OptionFuture::from(
1278 (self.inner.running || !self.inner.reset_done.is_empty())
1279 .then(|| self.inner.channel_responses.select_next_some()),
1280 );
1281
1282 let mut hvsock_response =
1284 OptionFuture::from(running_not_resetting.then(|| hvsock_recv.select_next_some()));
1285
1286 let mut channel_unstickers = OptionFuture::from(
1287 running_not_resetting.then(|| self.channel_unstickers.select_next_some()),
1288 );
1289
1290 futures::select! { r = self.task_recv.recv().fuse() => {
1292 if let Ok(request) = r {
1293 self.handle_request(request).await;
1294 } else {
1295 break;
1296 }
1297 }
1298 r = self.offer_recv.select_next_some() => {
1299 match r {
1300 OfferRequest::Offer(rpc) => {
1301 rpc.handle_failable_sync(|request| { self.handle_offer(request) })
1302 },
1303 OfferRequest::ForceReset(rpc) => {
1304 self.handle_reset(rpc);
1305 }
1306 }
1307 }
1308 r = self.server_request_recv.select_next_some() => {
1309 match r {
1310 (id, Some(request)) => match request {
1311 ChannelServerRequest::Restore(rpc) => rpc.handle_failable_sync(|open| {
1312 self.handle_restore_channel(id.offer_id, open)
1313 }),
1314 ChannelServerRequest::Revoke(rpc) => rpc.handle_sync(|_| {
1315 self.handle_revoke(id);
1316 })
1317 },
1318 (id, None) => self.handle_revoke(id),
1319 }
1320 }
1321 r = channel_response => {
1322 let (id, seq, response) = r.unwrap();
1323 self.handle_response(id, seq, response);
1324 }
1325 r = relay_response_recv.select_next_some() => {
1326 self.handle_relay_response(r);
1327 },
1328 r = hvsock_response => {
1329 self.handle_tl_connect_result(r.unwrap());
1330 }
1331 data = message_recv => {
1332 let data = data.unwrap();
1333 self.handle_synic_message(data);
1334 }
1335 r = external_requests => {
1336 let r = r.unwrap();
1337 self.handle_external_request(r);
1338 }
1339 r = channel_unstickers => {
1340 self.unstick_channel_by_id(r.unwrap());
1341 }
1342 _r = flush_pending_messages => {}
1343 complete => break,
1344 }
1345 }
1346 }
1347
1348 fn unstick_channels(&self, force: bool) {
1352 let Some(version) = self.server.get_version() else {
1353 tracing::warn!("cannot unstick when not connected");
1354 return;
1355 };
1356
1357 for channel in self.inner.channels.values() {
1358 let gm = self.inner.get_gm_for_channel(version, channel);
1359 if let Err(err) = Self::unstick_channel(gm, channel, force, true) {
1360 tracing::warn!(
1361 channel = %channel.key,
1362 error = err.as_ref() as &dyn std::error::Error,
1363 "could not unstick channel"
1364 );
1365 }
1366 }
1367 }
1368
1369 fn unstick_channel_by_id(&mut self, id: OfferInstanceId) {
1372 let Some(version) = self.server.get_version() else {
1373 tracelimit::warn_ratelimited!("cannot unstick when not connected");
1374 return;
1375 };
1376
1377 if let Some(channel) = self.inner.channels.get_mut(&id.offer_id) {
1378 if channel.seq != id.seq {
1379 return;
1381 }
1382
1383 if channel.unstick_state == ChannelUnstickState::NeedsRequeue {
1386 channel.unstick_state = ChannelUnstickState::Queued;
1387 let mut timer = PolledTimer::new(&self.driver);
1388 let delay = self.channel_unstick_delay.unwrap();
1389 self.channel_unstickers.push(Box::pin(async move {
1390 timer.sleep(delay).await;
1391 id
1392 }));
1393
1394 return;
1395 }
1396
1397 channel.unstick_state = ChannelUnstickState::None;
1398 let gm = select_gm_for_channel(
1399 &self.inner.gm,
1400 self.inner.private_gm.as_ref(),
1401 version,
1402 channel,
1403 );
1404 if let Err(err) = Self::unstick_channel(gm, channel, false, false) {
1405 tracelimit::warn_ratelimited!(
1406 channel = %channel.key,
1407 error = err.as_ref() as &dyn std::error::Error,
1408 "could not unstick channel"
1409 );
1410 }
1411 }
1412 }
1413
1414 fn unstick_channel(
1415 gm: &GuestMemory,
1416 channel: &Channel,
1417 force: bool,
1418 unstick_host: bool,
1419 ) -> anyhow::Result<()> {
1420 if let ChannelState::Open(state) = &channel.state {
1421 if force {
1422 tracing::info!(channel = %channel.key, "waking host and guest");
1423 if unstick_host {
1424 channel.guest_to_host_event.0.deliver();
1425 }
1426 state.host_to_guest_interrupt.deliver();
1427 return Ok(());
1428 }
1429
1430 let gpadl = channel
1431 .gpadls
1432 .clone()
1433 .view()
1434 .map(state.open_params.open_data.ring_gpadl_id)
1435 .context("couldn't find ring gpadl")?;
1436
1437 let aligned = AlignedGpadlView::new(gpadl)
1438 .ok()
1439 .context("ring not aligned")?;
1440 let (in_gpadl, out_gpadl) = aligned
1441 .split(state.open_params.open_data.ring_offset)
1442 .ok()
1443 .context("couldn't split ring")?;
1444
1445 if let Err(err) = Self::unstick_incoming_ring(
1446 gm,
1447 channel,
1448 in_gpadl,
1449 unstick_host.then_some(channel.guest_to_host_event.as_ref()),
1450 &state.host_to_guest_interrupt,
1451 ) {
1452 tracelimit::warn_ratelimited!(
1453 channel = %channel.key,
1454 error = err.as_ref() as &dyn std::error::Error,
1455 "could not unstick incoming ring"
1456 );
1457 }
1458 if let Err(err) = Self::unstick_outgoing_ring(
1459 gm,
1460 channel,
1461 out_gpadl,
1462 unstick_host.then_some(channel.guest_to_host_event.as_ref()),
1463 &state.host_to_guest_interrupt,
1464 ) {
1465 tracelimit::warn_ratelimited!(
1466 channel = %channel.key,
1467 error = err.as_ref() as &dyn std::error::Error,
1468 "could not unstick outgoing ring"
1469 );
1470 }
1471 }
1472 Ok(())
1473 }
1474
1475 fn unstick_incoming_ring(
1476 gm: &GuestMemory,
1477 channel: &Channel,
1478 in_gpadl: AlignedGpadlView,
1479 guest_to_host_event: Option<&ChannelEvent>,
1480 host_to_guest_interrupt: &Interrupt,
1481 ) -> anyhow::Result<()> {
1482 let control_page = lock_gpn_with_subrange(gm, in_gpadl.gpns()[0])?;
1483 if let Some(guest_to_host_event) = guest_to_host_event {
1484 if ring::reader_needs_signal(control_page.pages()[0]) {
1485 tracelimit::info_ratelimited!(channel = %channel.key, "waking host for incoming ring");
1486 guest_to_host_event.0.deliver();
1487 }
1488 }
1489
1490 let ring_size = gpadl_ring_size(&in_gpadl).try_into()?;
1491 if ring::writer_needs_signal(control_page.pages()[0], ring_size) {
1492 tracelimit::info_ratelimited!(channel = %channel.key, "waking guest for incoming ring");
1493 host_to_guest_interrupt.deliver();
1494 }
1495 Ok(())
1496 }
1497
1498 fn unstick_outgoing_ring(
1499 gm: &GuestMemory,
1500 channel: &Channel,
1501 out_gpadl: AlignedGpadlView,
1502 guest_to_host_event: Option<&ChannelEvent>,
1503 host_to_guest_interrupt: &Interrupt,
1504 ) -> anyhow::Result<()> {
1505 let control_page = lock_gpn_with_subrange(gm, out_gpadl.gpns()[0])?;
1506 if ring::reader_needs_signal(control_page.pages()[0]) {
1507 tracelimit::info_ratelimited!(channel = %channel.key, "waking guest for outgoing ring");
1508 host_to_guest_interrupt.deliver();
1509 }
1510
1511 if let Some(guest_to_host_event) = guest_to_host_event {
1512 let ring_size = gpadl_ring_size(&out_gpadl).try_into()?;
1513 if ring::writer_needs_signal(control_page.pages()[0], ring_size) {
1514 tracelimit::info_ratelimited!(channel = %channel.key, "waking host for outgoing ring");
1515 guest_to_host_event.0.deliver();
1516 }
1517 }
1518 Ok(())
1519 }
1520}
1521
1522impl Notifier for ServerTaskInner {
1523 fn notify(&mut self, offer_id: OfferId, action: channels::Action) {
1524 let channel = self
1525 .channels
1526 .get_mut(&offer_id)
1527 .expect("channel does not exist");
1528
1529 fn handle<I: 'static + Send, R: 'static + Send>(
1530 offer_id: OfferId,
1531 channel: &Channel,
1532 req: impl FnOnce(Rpc<I, R>) -> ChannelRequest,
1533 input: I,
1534 f: impl 'static + Send + FnOnce(R) -> ChannelResponse,
1535 ) -> Pin<Box<dyn Send + Future<Output = (OfferId, u64, Result<ChannelResponse, RpcError>)>>>
1536 {
1537 let recv = channel.send.call(req, input);
1538 let seq = channel.seq;
1539 Box::pin(async move {
1540 let r = recv.await.map(f);
1541 (offer_id, seq, r)
1542 })
1543 }
1544
1545 let response = match action {
1546 channels::Action::Open(open_params, version) => {
1547 let seq = channel.seq;
1548 let key = channel.key;
1549 match self.open_channel(offer_id, &open_params) {
1550 Ok((channel, interrupt)) => handle(
1551 offer_id,
1552 channel,
1553 ChannelRequest::Open,
1554 OpenRequest::new(
1555 open_params.open_data,
1556 interrupt,
1557 version.feature_flags,
1558 channel.flags,
1559 ),
1560 ChannelResponse::Open,
1561 ),
1562 Err(err) => {
1563 tracelimit::error_ratelimited!(
1564 err = err.as_ref() as &dyn std::error::Error,
1565 ?offer_id,
1566 %key,
1567 "could not open channel",
1568 );
1569
1570 Box::pin(future::ready((
1573 offer_id,
1574 seq,
1575 Ok(ChannelResponse::Open(false)),
1576 )))
1577 }
1578 }
1579 }
1580 channels::Action::Close => {
1581 if let Some(channel_bitmap) = self.channel_bitmap.as_ref() {
1582 if let ChannelState::Open(ref state) = channel.state {
1583 channel_bitmap.unregister_channel(state.open_params.event_flag);
1584 }
1585 }
1586
1587 channel.state = ChannelState::Closing;
1588 handle(offer_id, channel, ChannelRequest::Close, (), |()| {
1589 ChannelResponse::Close
1590 })
1591 }
1592 channels::Action::Gpadl(gpadl_id, count, buf) => {
1593 channel.gpadls.add(
1594 gpadl_id,
1595 MultiPagedRangeBuf::from_range_buffer(count.into(), buf.clone()).unwrap(),
1596 );
1597 handle(
1598 offer_id,
1599 channel,
1600 ChannelRequest::Gpadl,
1601 GpadlRequest {
1602 id: gpadl_id,
1603 count,
1604 buf,
1605 },
1606 move |r| ChannelResponse::Gpadl(gpadl_id, r),
1607 )
1608 }
1609 channels::Action::TeardownGpadl {
1610 gpadl_id,
1611 post_restore,
1612 } => {
1613 if !post_restore {
1614 channel.gpadls.remove(gpadl_id, Box::new(|| ()));
1615 }
1616
1617 handle(
1618 offer_id,
1619 channel,
1620 ChannelRequest::TeardownGpadl,
1621 gpadl_id,
1622 move |()| ChannelResponse::TeardownGpadl(gpadl_id),
1623 )
1624 }
1625 channels::Action::Modify { target_vp } => {
1626 let ChannelState::Open(state) = &mut channel.state else {
1627 unreachable!();
1628 };
1629
1630 if let Err(err) = state.set_event_port_target_vp(target_vp) {
1631 tracelimit::error_ratelimited!(
1632 error = err.as_ref() as &dyn std::error::Error,
1633 channel = %channel.key,
1634 "could not modify channel",
1635 );
1636
1637 let seq = channel.seq;
1639 Box::pin(async move {
1640 (
1641 offer_id,
1642 seq,
1643 Ok(ChannelResponse::Modify(protocol::STATUS_UNSUCCESSFUL)),
1644 )
1645 })
1646 } else {
1647 handle(
1648 offer_id,
1649 channel,
1650 ChannelRequest::Modify,
1651 ModifyRequest::TargetVp { target_vp },
1652 ChannelResponse::Modify,
1653 )
1654 }
1655 }
1656 };
1657 self.channel_responses.push(response);
1658 }
1659
1660 fn modify_connection(&mut self, mut request: ModifyConnectionRequest) -> anyhow::Result<()> {
1661 self.map_interrupt_page(request.interrupt_page)
1662 .context("Failed to map interrupt page.")?;
1663
1664 self.set_monitor_page(&mut request)
1665 .context("Failed to map monitor page.")?;
1666
1667 if let Some(vp) = request.target_message_vp {
1668 self.message_port.set_target_vp(vp)?;
1669 }
1670
1671 if request.notify_relay {
1672 self.relay_send.send(request.into());
1673 }
1674
1675 Ok(())
1676 }
1677
1678 fn forward_unhandled(&mut self, request: InitiateContactRequest) {
1679 if let Some(external_server) = &self.external_server_send {
1680 external_server.send(request);
1681 } else {
1682 tracing::warn!(?request, "nowhere to forward unhandled request")
1683 }
1684 }
1685
1686 fn inspect(&self, version: Option<VersionInfo>, offer_id: OfferId, req: inspect::Request<'_>) {
1687 let channel = self.channels.get(&offer_id).expect("should exist");
1688 let mut resp = req.respond();
1689 if let (ChannelState::Open(state), Some(version)) = (&channel.state, version) {
1694 let mem = self.get_gm_for_channel(version, channel);
1695 inspect_rings(
1696 &mut resp,
1697 mem,
1698 channel.gpadls.clone(),
1699 &state.open_params.open_data,
1700 );
1701 }
1702 }
1703
1704 fn send_message(&mut self, message: &OutgoingMessage, target: MessageTarget) -> bool {
1705 if !self.running && !self.send_messages_while_stopped {
1718 if !matches!(target, MessageTarget::Default) {
1719 tracelimit::error_ratelimited!(?target, "dropping message while paused");
1720 }
1721 return false;
1722 }
1723
1724 let mut port_storage;
1725 let port = match target {
1726 MessageTarget::Default => self.message_port.as_mut(),
1727 MessageTarget::ReservedChannel(offer_id, target) => {
1728 if let Some(port) = self.get_reserved_channel_message_port(offer_id, target) {
1729 port.as_mut()
1730 } else {
1731 return true;
1733 }
1734 }
1735 MessageTarget::Custom(target) => {
1736 port_storage = match self.synic.new_guest_message_port(
1737 self.redirect_vtl,
1738 target.vp,
1739 target.sint,
1740 ) {
1741 Ok(port) => port,
1742 Err(err) => {
1743 tracing::error!(
1744 ?err,
1745 ?self.redirect_vtl,
1746 ?target,
1747 "could not create message port"
1748 );
1749
1750 return true;
1752 }
1753 };
1754 port_storage.as_mut()
1755 }
1756 };
1757
1758 matches!(
1761 port.poll_post_message(
1762 &mut std::task::Context::from_waker(std::task::Waker::noop()),
1763 VMBUS_MESSAGE_TYPE,
1764 message.data()
1765 ),
1766 Poll::Ready(())
1767 )
1768 }
1769
1770 fn notify_hvsock(&mut self, request: &HvsockConnectRequest) {
1771 tracing::debug!(?request, "received hvsock connect request");
1772 self.hvsock_requests += 1;
1773 self.hvsock_send.send(*request);
1774 }
1775
1776 fn reset_complete(&mut self) {
1777 if let Some(monitor) = self.synic.monitor_support() {
1778 if let Err(err) = monitor.set_monitor_page(self.vtl, None) {
1779 tracing::warn!(?err, "resetting monitor page failed")
1780 }
1781 }
1782
1783 self.unreserve_channels();
1784 for done in self.reset_done.drain(..) {
1785 done.complete(());
1786 }
1787 }
1788
1789 fn unload_complete(&mut self) {
1790 self.unreserve_channels();
1791 }
1792}
1793
1794impl ServerTaskInner {
1795 fn open_channel(
1796 &mut self,
1797 offer_id: OfferId,
1798 open_params: &OpenParams,
1799 ) -> anyhow::Result<(&mut Channel, Interrupt)> {
1800 let channel = self
1801 .channels
1802 .get_mut(&offer_id)
1803 .expect("channel does not exist");
1804
1805 if let Some(channel_bitmap) = self.channel_bitmap.as_ref() {
1807 channel_bitmap.register_channel(
1808 open_params.event_flag,
1809 channel.guest_to_host_event.0.clone(),
1810 );
1811 }
1812 let event_port = self
1817 .synic
1818 .add_event_port(
1819 open_params.connection_id,
1820 self.vtl,
1821 channel.guest_to_host_event.clone(),
1822 open_params.monitor_info,
1823 )
1824 .context("failed to create guest-to-host event port")?;
1825
1826 let (target_vp, event_flag) = if self.channel_bitmap.is_some() {
1829 (Some(0), 0)
1830 } else {
1831 (open_params.open_data.target_vp, open_params.event_flag)
1832 };
1833
1834 let (guest_event_port, interrupt) = if let Some(target_vp) = target_vp {
1835 let (target_vtl, target_sint) = if open_params.flags.redirect_interrupt() {
1836 (self.redirect_vtl, self.redirect_sint)
1837 } else {
1838 (self.vtl, VMBUS_SINT)
1839 };
1840
1841 let guest_event_port = self.synic.new_guest_event_port(
1842 VmbusServer::get_child_event_port_id(open_params.channel_id, VMBUS_SINT, self.vtl),
1843 target_vtl,
1844 target_vp,
1845 target_sint,
1846 event_flag,
1847 open_params.monitor_info,
1848 )?;
1849
1850 let interrupt = ChannelBitmap::create_interrupt(
1851 &self.channel_bitmap,
1852 guest_event_port.interrupt(),
1853 open_params.event_flag,
1854 );
1855
1856 (Some(guest_event_port), interrupt)
1857 } else {
1858 (None, Interrupt::null())
1860 };
1861
1862 channel.reserved_state.message_port = None;
1864
1865 if let Some(target) = open_params.reserved_target {
1867 channel.reserved_state.message_port = Some(self.synic.new_guest_message_port(
1868 self.redirect_vtl,
1869 target.vp,
1870 target.sint,
1871 )?);
1872
1873 channel.reserved_state.target = target;
1874 }
1875
1876 channel.state = ChannelState::Open(Box::new(ChannelOpenState {
1877 open_params: *open_params,
1878 _event_port: event_port,
1879 guest_event_port,
1880 host_to_guest_interrupt: interrupt.clone(),
1881 }));
1882 Ok((channel, interrupt))
1883 }
1884
1885 fn map_interrupt_page(&mut self, interrupt_page: Update<u64>) -> anyhow::Result<()> {
1888 let interrupt_page = match interrupt_page {
1889 Update::Unchanged => return Ok(()),
1890 Update::Reset => {
1891 self.channel_bitmap = None;
1892 self.shared_event_port = None;
1893 return Ok(());
1894 }
1895 Update::Set(interrupt_page) => interrupt_page,
1896 };
1897
1898 assert_ne!(interrupt_page, 0);
1899
1900 if interrupt_page % PAGE_SIZE as u64 != 0 {
1901 anyhow::bail!("interrupt page {:#x} is not page aligned", interrupt_page);
1902 }
1903
1904 let interrupt_page = lock_page_with_subrange(&self.gm, interrupt_page)?;
1907 let channel_bitmap = Arc::new(ChannelBitmap::new(interrupt_page));
1908 self.channel_bitmap = Some(channel_bitmap.clone());
1909
1910 let interrupt = Interrupt::from_fn(move || {
1912 channel_bitmap.handle_shared_interrupt();
1913 });
1914
1915 self.shared_event_port = Some(self.synic.add_event_port(
1916 SHARED_EVENT_CONNECTION_ID,
1917 self.vtl,
1918 Arc::new(ChannelEvent(interrupt)),
1919 None,
1920 )?);
1921
1922 Ok(())
1923 }
1924
1925 fn set_monitor_page(&mut self, request: &mut ModifyConnectionRequest) -> anyhow::Result<()> {
1926 let monitor_page = match request.monitor_page {
1927 Update::Unchanged => return Ok(()),
1928 Update::Reset => None,
1929 Update::Set(value) => Some(value),
1930 };
1931
1932 if self.channels.iter().any(|(_, c)| {
1934 matches!(
1935 &c.state,
1936 ChannelState::Open(state) if state.open_params.monitor_info.is_some()
1937 )
1938 }) {
1939 anyhow::bail!("attempt to change monitor page while open channels using mnf");
1940 }
1941
1942 if let Some(mnf_support) = self.mnf_support.as_mut() {
1946 if let Some(monitor) = self.synic.monitor_support() {
1947 mnf_support.allocated_monitor_page = None;
1948
1949 if let Some(version) = request.version {
1950 if version.feature_flags.server_specified_monitor_pages() {
1951 if let Some(monitor_page) = monitor.allocate_monitor_page(self.vtl)? {
1952 tracelimit::info_ratelimited!(
1953 ?monitor_page,
1954 "using server-allocated monitor pages"
1955 );
1956 mnf_support.allocated_monitor_page = Some(monitor_page);
1957 }
1958 }
1959 }
1960
1961 if mnf_support.allocated_monitor_page.is_none() {
1963 if let Err(err) = monitor.set_monitor_page(self.vtl, monitor_page) {
1964 anyhow::bail!(
1965 "setting monitor page failed, err = {err:?}, monitor_page = {monitor_page:?}"
1966 );
1967 }
1968 }
1969 }
1970
1971 request.monitor_page = Update::Unchanged;
1974 }
1975
1976 Ok(())
1977 }
1978
1979 fn get_reserved_channel_message_port(
1980 &mut self,
1981 offer_id: OfferId,
1982 new_target: ConnectionTarget,
1983 ) -> Option<&mut Box<dyn GuestMessagePort>> {
1984 let channel = self
1985 .channels
1986 .get_mut(&offer_id)
1987 .expect("channel does not exist");
1988
1989 assert!(
1990 channel.reserved_state.message_port.is_some(),
1991 "channel is not reserved"
1992 );
1993
1994 if channel.reserved_state.target.sint != new_target.sint {
1997 channel.reserved_state.message_port = None;
1999 let message_port = self
2000 .synic
2001 .new_guest_message_port(self.redirect_vtl, new_target.vp, new_target.sint)
2002 .inspect_err(|err| {
2003 tracing::error!(
2004 key = %channel.key,
2005 ?err,
2006 ?self.redirect_vtl,
2007 ?new_target,
2008 "could not create reserved channel message port"
2009 )
2010 })
2011 .ok()?;
2012
2013 channel.reserved_state.message_port = Some(message_port);
2014 channel.reserved_state.target = new_target;
2015 } else if channel.reserved_state.target.vp != new_target.vp {
2016 let message_port = channel.reserved_state.message_port.as_mut().unwrap();
2017
2018 if let Err(err) = message_port.set_target_vp(new_target.vp) {
2021 tracing::error!(
2022 key = %channel.key,
2023 ?err,
2024 ?self.redirect_vtl,
2025 ?new_target,
2026 "could not update reserved channel message port"
2027 );
2028 }
2029
2030 channel.reserved_state.target = new_target;
2031 return Some(message_port);
2032 }
2033
2034 Some(channel.reserved_state.message_port.as_mut().unwrap())
2035 }
2036
2037 fn unreserve_channels(&mut self) {
2038 for channel in self.channels.values_mut() {
2040 if let ChannelState::Closed = channel.state {
2041 channel.reserved_state.message_port = None;
2042 }
2043 }
2044 }
2045
2046 fn get_gm_for_channel(&self, version: VersionInfo, channel: &Channel) -> &GuestMemory {
2047 select_gm_for_channel(&self.gm, self.private_gm.as_ref(), version, channel)
2048 }
2049}
2050
2051fn select_gm_for_channel<'a>(
2052 gm: &'a GuestMemory,
2053 private_gm: Option<&'a GuestMemory>,
2054 version: VersionInfo,
2055 channel: &Channel,
2056) -> &'a GuestMemory {
2057 if channel.flags.confidential_ring_buffer() && version.feature_flags.confidential_channels() {
2058 if let Some(private_gm) = private_gm {
2059 return private_gm;
2060 }
2061 }
2062
2063 gm
2064}
2065
2066#[derive(Clone)]
2068pub struct VmbusServerControl {
2069 mem: GuestMemory,
2070 private_mem: Option<GuestMemory>,
2071 send: mesh::Sender<OfferRequest>,
2072 use_event: bool,
2073 force_confidential_external_memory: bool,
2074}
2075
2076impl VmbusServerControl {
2077 pub async fn offer_core(&self, offer_info: OfferInfo) -> anyhow::Result<OfferResources> {
2080 let flags = offer_info.params.flags;
2081 self.send
2082 .call_failable(OfferRequest::Offer, offer_info)
2083 .await?;
2084 Ok(OfferResources::new(
2085 self.mem.clone(),
2086 if flags.confidential_ring_buffer() || flags.confidential_external_memory() {
2087 self.private_mem.clone()
2088 } else {
2089 None
2090 },
2091 ))
2092 }
2093
2094 pub async fn force_reset(&self) -> anyhow::Result<()> {
2097 self.send
2098 .call(OfferRequest::ForceReset, ())
2099 .await
2100 .context("vmbus server is gone")
2101 }
2102
2103 async fn offer(&self, request: OfferInput) -> anyhow::Result<OfferResources> {
2104 let mut offer_info = OfferInfo {
2105 params: request.params.into(),
2106 event: request.event,
2107 request_send: request.request_send,
2108 server_request_recv: request.server_request_recv,
2109 };
2110
2111 if self.force_confidential_external_memory {
2112 tracing::warn!(
2113 key = %offer_info.params.key(),
2114 "forcing confidential external memory for channel"
2115 );
2116
2117 offer_info
2118 .params
2119 .flags
2120 .set_confidential_external_memory(true);
2121 }
2122
2123 self.offer_core(offer_info).await
2124 }
2125}
2126
2127fn inspect_rings(
2129 resp: &mut inspect::Response<'_>,
2130 gm: &GuestMemory,
2131 gpadl_map: Arc<GpadlMap>,
2132 open_data: &OpenData,
2133) -> Option<()> {
2134 let gpadl = gpadl_map
2135 .view()
2136 .map(GpadlId(open_data.ring_gpadl_id.0))
2137 .ok()?;
2138
2139 let aligned = AlignedGpadlView::new(gpadl).ok()?;
2140 let (in_gpadl, out_gpadl) = aligned.split(open_data.ring_offset).ok()?;
2141 resp.child("incoming_ring", |req| inspect_ring(req, &in_gpadl, gm));
2142 resp.child("outgoing_ring", |req| inspect_ring(req, &out_gpadl, gm));
2143 Some(())
2144}
2145
2146fn inspect_ring(req: inspect::Request<'_>, gpadl: &AlignedGpadlView, gm: &GuestMemory) {
2148 let mut resp = req.respond();
2149
2150 resp.hex("ring_size", gpadl_ring_size(gpadl));
2151
2152 if let Ok(pages) = lock_gpn_with_subrange(gm, gpadl.gpns()[0]) {
2155 ring::inspect_ring(pages.pages()[0], &mut resp);
2156 }
2157}
2158
2159fn gpadl_ring_size(gpadl: &AlignedGpadlView) -> usize {
2160 (gpadl.gpns().len() - 1) * PAGE_SIZE
2162}
2163
2164fn lock_page_with_subrange(gm: &GuestMemory, offset: u64) -> anyhow::Result<guestmem::LockedPages> {
2169 Ok(gm
2170 .lockable_subrange(offset, PAGE_SIZE as u64)?
2171 .lock_gpns(false, &[0])?)
2172}
2173
2174fn lock_gpn_with_subrange(gm: &GuestMemory, gpn: u64) -> anyhow::Result<guestmem::LockedPages> {
2179 lock_page_with_subrange(gm, gpn * PAGE_SIZE as u64)
2180}
2181
2182pub(crate) struct MessageSender {
2183 send: mpsc::Sender<SynicMessage>,
2184 multiclient: bool,
2185}
2186
2187impl MessageSender {
2188 fn poll_handle_message(
2189 &self,
2190 cx: &mut std::task::Context<'_>,
2191 msg: &[u8],
2192 trusted: bool,
2193 ) -> Poll<Result<(), SendError>> {
2194 let mut send = self.send.clone();
2195 ready!(send.poll_ready(cx))?;
2196 send.start_send(SynicMessage {
2197 data: msg.to_vec(),
2198 multiclient: self.multiclient,
2199 trusted,
2200 })?;
2201
2202 Poll::Ready(Ok(()))
2203 }
2204}
2205
2206impl MessagePort for MessageSender {
2207 fn poll_handle_message(
2208 &self,
2209 cx: &mut std::task::Context<'_>,
2210 msg: &[u8],
2211 trusted: bool,
2212 ) -> Poll<()> {
2213 if let Err(err) = ready!(self.poll_handle_message(cx, msg, trusted)) {
2214 tracelimit::error_ratelimited!(
2215 error = &err as &dyn std::error::Error,
2216 "failed to send message"
2217 );
2218 }
2219
2220 Poll::Ready(())
2221 }
2222}
2223
2224#[async_trait]
2225impl ParentBus for VmbusServerControl {
2226 async fn add_child(&self, request: OfferInput) -> anyhow::Result<OfferResources> {
2227 self.offer(request).await
2228 }
2229
2230 fn clone_bus(&self) -> Box<dyn ParentBus> {
2231 Box::new(self.clone())
2232 }
2233
2234 fn use_event(&self) -> bool {
2235 self.use_event
2236 }
2237}