1#![expect(missing_docs)]
5#![forbid(unsafe_code)]
6
7mod channel_bitmap;
8pub mod channels;
9pub mod event;
10pub mod hvsock;
11mod monitor;
12mod proxyintegration;
13#[cfg(test)]
14mod tests;
15
16pub type Guid = guid::Guid;
18
19use anyhow::Context;
20use async_trait::async_trait;
21use channel_bitmap::ChannelBitmap;
22use channels::ConnectionTarget;
23pub use channels::InitiateContactRequest;
24use channels::MessageTarget;
25pub use channels::MnfUsage;
26use channels::ModifyConnectionRequest;
27use channels::ModifyConnectionResponse;
28use channels::Notifier;
29use channels::OfferId;
30pub use channels::OfferParamsInternal;
31use channels::OpenParams;
32use channels::RestoreError;
33pub use channels::Update;
34use futures::FutureExt;
35use futures::StreamExt;
36use futures::channel::mpsc;
37use futures::channel::mpsc::SendError;
38use futures::future::OptionFuture;
39use futures::future::poll_fn;
40use futures::stream::SelectAll;
41use guestmem::GuestMemory;
42use hvdef::Vtl;
43use inspect::Inspect;
44use mesh::payload::Protobuf;
45use mesh::rpc::FailableRpc;
46use mesh::rpc::Rpc;
47use mesh::rpc::RpcError;
48use mesh::rpc::RpcSend;
49use pal_async::driver::Driver;
50use pal_async::driver::SpawnDriver;
51use pal_async::task::Task;
52use pal_async::timer::PolledTimer;
53use pal_event::Event;
54#[cfg(windows)]
55pub use proxyintegration::ProxyIntegration;
56#[cfg(windows)]
57pub use proxyintegration::ProxyServerInfo;
58use ring::PAGE_SIZE;
59use std::collections::HashMap;
60use std::future;
61use std::future::Future;
62use std::pin::Pin;
63use std::sync::Arc;
64use std::task::Poll;
65use std::task::ready;
66use std::time::Duration;
67use unicycle::FuturesUnordered;
68use vmbus_channel::bus::ChannelRequest;
69use vmbus_channel::bus::ChannelServerRequest;
70use vmbus_channel::bus::GpadlRequest;
71use vmbus_channel::bus::ModifyRequest;
72use vmbus_channel::bus::OfferInput;
73use vmbus_channel::bus::OfferKey;
74use vmbus_channel::bus::OfferResources;
75use vmbus_channel::bus::OpenData;
76use vmbus_channel::bus::OpenRequest;
77use vmbus_channel::bus::ParentBus;
78use vmbus_channel::bus::RestoreResult;
79use vmbus_channel::gpadl::GpadlMap;
80use vmbus_channel::gpadl_ring::AlignedGpadlView;
81use vmbus_core::HvsockConnectRequest;
82use vmbus_core::HvsockConnectResult;
83use vmbus_core::MaxVersionInfo;
84use vmbus_core::OutgoingMessage;
85use vmbus_core::TaggedStream;
86use vmbus_core::VMBUS_SINT;
87use vmbus_core::VersionInfo;
88use vmbus_core::protocol;
89pub use vmbus_core::protocol::GpadlId;
90#[cfg(windows)]
91use vmbus_proxy::ProxyHandle;
92use vmbus_ring as ring;
93use vmbus_ring::gparange::MultiPagedRangeBuf;
94use vmcore::interrupt::Interrupt;
95use vmcore::save_restore::SavedStateRoot;
96use vmcore::synic::EventPort;
97use vmcore::synic::GuestEventPort;
98use vmcore::synic::GuestMessagePort;
99use vmcore::synic::MessagePort;
100use vmcore::synic::MonitorPageGpas;
101use vmcore::synic::SynicPortAccess;
102
103pub const REDIRECT_SINT: u8 = 7;
104pub const REDIRECT_VTL: Vtl = Vtl::Vtl2;
105const SHARED_EVENT_CONNECTION_ID: u32 = 2;
106const EVENT_PORT_ID: u32 = 2;
107const VMBUS_MESSAGE_TYPE: u32 = 1;
108
109const MAX_CONCURRENT_HVSOCK_REQUESTS: usize = 16;
110
111#[derive(Inspect)]
112pub struct VmbusServer {
113 #[inspect(flatten, send = "VmbusRequest::Inspect")]
114 task_send: mesh::Sender<VmbusRequest>,
115 #[inspect(skip)]
116 control: Arc<VmbusServerControl>,
117 #[inspect(skip)]
118 _message_port: Box<dyn Sync + Send>,
119 #[inspect(skip)]
120 _multiclient_message_port: Option<Box<dyn Sync + Send>>,
121 #[inspect(skip)]
122 task: Task<ServerTask>,
123}
124
125pub struct VmbusServerBuilder<T: SpawnDriver> {
126 spawner: T,
127 synic: Arc<dyn SynicPortAccess>,
128 gm: GuestMemory,
129 private_gm: Option<GuestMemory>,
130 vtl: Vtl,
131 hvsock_notify: Option<HvsockServerChannelHalf>,
132 server_relay: Option<VmbusServerChannelHalf>,
133 saved_state_notify: Option<mesh::Sender<SavedStateRequest>>,
134 external_server: Option<mesh::Sender<InitiateContactRequest>>,
135 external_requests: Option<mesh::Receiver<InitiateContactRequest>>,
136 use_message_redirect: bool,
137 channel_id_offset: u16,
138 max_version: Option<MaxVersionInfo>,
139 delay_max_version: bool,
140 enable_mnf: bool,
141 force_confidential_external_memory: bool,
142 send_messages_while_stopped: bool,
143 channel_unstick_delay: Option<Duration>,
144 use_absolute_channel_order: bool,
145}
146
147#[derive(mesh::MeshPayload)]
148pub enum SavedStateRequest {
150 Set(FailableRpc<Box<channels::SavedState>, ()>),
151 Clear(Rpc<(), ()>),
152}
153
154pub struct ServerChannelHalf<Request, Response> {
156 request_send: mesh::Sender<Request>,
157 response_receive: mesh::Receiver<Response>,
158}
159
160pub struct RelayChannelHalf<Request, Response> {
162 pub request_receive: mesh::Receiver<Request>,
163 pub response_send: mesh::Sender<Response>,
164}
165
166pub struct RelayChannel<Request, Response> {
168 pub relay_half: RelayChannelHalf<Request, Response>,
169 pub server_half: ServerChannelHalf<Request, Response>,
170}
171
172impl<Request: 'static + Send, Response: 'static + Send> RelayChannel<Request, Response> {
173 pub fn new() -> Self {
175 let (request_send, request_receive) = mesh::channel();
176 let (response_send, response_receive) = mesh::channel();
177 Self {
178 relay_half: RelayChannelHalf {
179 request_receive,
180 response_send,
181 },
182 server_half: ServerChannelHalf {
183 request_send,
184 response_receive,
185 },
186 }
187 }
188}
189
190pub type VmbusServerChannelHalf = ServerChannelHalf<ModifyRelayRequest, ModifyRelayResponse>;
191pub type VmbusRelayChannelHalf = RelayChannelHalf<ModifyRelayRequest, ModifyRelayResponse>;
192pub type VmbusRelayChannel = RelayChannel<ModifyRelayRequest, ModifyRelayResponse>;
193pub type HvsockServerChannelHalf = ServerChannelHalf<HvsockConnectRequest, HvsockConnectResult>;
194pub type HvsockRelayChannelHalf = RelayChannelHalf<HvsockConnectRequest, HvsockConnectResult>;
195pub type HvsockRelayChannel = RelayChannel<HvsockConnectRequest, HvsockConnectResult>;
196
197#[derive(Debug, Copy, Clone)]
206pub struct ModifyRelayRequest {
207 pub version: Option<u32>,
208 pub monitor_page: Update<MonitorPageGpas>,
209 pub use_interrupt_page: Option<bool>,
210}
211
212#[derive(Debug, Copy, Clone)]
214pub enum ModifyRelayResponse {
215 Supported(protocol::ConnectionState, protocol::FeatureFlags),
219 Unsupported,
222 Modified(protocol::ConnectionState),
225}
226
227impl From<ModifyConnectionRequest> for ModifyRelayRequest {
228 fn from(value: ModifyConnectionRequest) -> Self {
229 Self {
230 version: value.version.map(|v| v.version as u32),
231 monitor_page: value.monitor_page,
232 use_interrupt_page: match value.interrupt_page {
233 Update::Unchanged => None,
234 Update::Reset => Some(false),
235 Update::Set(_) => Some(true),
236 },
237 }
238 }
239}
240
241#[derive(Debug)]
242enum VmbusRequest {
243 Reset(Rpc<(), ()>),
244 Inspect(inspect::Deferred),
245 Save(Rpc<(), SavedState>),
246 Restore(Rpc<Box<SavedState>, Result<(), RestoreError>>),
247 Start,
248 Stop(Rpc<(), ()>),
249}
250
251#[derive(mesh::MeshPayload, Debug)]
252pub struct OfferInfo {
253 pub params: OfferParamsInternal,
254 pub event: Interrupt,
255 pub request_send: mesh::Sender<ChannelRequest>,
256 pub server_request_recv: mesh::Receiver<ChannelServerRequest>,
257}
258
259#[expect(clippy::large_enum_variant)]
260#[derive(mesh::MeshPayload)]
261pub(crate) enum OfferRequest {
262 Offer(FailableRpc<OfferInfo, ()>),
263 ForceReset(Rpc<(), ()>),
264}
265
266struct ChannelEvent(Interrupt);
267
268impl EventPort for ChannelEvent {
269 fn handle_event(&self, _flag: u16) {
270 self.0.deliver();
271 }
272
273 fn os_event(&self) -> Option<&Event> {
274 self.0.event()
275 }
276}
277
278#[derive(Debug, Protobuf, SavedStateRoot)]
279#[mesh(package = "vmbus.server")]
280pub struct SavedState {
281 #[mesh(1)]
282 pub server: channels::SavedState,
283 #[mesh(2)]
287 pub lost_synic_bug_fixed: bool,
288}
289
290const MESSAGE_CONNECTION_ID: u32 = 1;
291const MULTICLIENT_MESSAGE_CONNECTION_ID: u32 = 4;
292
293impl<T: SpawnDriver + Clone> VmbusServerBuilder<T> {
294 pub fn new(spawner: T, synic: Arc<dyn SynicPortAccess>, gm: GuestMemory) -> Self {
296 Self {
297 spawner,
298 synic,
299 gm,
300 private_gm: None,
301 vtl: Vtl::Vtl0,
302 hvsock_notify: None,
303 server_relay: None,
304 saved_state_notify: None,
305 external_server: None,
306 external_requests: None,
307 use_message_redirect: false,
308 channel_id_offset: 0,
309 max_version: None,
310 delay_max_version: false,
311 enable_mnf: false,
312 force_confidential_external_memory: false,
313 send_messages_while_stopped: false,
314 channel_unstick_delay: Some(Duration::from_millis(100)),
315 use_absolute_channel_order: false,
316 }
317 }
318
319 pub fn private_gm(mut self, private_gm: Option<GuestMemory>) -> Self {
323 self.private_gm = private_gm;
324 self
325 }
326
327 pub fn vtl(mut self, vtl: Vtl) -> Self {
329 self.vtl = vtl;
330 self
331 }
332
333 pub fn hvsock_notify(mut self, hvsock_notify: Option<HvsockServerChannelHalf>) -> Self {
335 self.hvsock_notify = hvsock_notify;
336 self
337 }
338
339 pub fn saved_state_notify(
341 mut self,
342 saved_state_notify: Option<mesh::Sender<SavedStateRequest>>,
343 ) -> Self {
344 self.saved_state_notify = saved_state_notify;
345 self
346 }
347
348 pub fn server_relay(mut self, server_relay: Option<VmbusServerChannelHalf>) -> Self {
351 self.server_relay = server_relay;
352 self
353 }
354
355 pub fn external_requests(
357 mut self,
358 external_requests: Option<mesh::Receiver<InitiateContactRequest>>,
359 ) -> Self {
360 self.external_requests = external_requests;
361 self
362 }
363
364 pub fn external_server(
367 mut self,
368 external_server: Option<mesh::Sender<InitiateContactRequest>>,
369 ) -> Self {
370 self.external_server = external_server;
371 self
372 }
373
374 pub fn use_message_redirect(mut self, use_message_redirect: bool) -> Self {
376 self.use_message_redirect = use_message_redirect;
377 self
378 }
379
380 pub fn enable_channel_id_offset(mut self, enable: bool) -> Self {
385 self.channel_id_offset = if enable { 1024 } else { 0 };
386 self
387 }
388
389 pub fn max_version(mut self, max_version: Option<MaxVersionInfo>) -> Self {
393 self.max_version = max_version;
394 self
395 }
396
397 pub fn delay_max_version(mut self, delay: bool) -> Self {
402 self.delay_max_version = delay;
403 self
404 }
405
406 pub fn enable_mnf(mut self, enable: bool) -> Self {
410 self.enable_mnf = enable;
411 self
412 }
413
414 pub fn force_confidential_external_memory(mut self, force: bool) -> Self {
417 self.force_confidential_external_memory = force;
418 self
419 }
420
421 pub fn send_messages_while_stopped(mut self, send: bool) -> Self {
428 self.send_messages_while_stopped = send;
429 self
430 }
431
432 pub fn channel_unstick_delay(mut self, delay: Option<Duration>) -> Self {
439 self.channel_unstick_delay = delay;
440 self
441 }
442
443 pub fn use_absolute_channel_order(mut self, assign: bool) -> Self {
447 self.use_absolute_channel_order = assign;
448 self
449 }
450
451 pub fn build(self) -> anyhow::Result<VmbusServer> {
456 #[expect(clippy::disallowed_methods)] let (message_send, message_recv) = mpsc::channel(64);
458 let message_sender = Arc::new(MessageSender {
459 send: message_send.clone(),
460 multiclient: self.use_message_redirect,
461 });
462
463 let (redirect_vtl, redirect_sint) = if self.use_message_redirect {
464 (REDIRECT_VTL, REDIRECT_SINT)
465 } else {
466 (self.vtl, VMBUS_SINT)
467 };
468
469 let connection_id = if self.vtl == Vtl::Vtl0 && !self.use_message_redirect {
472 MESSAGE_CONNECTION_ID
473 } else {
474 VmbusServer::get_child_message_connection_id(0, redirect_sint, redirect_vtl)
477 };
478
479 let _message_port = self
480 .synic
481 .add_message_port(connection_id, redirect_vtl, message_sender)
482 .context("failed to create vmbus synic ports")?;
483
484 let _multiclient_message_port = if self.vtl == Vtl::Vtl0 && !self.use_message_redirect {
488 let multiclient_message_sender = Arc::new(MessageSender {
489 send: message_send,
490 multiclient: true,
491 });
492
493 Some(
494 self.synic
495 .add_message_port(
496 MULTICLIENT_MESSAGE_CONNECTION_ID,
497 self.vtl,
498 multiclient_message_sender,
499 )
500 .context("failed to create vmbus synic ports")?,
501 )
502 } else {
503 None
504 };
505
506 let (offer_send, offer_recv) = mesh::mpsc_channel();
507 let control = Arc::new(VmbusServerControl {
508 mem: self.gm.clone(),
509 private_mem: self.private_gm.clone(),
510 send: offer_send,
511 use_event: self.synic.prefer_os_events(),
512 force_confidential_external_memory: self.force_confidential_external_memory,
513 });
514
515 let mut server = channels::Server::new(
516 self.vtl,
517 connection_id,
518 self.channel_id_offset,
519 self.use_absolute_channel_order,
520 );
521
522 server.set_require_server_allocated_mnf(self.enable_mnf && self.private_gm.is_some());
527
528 if let Some(version) = self.max_version {
530 server.set_compatibility_version(version, self.delay_max_version);
531 }
532 let (relay_request_send, relay_response_recv) =
533 if let Some(server_relay) = self.server_relay {
534 let r = server_relay.response_receive.boxed().fuse();
535 (server_relay.request_send, r)
536 } else {
537 let (req_send, req_recv) = mesh::channel();
538 let resp_recv = req_recv
539 .map(|req: ModifyRelayRequest| {
540 if req.version.is_some() {
542 ModifyRelayResponse::Supported(
543 protocol::ConnectionState::SUCCESSFUL,
544 protocol::FeatureFlags::from_bits(u32::MAX),
545 )
546 } else {
547 ModifyRelayResponse::Modified(protocol::ConnectionState::SUCCESSFUL)
548 }
549 })
550 .boxed()
551 .fuse();
552 (req_send, resp_recv)
553 };
554
555 let (hvsock_send, hvsock_recv) = if let Some(hvsock_notify) = self.hvsock_notify {
557 let r = hvsock_notify.response_receive.boxed().fuse();
558 (hvsock_notify.request_send, r)
559 } else {
560 let (req_send, req_recv) = mesh::channel();
561 let resp_recv = req_recv
562 .map(|r: HvsockConnectRequest| HvsockConnectResult::from_request(&r, false))
563 .boxed()
564 .fuse();
565 (req_send, resp_recv)
566 };
567
568 let inner = ServerTaskInner {
569 running: false,
570 send_messages_while_stopped: self.send_messages_while_stopped,
571 gm: self.gm,
572 private_gm: self.private_gm,
573 vtl: self.vtl,
574 redirect_vtl,
575 redirect_sint,
576 message_port: self
577 .synic
578 .new_guest_message_port(redirect_vtl, 0, redirect_sint)?,
579 synic: self.synic,
580 hvsock_requests: 0,
581 hvsock_send,
582 saved_state_notify: self.saved_state_notify,
583 channels: HashMap::new(),
584 channel_responses: FuturesUnordered::new(),
585 relay_send: relay_request_send,
586 external_server_send: self.external_server,
587 channel_bitmap: None,
588 shared_event_port: None,
589 reset_done: Vec::new(),
590 mnf_support: self.enable_mnf.then(MnfSupport::default),
591 };
592
593 let (task_send, task_recv) = mesh::channel();
594 let mut server_task = ServerTask {
595 driver: Box::new(self.spawner.clone()),
596 server,
597 task_recv,
598 offer_recv,
599 message_recv,
600 server_request_recv: SelectAll::new(),
601 inner,
602 external_requests: self.external_requests,
603 next_seq: 0,
604 perform_post_restore_on_start: false,
605 unstick_on_start: false,
606 channel_unstickers: FuturesUnordered::new(),
607 channel_unstick_delay: self.channel_unstick_delay,
608 };
609
610 let task = self.spawner.spawn("vmbus server", async move {
611 server_task.run(relay_response_recv, hvsock_recv).await;
612 server_task
613 });
614
615 Ok(VmbusServer {
616 task_send,
617 control,
618 _message_port,
619 _multiclient_message_port,
620 task,
621 })
622 }
623}
624
625impl VmbusServer {
626 pub fn builder<T: SpawnDriver + Clone>(
628 spawner: T,
629 synic: Arc<dyn SynicPortAccess>,
630 gm: GuestMemory,
631 ) -> VmbusServerBuilder<T> {
632 VmbusServerBuilder::new(spawner, synic, gm)
633 }
634
635 pub async fn save(&self) -> SavedState {
636 self.task_send.call(VmbusRequest::Save, ()).await.unwrap()
637 }
638
639 pub async fn restore(&self, state: SavedState) -> Result<(), RestoreError> {
640 self.task_send
641 .call(VmbusRequest::Restore, Box::new(state))
642 .await
643 .unwrap()
644 }
645
646 pub async fn stop(&self) {
648 self.task_send.call(VmbusRequest::Stop, ()).await.unwrap()
649 }
650
651 pub fn start(&self) {
653 self.task_send.send(VmbusRequest::Start);
654 }
655
656 pub async fn reset(&self) {
658 tracing::debug!("resetting channel state");
659 self.task_send.call(VmbusRequest::Reset, ()).await.unwrap()
660 }
661
662 pub async fn shutdown(self) {
664 drop(self.task_send);
665 let _ = self.task.await;
666 }
667
668 pub fn control(&self) -> Arc<VmbusServerControl> {
670 self.control.clone()
671 }
672
673 fn get_child_message_connection_id(vp_index: u32, sint_index: u8, vtl: Vtl) -> u32 {
676 MULTICLIENT_MESSAGE_CONNECTION_ID
677 | (vtl as u32) << 22
678 | vp_index << 8
679 | (sint_index as u32) << 4
680 }
681
682 fn get_child_event_port_id(channel_id: protocol::ChannelId, sint_index: u8, vtl: Vtl) -> u32 {
683 EVENT_PORT_ID | (vtl as u32) << 22 | channel_id.0 << 8 | (sint_index as u32) << 4
684 }
685}
686
687#[derive(mesh::MeshPayload)]
688pub struct RestoreInfo {
689 open_data: Option<OpenData>,
690 gpadls: Vec<(GpadlId, u16, Vec<u64>)>,
691 interrupt: Option<Interrupt>,
692}
693
694#[derive(Default)]
695pub struct SynicMessage {
696 data: Vec<u8>,
697 multiclient: bool,
698 trusted: bool,
699}
700
701#[derive(Default)]
703struct MnfSupport {
704 allocated_monitor_page: Option<MonitorPageGpas>,
705}
706
707#[derive(Debug, Clone, Copy)]
709struct OfferInstanceId {
710 offer_id: OfferId,
711 seq: u64,
712}
713
714struct ServerTask {
715 driver: Box<dyn Driver>,
716 server: channels::Server,
717 task_recv: mesh::Receiver<VmbusRequest>,
718 offer_recv: mesh::Receiver<OfferRequest>,
719 message_recv: mpsc::Receiver<SynicMessage>,
720 server_request_recv:
721 SelectAll<TaggedStream<OfferInstanceId, mesh::Receiver<ChannelServerRequest>>>,
722 inner: ServerTaskInner,
723 external_requests: Option<mesh::Receiver<InitiateContactRequest>>,
724 next_seq: u64,
726 perform_post_restore_on_start: bool,
727 unstick_on_start: bool,
728 channel_unstickers: FuturesUnordered<Pin<Box<dyn Send + Future<Output = OfferInstanceId>>>>,
729 channel_unstick_delay: Option<Duration>,
730}
731
732struct ServerTaskInner {
733 running: bool,
734 send_messages_while_stopped: bool,
735 gm: GuestMemory,
736 private_gm: Option<GuestMemory>,
737 synic: Arc<dyn SynicPortAccess>,
738 vtl: Vtl,
739 redirect_vtl: Vtl,
740 redirect_sint: u8,
741 message_port: Box<dyn GuestMessagePort>,
742 hvsock_requests: usize,
743 hvsock_send: mesh::Sender<HvsockConnectRequest>,
744 saved_state_notify: Option<mesh::Sender<SavedStateRequest>>,
745 channels: HashMap<OfferId, Channel>,
746 channel_responses: FuturesUnordered<
747 Pin<Box<dyn Send + Future<Output = (OfferId, u64, Result<ChannelResponse, RpcError>)>>>,
748 >,
749 external_server_send: Option<mesh::Sender<InitiateContactRequest>>,
750 relay_send: mesh::Sender<ModifyRelayRequest>,
751 channel_bitmap: Option<Arc<ChannelBitmap>>,
752 shared_event_port: Option<Box<dyn Send>>,
753 reset_done: Vec<Rpc<(), ()>>,
754 mnf_support: Option<MnfSupport>,
757}
758
759#[derive(Debug)]
760enum ChannelResponse {
761 Open(bool),
762 Close,
763 Gpadl(GpadlId, bool),
764 TeardownGpadl(GpadlId),
765 Modify(i32),
766}
767
768#[derive(Debug, Copy, Clone, PartialEq, Eq)]
769enum ChannelUnstickState {
770 None,
771 Queued,
772 NeedsRequeue,
773}
774
775struct Channel {
776 key: OfferKey,
777 send: mesh::Sender<ChannelRequest>,
778 seq: u64,
779 state: ChannelState,
780 gpadls: Arc<GpadlMap>,
781 guest_to_host_event: Arc<ChannelEvent>,
782 flags: protocol::OfferFlags,
783 reserved_state: ReservedState,
788 unstick_state: ChannelUnstickState,
789}
790
791struct ReservedState {
792 message_port: Option<Box<dyn GuestMessagePort>>,
793 target: ConnectionTarget,
794}
795
796struct ChannelOpenState {
797 open_params: OpenParams,
798 _event_port: Box<dyn Send>,
799 guest_event_port: Option<Box<dyn GuestEventPort>>,
800 host_to_guest_interrupt: Interrupt,
801}
802
803impl ChannelOpenState {
804 fn set_event_port_target_vp(&mut self, vp: u32) -> anyhow::Result<()> {
805 let Some(guest_event_port) = self.guest_event_port.as_mut() else {
806 anyhow::bail!("cannot set target VP if the channel interrupt is disabled");
807 };
808
809 guest_event_port.set_target_vp(vp)?;
810 Ok(())
811 }
812}
813
814enum ChannelState {
815 Closed,
816 Open(Box<ChannelOpenState>),
817 Closing,
818}
819
820impl ServerTask {
821 fn handle_offer(&mut self, mut info: OfferInfo) -> anyhow::Result<()> {
822 let key = info.params.key();
823 let flags = info.params.flags;
824
825 if self.inner.mnf_support.is_some() && self.inner.synic.monitor_support().is_some() {
826 if info.params.use_mnf.is_relayed() {
831 info.params.use_mnf = MnfUsage::Enabled {
832 latency: Duration::ZERO,
833 }
834 }
835 } else if info.params.use_mnf.is_enabled() {
836 info.params.use_mnf = MnfUsage::Disabled;
839 }
840
841 let offer_id = self
842 .server
843 .with_notifier(&mut self.inner)
844 .offer_channel(info.params)
845 .context("channel offer failed")?;
846
847 tracing::debug!(?offer_id, %key, "offered channel");
848
849 let seq = self.next_seq;
850 self.next_seq += 1;
851 self.inner.channels.insert(
852 offer_id,
853 Channel {
854 key,
855 send: info.request_send,
856 state: ChannelState::Closed,
857 gpadls: GpadlMap::new(),
858 guest_to_host_event: Arc::new(ChannelEvent(info.event)),
859 seq,
860 flags,
861 reserved_state: ReservedState {
862 message_port: None,
863 target: ConnectionTarget { vp: 0, sint: 0 },
864 },
865 unstick_state: ChannelUnstickState::None,
866 },
867 );
868
869 self.server_request_recv.push(TaggedStream::new(
870 OfferInstanceId { offer_id, seq },
871 info.server_request_recv,
872 ));
873
874 Ok(())
875 }
876
877 fn handle_revoke(&mut self, id: OfferInstanceId) {
878 if let Some(channel) = self.inner.channels.get(&id.offer_id) {
881 if channel.seq == id.seq {
882 tracing::info!(?id.offer_id, key = %channel.key, "revoking channel");
883 self.inner.channels.remove(&id.offer_id);
884 self.server
885 .with_notifier(&mut self.inner)
886 .revoke_channel(id.offer_id);
887 }
888 }
889 }
890
891 fn handle_response(
892 &mut self,
893 offer_id: OfferId,
894 seq: u64,
895 response: Result<ChannelResponse, RpcError>,
896 ) {
897 let channel = self
899 .inner
900 .channels
901 .get(&offer_id)
902 .filter(|channel| channel.seq == seq);
903
904 if let Some(channel) = channel {
905 match response {
906 Ok(response) => match response {
907 ChannelResponse::Open(result) => self.handle_open(offer_id, result),
908 ChannelResponse::Close => self.handle_close(offer_id),
909 ChannelResponse::Gpadl(gpadl_id, ok) => {
910 self.handle_gpadl_create(offer_id, gpadl_id, ok)
911 }
912 ChannelResponse::TeardownGpadl(gpadl_id) => {
913 self.handle_gpadl_teardown(offer_id, gpadl_id)
914 }
915 ChannelResponse::Modify(status) => self.handle_modify_channel(offer_id, status),
916 },
917 Err(err) => {
918 tracing::error!(
919 key = %channel.key,
920 error = &err as &dyn std::error::Error,
921 "channel response failure, channel is in inconsistent state until revoked"
922 );
923 }
924 }
925 } else {
926 tracing::debug!(offer_id = ?offer_id, seq, ?response, "received response after revoke");
927 }
928 }
929
930 fn handle_open(&mut self, offer_id: OfferId, success: bool) {
931 let status = if success {
932 let channel = self
933 .inner
934 .channels
935 .get_mut(&offer_id)
936 .expect("channel exists");
937
938 if let Some(delay) = self.channel_unstick_delay {
941 if channel.unstick_state == ChannelUnstickState::None {
942 channel.unstick_state = ChannelUnstickState::Queued;
943 let seq = channel.seq;
944 let mut timer = PolledTimer::new(&self.driver);
945 self.channel_unstickers.push(Box::pin(async move {
946 timer.sleep(delay).await;
947 OfferInstanceId { offer_id, seq }
948 }));
949 } else {
950 channel.unstick_state = ChannelUnstickState::NeedsRequeue;
951 }
952 }
953
954 0
955 } else {
956 protocol::STATUS_UNSUCCESSFUL
957 };
958
959 self.server
960 .with_notifier(&mut self.inner)
961 .open_complete(offer_id, status);
962 }
963
964 fn handle_close(&mut self, offer_id: OfferId) {
965 let channel = self
966 .inner
967 .channels
968 .get_mut(&offer_id)
969 .expect("channel still exists");
970
971 match &mut channel.state {
972 ChannelState::Closing => {
973 tracing::debug!(?offer_id, key = %channel.key, "closing channel");
974 channel.state = ChannelState::Closed;
975 self.server
976 .with_notifier(&mut self.inner)
977 .close_complete(offer_id);
978 }
979 _ => {
980 tracing::error!(?offer_id, key = %channel.key, "invalid close channel response");
981 }
982 };
983 }
984
985 fn handle_gpadl_create(&mut self, offer_id: OfferId, gpadl_id: GpadlId, ok: bool) {
986 let status = if ok { 0 } else { protocol::STATUS_UNSUCCESSFUL };
987 self.server
988 .with_notifier(&mut self.inner)
989 .gpadl_create_complete(offer_id, gpadl_id, status);
990 }
991
992 fn handle_gpadl_teardown(&mut self, offer_id: OfferId, gpadl_id: GpadlId) {
993 self.server
994 .with_notifier(&mut self.inner)
995 .gpadl_teardown_complete(offer_id, gpadl_id);
996 }
997
998 fn handle_modify_channel(&mut self, offer_id: OfferId, status: i32) {
999 self.server
1000 .with_notifier(&mut self.inner)
1001 .modify_channel_complete(offer_id, status);
1002 }
1003
1004 fn handle_restore_channel(
1005 &mut self,
1006 offer_id: OfferId,
1007 open: bool,
1008 ) -> anyhow::Result<RestoreResult> {
1009 let gpadls = self.server.channel_gpadls(offer_id);
1010
1011 let open_request = open
1014 .then(|| -> anyhow::Result<_> {
1015 let params = self.server.get_restore_open_params(offer_id)?;
1016 let (channel, interrupt) = self.inner.open_channel(offer_id, ¶ms)?;
1017 Ok(OpenRequest::new(
1018 params.open_data,
1019 interrupt,
1020 self.server
1021 .get_version()
1022 .expect("must be connected")
1023 .feature_flags,
1024 channel.flags,
1025 ))
1026 })
1027 .transpose()?;
1028
1029 self.server
1030 .with_notifier(&mut self.inner)
1031 .restore_channel(offer_id, open_request.is_some())?;
1032
1033 let channel = self.inner.channels.get_mut(&offer_id).unwrap();
1034 for gpadl in &gpadls {
1035 if let Ok(buf) = MultiPagedRangeBuf::from_range_buffer(
1036 gpadl.request.count.into(),
1037 gpadl.request.buf.clone(),
1038 ) {
1039 channel.gpadls.add(gpadl.request.id, buf);
1040 }
1041 }
1042
1043 let result = RestoreResult {
1044 open_request,
1045 gpadls,
1046 };
1047 Ok(result)
1048 }
1049
1050 async fn handle_request(&mut self, request: VmbusRequest) {
1051 tracing::debug!(?request, "handle_request");
1052 match request {
1053 VmbusRequest::Reset(rpc) => self.handle_reset(rpc),
1054 VmbusRequest::Inspect(deferred) => {
1055 deferred.respond(|resp| {
1056 resp.field("message_port", &self.inner.message_port)
1057 .field("running", self.inner.running)
1058 .field("hvsock_requests", self.inner.hvsock_requests)
1059 .field("channel_unstick_delay", self.channel_unstick_delay)
1060 .field_mut_with("unstick_channels", |v| {
1061 let v: inspect::ValueKind = if let Some(v) = v {
1062 if v == "force" {
1063 self.unstick_channels(true);
1064 v.into()
1065 } else {
1066 let v =
1067 v.parse().ok().context("expected false, true, or force")?;
1068 if v {
1069 self.unstick_channels(false);
1070 }
1071 v.into()
1072 }
1073 } else {
1074 false.into()
1075 };
1076 anyhow::Ok(v)
1077 })
1078 .merge(&self.server.with_notifier(&mut self.inner));
1079 });
1080 }
1081 VmbusRequest::Save(rpc) => rpc.handle_sync(|()| SavedState {
1082 server: self.server.save(),
1083 lost_synic_bug_fixed: true,
1084 }),
1085 VmbusRequest::Restore(rpc) => {
1086 rpc.handle(async |state| {
1087 self.unstick_on_start = !state.lost_synic_bug_fixed;
1088 self.perform_post_restore_on_start = true;
1089 if let Some(sender) = &self.inner.saved_state_notify {
1090 tracing::trace!("sending saved state to proxy");
1091 if let Err(err) = sender
1092 .call_failable(SavedStateRequest::Set, Box::new(state.server.clone()))
1093 .await
1094 {
1095 tracing::error!(
1096 err = &err as &dyn std::error::Error,
1097 "failed to restore proxy saved state"
1098 );
1099 return Err(RestoreError::ServerError(err.into()));
1100 }
1101 }
1102
1103 self.server
1104 .with_notifier(&mut self.inner)
1105 .restore(state.server)
1106 })
1107 .await
1108 }
1109 VmbusRequest::Stop(rpc) => rpc.handle_sync(|()| {
1110 if self.inner.running {
1111 self.inner.running = false;
1112 }
1113 }),
1114 VmbusRequest::Start => {
1115 if !self.inner.running {
1116 self.inner.running = true;
1117 if self.perform_post_restore_on_start {
1118 if let Some(sender) = self.inner.saved_state_notify.as_ref() {
1119 tracing::trace!("sending clear saved state message to proxy");
1122 sender
1123 .call(SavedStateRequest::Clear, ())
1124 .await
1125 .expect("failed to clear proxy saved state");
1126 }
1127
1128 self.server
1129 .with_notifier(&mut self.inner)
1130 .revoke_unclaimed_channels();
1131
1132 self.perform_post_restore_on_start = false;
1133 }
1134
1135 if self.unstick_on_start {
1136 tracing::info!(
1137 "lost synic bug fix is not in yet, call unstick_channels to mitigate the issue."
1138 );
1139 self.unstick_channels(false);
1140 self.unstick_on_start = false;
1141 }
1142 }
1143 }
1144 }
1145 }
1146
1147 fn handle_reset(&mut self, rpc: Rpc<(), ()>) {
1148 let needs_reset = self.inner.reset_done.is_empty();
1149 self.inner.reset_done.push(rpc);
1150 if needs_reset {
1151 self.server.with_notifier(&mut self.inner).reset();
1152 }
1153 }
1154
1155 fn handle_relay_response(&mut self, response: ModifyRelayResponse) {
1156 let response = match response {
1158 ModifyRelayResponse::Supported(state, features) => {
1159 let allocated_monitor_gpas = self
1162 .inner
1163 .mnf_support
1164 .as_ref()
1165 .and_then(|mnf| mnf.allocated_monitor_page);
1166
1167 ModifyConnectionResponse::Supported(state, features, allocated_monitor_gpas)
1168 }
1169 ModifyRelayResponse::Unsupported => ModifyConnectionResponse::Unsupported,
1170 ModifyRelayResponse::Modified(state) => ModifyConnectionResponse::Modified(state),
1171 };
1172
1173 self.server
1174 .with_notifier(&mut self.inner)
1175 .complete_modify_connection(response);
1176 }
1177
1178 fn handle_tl_connect_result(&mut self, result: HvsockConnectResult) {
1179 tracing::debug!(?result, "hvsock connect result");
1180 assert_ne!(self.inner.hvsock_requests, 0);
1181 self.inner.hvsock_requests -= 1;
1182
1183 self.server
1184 .with_notifier(&mut self.inner)
1185 .send_tl_connect_result(result);
1186 }
1187
1188 fn handle_synic_message(&mut self, message: SynicMessage) {
1189 match self
1190 .server
1191 .with_notifier(&mut self.inner)
1192 .handle_synic_message(message)
1193 {
1194 Ok(()) => {}
1195 Err(err) => {
1196 tracing::warn!(
1197 error = &err as &dyn std::error::Error,
1198 "synic message error"
1199 );
1200 }
1201 }
1202 }
1203
1204 fn handle_external_request(&mut self, request: InitiateContactRequest) {
1211 self.server
1212 .with_notifier(&mut self.inner)
1213 .initiate_contact(request);
1214 }
1215
1216 async fn run(
1217 &mut self,
1218 mut relay_response_recv: impl futures::stream::FusedStream<Item = ModifyRelayResponse> + Unpin,
1219 mut hvsock_recv: impl futures::stream::FusedStream<Item = HvsockConnectResult> + Unpin,
1220 ) {
1221 loop {
1222 let running_not_resetting = self.inner.running && self.inner.reset_done.is_empty();
1227 let mut external_requests = OptionFuture::from(
1228 running_not_resetting
1229 .then(|| {
1230 self.external_requests
1231 .as_mut()
1232 .map(|r| r.select_next_some())
1233 })
1234 .flatten(),
1235 );
1236
1237 let has_pending_messages = self.server.has_pending_messages();
1239 let message_port = self.inner.message_port.as_mut();
1240 let mut flush_pending_messages =
1241 OptionFuture::from((running_not_resetting && has_pending_messages).then(|| {
1242 poll_fn(|cx| {
1243 self.server.poll_flush_pending_messages(|msg| {
1244 message_port.poll_post_message(cx, VMBUS_MESSAGE_TYPE, msg.data())
1245 })
1246 })
1247 .fuse()
1248 }));
1249
1250 let mut message_recv = OptionFuture::from(
1254 (running_not_resetting
1255 && !has_pending_messages
1256 && self.inner.hvsock_requests < MAX_CONCURRENT_HVSOCK_REQUESTS)
1257 .then(|| self.message_recv.select_next_some()),
1258 );
1259
1260 let mut channel_response = OptionFuture::from(
1262 (self.inner.running || !self.inner.reset_done.is_empty())
1263 .then(|| self.inner.channel_responses.select_next_some()),
1264 );
1265
1266 let mut hvsock_response =
1268 OptionFuture::from(running_not_resetting.then(|| hvsock_recv.select_next_some()));
1269
1270 let mut channel_unstickers = OptionFuture::from(
1271 running_not_resetting.then(|| self.channel_unstickers.select_next_some()),
1272 );
1273
1274 futures::select! { r = self.task_recv.recv().fuse() => {
1276 if let Ok(request) = r {
1277 self.handle_request(request).await;
1278 } else {
1279 break;
1280 }
1281 }
1282 r = self.offer_recv.select_next_some() => {
1283 match r {
1284 OfferRequest::Offer(rpc) => {
1285 rpc.handle_failable_sync(|request| { self.handle_offer(request) })
1286 },
1287 OfferRequest::ForceReset(rpc) => {
1288 self.handle_reset(rpc);
1289 }
1290 }
1291 }
1292 r = self.server_request_recv.select_next_some() => {
1293 match r {
1294 (id, Some(request)) => match request {
1295 ChannelServerRequest::Restore(rpc) => rpc.handle_failable_sync(|open| {
1296 self.handle_restore_channel(id.offer_id, open)
1297 }),
1298 ChannelServerRequest::Revoke(rpc) => rpc.handle_sync(|_| {
1299 self.handle_revoke(id);
1300 })
1301 },
1302 (id, None) => self.handle_revoke(id),
1303 }
1304 }
1305 r = channel_response => {
1306 let (id, seq, response) = r.unwrap();
1307 self.handle_response(id, seq, response);
1308 }
1309 r = relay_response_recv.select_next_some() => {
1310 self.handle_relay_response(r);
1311 },
1312 r = hvsock_response => {
1313 self.handle_tl_connect_result(r.unwrap());
1314 }
1315 data = message_recv => {
1316 let data = data.unwrap();
1317 self.handle_synic_message(data);
1318 }
1319 r = external_requests => {
1320 let r = r.unwrap();
1321 self.handle_external_request(r);
1322 }
1323 r = channel_unstickers => {
1324 self.unstick_channel_by_id(r.unwrap());
1325 }
1326 _r = flush_pending_messages => {}
1327 complete => break,
1328 }
1329 }
1330 }
1331
1332 fn unstick_channels(&self, force: bool) {
1336 let Some(version) = self.server.get_version() else {
1337 tracing::warn!("cannot unstick when not connected");
1338 return;
1339 };
1340
1341 for channel in self.inner.channels.values() {
1342 let gm = self.inner.get_gm_for_channel(version, channel);
1343 if let Err(err) = Self::unstick_channel(gm, channel, force, true) {
1344 tracing::warn!(
1345 channel = %channel.key,
1346 error = err.as_ref() as &dyn std::error::Error,
1347 "could not unstick channel"
1348 );
1349 }
1350 }
1351 }
1352
1353 fn unstick_channel_by_id(&mut self, id: OfferInstanceId) {
1356 let Some(version) = self.server.get_version() else {
1357 tracelimit::warn_ratelimited!("cannot unstick when not connected");
1358 return;
1359 };
1360
1361 if let Some(channel) = self.inner.channels.get_mut(&id.offer_id) {
1362 if channel.seq != id.seq {
1363 return;
1365 }
1366
1367 if channel.unstick_state == ChannelUnstickState::NeedsRequeue {
1370 channel.unstick_state = ChannelUnstickState::Queued;
1371 let mut timer = PolledTimer::new(&self.driver);
1372 let delay = self.channel_unstick_delay.unwrap();
1373 self.channel_unstickers.push(Box::pin(async move {
1374 timer.sleep(delay).await;
1375 id
1376 }));
1377
1378 return;
1379 }
1380
1381 channel.unstick_state = ChannelUnstickState::None;
1382 let gm = select_gm_for_channel(
1383 &self.inner.gm,
1384 self.inner.private_gm.as_ref(),
1385 version,
1386 channel,
1387 );
1388 if let Err(err) = Self::unstick_channel(gm, channel, false, false) {
1389 tracelimit::warn_ratelimited!(
1390 channel = %channel.key,
1391 error = err.as_ref() as &dyn std::error::Error,
1392 "could not unstick channel"
1393 );
1394 }
1395 }
1396 }
1397
1398 fn unstick_channel(
1399 gm: &GuestMemory,
1400 channel: &Channel,
1401 force: bool,
1402 unstick_host: bool,
1403 ) -> anyhow::Result<()> {
1404 if let ChannelState::Open(state) = &channel.state {
1405 if force {
1406 tracing::info!(channel = %channel.key, "waking host and guest");
1407 if unstick_host {
1408 channel.guest_to_host_event.0.deliver();
1409 }
1410 state.host_to_guest_interrupt.deliver();
1411 return Ok(());
1412 }
1413
1414 let gpadl = channel
1415 .gpadls
1416 .clone()
1417 .view()
1418 .map(state.open_params.open_data.ring_gpadl_id)
1419 .context("couldn't find ring gpadl")?;
1420
1421 let aligned = AlignedGpadlView::new(gpadl)
1422 .ok()
1423 .context("ring not aligned")?;
1424 let (in_gpadl, out_gpadl) = aligned
1425 .split(state.open_params.open_data.ring_offset)
1426 .ok()
1427 .context("couldn't split ring")?;
1428
1429 if let Err(err) = Self::unstick_incoming_ring(
1430 gm,
1431 channel,
1432 in_gpadl,
1433 unstick_host.then_some(channel.guest_to_host_event.as_ref()),
1434 &state.host_to_guest_interrupt,
1435 ) {
1436 tracelimit::warn_ratelimited!(
1437 channel = %channel.key,
1438 error = err.as_ref() as &dyn std::error::Error,
1439 "could not unstick incoming ring"
1440 );
1441 }
1442 if let Err(err) = Self::unstick_outgoing_ring(
1443 gm,
1444 channel,
1445 out_gpadl,
1446 unstick_host.then_some(channel.guest_to_host_event.as_ref()),
1447 &state.host_to_guest_interrupt,
1448 ) {
1449 tracelimit::warn_ratelimited!(
1450 channel = %channel.key,
1451 error = err.as_ref() as &dyn std::error::Error,
1452 "could not unstick outgoing ring"
1453 );
1454 }
1455 }
1456 Ok(())
1457 }
1458
1459 fn unstick_incoming_ring(
1460 gm: &GuestMemory,
1461 channel: &Channel,
1462 in_gpadl: AlignedGpadlView,
1463 guest_to_host_event: Option<&ChannelEvent>,
1464 host_to_guest_interrupt: &Interrupt,
1465 ) -> anyhow::Result<()> {
1466 let control_page = lock_gpn_with_subrange(gm, in_gpadl.gpns()[0])?;
1467 if let Some(guest_to_host_event) = guest_to_host_event {
1468 if ring::reader_needs_signal(control_page.pages()[0]) {
1469 tracelimit::info_ratelimited!(channel = %channel.key, "waking host for incoming ring");
1470 guest_to_host_event.0.deliver();
1471 }
1472 }
1473
1474 let ring_size = gpadl_ring_size(&in_gpadl).try_into()?;
1475 if ring::writer_needs_signal(control_page.pages()[0], ring_size) {
1476 tracelimit::info_ratelimited!(channel = %channel.key, "waking guest for incoming ring");
1477 host_to_guest_interrupt.deliver();
1478 }
1479 Ok(())
1480 }
1481
1482 fn unstick_outgoing_ring(
1483 gm: &GuestMemory,
1484 channel: &Channel,
1485 out_gpadl: AlignedGpadlView,
1486 guest_to_host_event: Option<&ChannelEvent>,
1487 host_to_guest_interrupt: &Interrupt,
1488 ) -> anyhow::Result<()> {
1489 let control_page = lock_gpn_with_subrange(gm, out_gpadl.gpns()[0])?;
1490 if ring::reader_needs_signal(control_page.pages()[0]) {
1491 tracelimit::info_ratelimited!(channel = %channel.key, "waking guest for outgoing ring");
1492 host_to_guest_interrupt.deliver();
1493 }
1494
1495 if let Some(guest_to_host_event) = guest_to_host_event {
1496 let ring_size = gpadl_ring_size(&out_gpadl).try_into()?;
1497 if ring::writer_needs_signal(control_page.pages()[0], ring_size) {
1498 tracelimit::info_ratelimited!(channel = %channel.key, "waking host for outgoing ring");
1499 guest_to_host_event.0.deliver();
1500 }
1501 }
1502 Ok(())
1503 }
1504}
1505
1506impl Notifier for ServerTaskInner {
1507 fn notify(&mut self, offer_id: OfferId, action: channels::Action) {
1508 let channel = self
1509 .channels
1510 .get_mut(&offer_id)
1511 .expect("channel does not exist");
1512
1513 fn handle<I: 'static + Send, R: 'static + Send>(
1514 offer_id: OfferId,
1515 channel: &Channel,
1516 req: impl FnOnce(Rpc<I, R>) -> ChannelRequest,
1517 input: I,
1518 f: impl 'static + Send + FnOnce(R) -> ChannelResponse,
1519 ) -> Pin<Box<dyn Send + Future<Output = (OfferId, u64, Result<ChannelResponse, RpcError>)>>>
1520 {
1521 let recv = channel.send.call(req, input);
1522 let seq = channel.seq;
1523 Box::pin(async move {
1524 let r = recv.await.map(f);
1525 (offer_id, seq, r)
1526 })
1527 }
1528
1529 let response = match action {
1530 channels::Action::Open(open_params, version) => {
1531 let seq = channel.seq;
1532 let key = channel.key;
1533 match self.open_channel(offer_id, &open_params) {
1534 Ok((channel, interrupt)) => handle(
1535 offer_id,
1536 channel,
1537 ChannelRequest::Open,
1538 OpenRequest::new(
1539 open_params.open_data,
1540 interrupt,
1541 version.feature_flags,
1542 channel.flags,
1543 ),
1544 ChannelResponse::Open,
1545 ),
1546 Err(err) => {
1547 tracelimit::error_ratelimited!(
1548 err = err.as_ref() as &dyn std::error::Error,
1549 ?offer_id,
1550 %key,
1551 "could not open channel",
1552 );
1553
1554 Box::pin(future::ready((
1557 offer_id,
1558 seq,
1559 Ok(ChannelResponse::Open(false)),
1560 )))
1561 }
1562 }
1563 }
1564 channels::Action::Close => {
1565 if let Some(channel_bitmap) = self.channel_bitmap.as_ref() {
1566 if let ChannelState::Open(ref state) = channel.state {
1567 channel_bitmap.unregister_channel(state.open_params.event_flag);
1568 }
1569 }
1570
1571 channel.state = ChannelState::Closing;
1572 handle(offer_id, channel, ChannelRequest::Close, (), |()| {
1573 ChannelResponse::Close
1574 })
1575 }
1576 channels::Action::Gpadl(gpadl_id, count, buf) => {
1577 channel.gpadls.add(
1578 gpadl_id,
1579 MultiPagedRangeBuf::from_range_buffer(count.into(), buf.clone()).unwrap(),
1580 );
1581 handle(
1582 offer_id,
1583 channel,
1584 ChannelRequest::Gpadl,
1585 GpadlRequest {
1586 id: gpadl_id,
1587 count,
1588 buf,
1589 },
1590 move |r| ChannelResponse::Gpadl(gpadl_id, r),
1591 )
1592 }
1593 channels::Action::TeardownGpadl {
1594 gpadl_id,
1595 post_restore,
1596 } => {
1597 if !post_restore {
1598 channel.gpadls.remove(gpadl_id, Box::new(|| ()));
1599 }
1600
1601 handle(
1602 offer_id,
1603 channel,
1604 ChannelRequest::TeardownGpadl,
1605 gpadl_id,
1606 move |()| ChannelResponse::TeardownGpadl(gpadl_id),
1607 )
1608 }
1609 channels::Action::Modify { target_vp } => {
1610 let ChannelState::Open(state) = &mut channel.state else {
1611 unreachable!();
1612 };
1613
1614 if let Err(err) = state.set_event_port_target_vp(target_vp) {
1615 tracelimit::error_ratelimited!(
1616 error = err.as_ref() as &dyn std::error::Error,
1617 channel = %channel.key,
1618 "could not modify channel",
1619 );
1620
1621 let seq = channel.seq;
1623 Box::pin(async move {
1624 (
1625 offer_id,
1626 seq,
1627 Ok(ChannelResponse::Modify(protocol::STATUS_UNSUCCESSFUL)),
1628 )
1629 })
1630 } else {
1631 handle(
1632 offer_id,
1633 channel,
1634 ChannelRequest::Modify,
1635 ModifyRequest::TargetVp { target_vp },
1636 ChannelResponse::Modify,
1637 )
1638 }
1639 }
1640 };
1641 self.channel_responses.push(response);
1642 }
1643
1644 fn modify_connection(&mut self, mut request: ModifyConnectionRequest) -> anyhow::Result<()> {
1645 self.map_interrupt_page(request.interrupt_page)
1646 .context("Failed to map interrupt page.")?;
1647
1648 self.set_monitor_page(&mut request)
1649 .context("Failed to map monitor page.")?;
1650
1651 if let Some(vp) = request.target_message_vp {
1652 self.message_port.set_target_vp(vp)?;
1653 }
1654
1655 if request.notify_relay {
1656 self.relay_send.send(request.into());
1657 }
1658
1659 Ok(())
1660 }
1661
1662 fn forward_unhandled(&mut self, request: InitiateContactRequest) {
1663 if let Some(external_server) = &self.external_server_send {
1664 external_server.send(request);
1665 } else {
1666 tracing::warn!(?request, "nowhere to forward unhandled request")
1667 }
1668 }
1669
1670 fn inspect(&self, version: Option<VersionInfo>, offer_id: OfferId, req: inspect::Request<'_>) {
1671 let channel = self.channels.get(&offer_id).expect("should exist");
1672 let mut resp = req.respond();
1673 if let ChannelState::Open(state) = &channel.state {
1674 let mem = self.get_gm_for_channel(version.expect("must be connected"), channel);
1675 inspect_rings(
1676 &mut resp,
1677 mem,
1678 channel.gpadls.clone(),
1679 &state.open_params.open_data,
1680 );
1681 }
1682 }
1683
1684 fn send_message(&mut self, message: &OutgoingMessage, target: MessageTarget) -> bool {
1685 if !self.running && !self.send_messages_while_stopped {
1698 if !matches!(target, MessageTarget::Default) {
1699 tracelimit::error_ratelimited!(?target, "dropping message while paused");
1700 }
1701 return false;
1702 }
1703
1704 let mut port_storage;
1705 let port = match target {
1706 MessageTarget::Default => self.message_port.as_mut(),
1707 MessageTarget::ReservedChannel(offer_id, target) => {
1708 if let Some(port) = self.get_reserved_channel_message_port(offer_id, target) {
1709 port.as_mut()
1710 } else {
1711 return true;
1713 }
1714 }
1715 MessageTarget::Custom(target) => {
1716 port_storage = match self.synic.new_guest_message_port(
1717 self.redirect_vtl,
1718 target.vp,
1719 target.sint,
1720 ) {
1721 Ok(port) => port,
1722 Err(err) => {
1723 tracing::error!(
1724 ?err,
1725 ?self.redirect_vtl,
1726 ?target,
1727 "could not create message port"
1728 );
1729
1730 return true;
1732 }
1733 };
1734 port_storage.as_mut()
1735 }
1736 };
1737
1738 matches!(
1741 port.poll_post_message(
1742 &mut std::task::Context::from_waker(std::task::Waker::noop()),
1743 VMBUS_MESSAGE_TYPE,
1744 message.data()
1745 ),
1746 Poll::Ready(())
1747 )
1748 }
1749
1750 fn notify_hvsock(&mut self, request: &HvsockConnectRequest) {
1751 tracing::debug!(?request, "received hvsock connect request");
1752 self.hvsock_requests += 1;
1753 self.hvsock_send.send(*request);
1754 }
1755
1756 fn reset_complete(&mut self) {
1757 if let Some(monitor) = self.synic.monitor_support() {
1758 if let Err(err) = monitor.set_monitor_page(self.vtl, None) {
1759 tracing::warn!(?err, "resetting monitor page failed")
1760 }
1761 }
1762
1763 self.unreserve_channels();
1764 for done in self.reset_done.drain(..) {
1765 done.complete(());
1766 }
1767 }
1768
1769 fn unload_complete(&mut self) {
1770 self.unreserve_channels();
1771 }
1772}
1773
1774impl ServerTaskInner {
1775 fn open_channel(
1776 &mut self,
1777 offer_id: OfferId,
1778 open_params: &OpenParams,
1779 ) -> anyhow::Result<(&mut Channel, Interrupt)> {
1780 let channel = self
1781 .channels
1782 .get_mut(&offer_id)
1783 .expect("channel does not exist");
1784
1785 if let Some(channel_bitmap) = self.channel_bitmap.as_ref() {
1787 channel_bitmap.register_channel(
1788 open_params.event_flag,
1789 channel.guest_to_host_event.0.clone(),
1790 );
1791 }
1792 let event_port = self
1797 .synic
1798 .add_event_port(
1799 open_params.connection_id,
1800 self.vtl,
1801 channel.guest_to_host_event.clone(),
1802 open_params.monitor_info,
1803 )
1804 .context("failed to create guest-to-host event port")?;
1805
1806 let (target_vp, event_flag) = if self.channel_bitmap.is_some() {
1809 (Some(0), 0)
1810 } else {
1811 (open_params.open_data.target_vp, open_params.event_flag)
1812 };
1813
1814 let (guest_event_port, interrupt) = if let Some(target_vp) = target_vp {
1815 let (target_vtl, target_sint) = if open_params.flags.redirect_interrupt() {
1816 (self.redirect_vtl, self.redirect_sint)
1817 } else {
1818 (self.vtl, VMBUS_SINT)
1819 };
1820
1821 let guest_event_port = self.synic.new_guest_event_port(
1822 VmbusServer::get_child_event_port_id(open_params.channel_id, VMBUS_SINT, self.vtl),
1823 target_vtl,
1824 target_vp,
1825 target_sint,
1826 event_flag,
1827 open_params.monitor_info,
1828 )?;
1829
1830 let interrupt = ChannelBitmap::create_interrupt(
1831 &self.channel_bitmap,
1832 guest_event_port.interrupt(),
1833 open_params.event_flag,
1834 );
1835
1836 (Some(guest_event_port), interrupt)
1837 } else {
1838 (None, Interrupt::null_event())
1841 };
1842
1843 channel.reserved_state.message_port = None;
1845
1846 if let Some(target) = open_params.reserved_target {
1848 channel.reserved_state.message_port = Some(self.synic.new_guest_message_port(
1849 self.redirect_vtl,
1850 target.vp,
1851 target.sint,
1852 )?);
1853
1854 channel.reserved_state.target = target;
1855 }
1856
1857 channel.state = ChannelState::Open(Box::new(ChannelOpenState {
1858 open_params: *open_params,
1859 _event_port: event_port,
1860 guest_event_port,
1861 host_to_guest_interrupt: interrupt.clone(),
1862 }));
1863 Ok((channel, interrupt))
1864 }
1865
1866 fn map_interrupt_page(&mut self, interrupt_page: Update<u64>) -> anyhow::Result<()> {
1869 let interrupt_page = match interrupt_page {
1870 Update::Unchanged => return Ok(()),
1871 Update::Reset => {
1872 self.channel_bitmap = None;
1873 self.shared_event_port = None;
1874 return Ok(());
1875 }
1876 Update::Set(interrupt_page) => interrupt_page,
1877 };
1878
1879 assert_ne!(interrupt_page, 0);
1880
1881 if interrupt_page % PAGE_SIZE as u64 != 0 {
1882 anyhow::bail!("interrupt page {:#x} is not page aligned", interrupt_page);
1883 }
1884
1885 let interrupt_page = lock_page_with_subrange(&self.gm, interrupt_page)?;
1888 let channel_bitmap = Arc::new(ChannelBitmap::new(interrupt_page));
1889 self.channel_bitmap = Some(channel_bitmap.clone());
1890
1891 let interrupt = Interrupt::from_fn(move || {
1893 channel_bitmap.handle_shared_interrupt();
1894 });
1895
1896 self.shared_event_port = Some(self.synic.add_event_port(
1897 SHARED_EVENT_CONNECTION_ID,
1898 self.vtl,
1899 Arc::new(ChannelEvent(interrupt)),
1900 None,
1901 )?);
1902
1903 Ok(())
1904 }
1905
1906 fn set_monitor_page(&mut self, request: &mut ModifyConnectionRequest) -> anyhow::Result<()> {
1907 let monitor_page = match request.monitor_page {
1908 Update::Unchanged => return Ok(()),
1909 Update::Reset => None,
1910 Update::Set(value) => Some(value),
1911 };
1912
1913 if self.channels.iter().any(|(_, c)| {
1915 matches!(
1916 &c.state,
1917 ChannelState::Open(state) if state.open_params.monitor_info.is_some()
1918 )
1919 }) {
1920 anyhow::bail!("attempt to change monitor page while open channels using mnf");
1921 }
1922
1923 if let Some(mnf_support) = self.mnf_support.as_mut() {
1927 if let Some(monitor) = self.synic.monitor_support() {
1928 mnf_support.allocated_monitor_page = None;
1929
1930 if let Some(version) = request.version {
1931 if version.feature_flags.server_specified_monitor_pages() {
1932 if let Some(monitor_page) = monitor.allocate_monitor_page(self.vtl)? {
1933 tracelimit::info_ratelimited!(
1934 ?monitor_page,
1935 "using server-allocated monitor pages"
1936 );
1937 mnf_support.allocated_monitor_page = Some(monitor_page);
1938 }
1939 }
1940 }
1941
1942 if mnf_support.allocated_monitor_page.is_none() {
1944 if let Err(err) = monitor.set_monitor_page(self.vtl, monitor_page) {
1945 anyhow::bail!(
1946 "setting monitor page failed, err = {err:?}, monitor_page = {monitor_page:?}"
1947 );
1948 }
1949 }
1950 }
1951
1952 request.monitor_page = Update::Unchanged;
1955 }
1956
1957 Ok(())
1958 }
1959
1960 fn get_reserved_channel_message_port(
1961 &mut self,
1962 offer_id: OfferId,
1963 new_target: ConnectionTarget,
1964 ) -> Option<&mut Box<dyn GuestMessagePort>> {
1965 let channel = self
1966 .channels
1967 .get_mut(&offer_id)
1968 .expect("channel does not exist");
1969
1970 assert!(
1971 channel.reserved_state.message_port.is_some(),
1972 "channel is not reserved"
1973 );
1974
1975 if channel.reserved_state.target.sint != new_target.sint {
1978 channel.reserved_state.message_port = None;
1980 let message_port = self
1981 .synic
1982 .new_guest_message_port(self.redirect_vtl, new_target.vp, new_target.sint)
1983 .inspect_err(|err| {
1984 tracing::error!(
1985 key = %channel.key,
1986 ?err,
1987 ?self.redirect_vtl,
1988 ?new_target,
1989 "could not create reserved channel message port"
1990 )
1991 })
1992 .ok()?;
1993
1994 channel.reserved_state.message_port = Some(message_port);
1995 channel.reserved_state.target = new_target;
1996 } else if channel.reserved_state.target.vp != new_target.vp {
1997 let message_port = channel.reserved_state.message_port.as_mut().unwrap();
1998
1999 if let Err(err) = message_port.set_target_vp(new_target.vp) {
2002 tracing::error!(
2003 key = %channel.key,
2004 ?err,
2005 ?self.redirect_vtl,
2006 ?new_target,
2007 "could not update reserved channel message port"
2008 );
2009 }
2010
2011 channel.reserved_state.target = new_target;
2012 return Some(message_port);
2013 }
2014
2015 Some(channel.reserved_state.message_port.as_mut().unwrap())
2016 }
2017
2018 fn unreserve_channels(&mut self) {
2019 for channel in self.channels.values_mut() {
2021 if let ChannelState::Closed = channel.state {
2022 channel.reserved_state.message_port = None;
2023 }
2024 }
2025 }
2026
2027 fn get_gm_for_channel(&self, version: VersionInfo, channel: &Channel) -> &GuestMemory {
2028 select_gm_for_channel(&self.gm, self.private_gm.as_ref(), version, channel)
2029 }
2030}
2031
2032fn select_gm_for_channel<'a>(
2033 gm: &'a GuestMemory,
2034 private_gm: Option<&'a GuestMemory>,
2035 version: VersionInfo,
2036 channel: &Channel,
2037) -> &'a GuestMemory {
2038 if channel.flags.confidential_ring_buffer() && version.feature_flags.confidential_channels() {
2039 if let Some(private_gm) = private_gm {
2040 return private_gm;
2041 }
2042 }
2043
2044 gm
2045}
2046
2047#[derive(Clone)]
2049pub struct VmbusServerControl {
2050 mem: GuestMemory,
2051 private_mem: Option<GuestMemory>,
2052 send: mesh::Sender<OfferRequest>,
2053 use_event: bool,
2054 force_confidential_external_memory: bool,
2055}
2056
2057impl VmbusServerControl {
2058 pub async fn offer_core(&self, offer_info: OfferInfo) -> anyhow::Result<OfferResources> {
2061 let flags = offer_info.params.flags;
2062 self.send
2063 .call_failable(OfferRequest::Offer, offer_info)
2064 .await?;
2065 Ok(OfferResources::new(
2066 self.mem.clone(),
2067 if flags.confidential_ring_buffer() || flags.confidential_external_memory() {
2068 self.private_mem.clone()
2069 } else {
2070 None
2071 },
2072 ))
2073 }
2074
2075 pub async fn force_reset(&self) -> anyhow::Result<()> {
2078 self.send
2079 .call(OfferRequest::ForceReset, ())
2080 .await
2081 .context("vmbus server is gone")
2082 }
2083
2084 async fn offer(&self, request: OfferInput) -> anyhow::Result<OfferResources> {
2085 let mut offer_info = OfferInfo {
2086 params: request.params.into(),
2087 event: request.event,
2088 request_send: request.request_send,
2089 server_request_recv: request.server_request_recv,
2090 };
2091
2092 if self.force_confidential_external_memory {
2093 tracing::warn!(
2094 key = %offer_info.params.key(),
2095 "forcing confidential external memory for channel"
2096 );
2097
2098 offer_info
2099 .params
2100 .flags
2101 .set_confidential_external_memory(true);
2102 }
2103
2104 self.offer_core(offer_info).await
2105 }
2106}
2107
2108fn inspect_rings(
2110 resp: &mut inspect::Response<'_>,
2111 gm: &GuestMemory,
2112 gpadl_map: Arc<GpadlMap>,
2113 open_data: &OpenData,
2114) -> Option<()> {
2115 let gpadl = gpadl_map
2116 .view()
2117 .map(GpadlId(open_data.ring_gpadl_id.0))
2118 .ok()?;
2119
2120 let aligned = AlignedGpadlView::new(gpadl).ok()?;
2121 let (in_gpadl, out_gpadl) = aligned.split(open_data.ring_offset).ok()?;
2122 resp.child("incoming_ring", |req| inspect_ring(req, &in_gpadl, gm));
2123 resp.child("outgoing_ring", |req| inspect_ring(req, &out_gpadl, gm));
2124 Some(())
2125}
2126
2127fn inspect_ring(req: inspect::Request<'_>, gpadl: &AlignedGpadlView, gm: &GuestMemory) {
2129 let mut resp = req.respond();
2130
2131 resp.hex("ring_size", gpadl_ring_size(gpadl));
2132
2133 if let Ok(pages) = lock_gpn_with_subrange(gm, gpadl.gpns()[0]) {
2136 ring::inspect_ring(pages.pages()[0], &mut resp);
2137 }
2138}
2139
2140fn gpadl_ring_size(gpadl: &AlignedGpadlView) -> usize {
2141 (gpadl.gpns().len() - 1) * PAGE_SIZE
2143}
2144
2145fn lock_page_with_subrange(gm: &GuestMemory, offset: u64) -> anyhow::Result<guestmem::LockedPages> {
2150 Ok(gm
2151 .lockable_subrange(offset, PAGE_SIZE as u64)?
2152 .lock_gpns(false, &[0])?)
2153}
2154
2155fn lock_gpn_with_subrange(gm: &GuestMemory, gpn: u64) -> anyhow::Result<guestmem::LockedPages> {
2160 lock_page_with_subrange(gm, gpn * PAGE_SIZE as u64)
2161}
2162
2163pub(crate) struct MessageSender {
2164 send: mpsc::Sender<SynicMessage>,
2165 multiclient: bool,
2166}
2167
2168impl MessageSender {
2169 fn poll_handle_message(
2170 &self,
2171 cx: &mut std::task::Context<'_>,
2172 msg: &[u8],
2173 trusted: bool,
2174 ) -> Poll<Result<(), SendError>> {
2175 let mut send = self.send.clone();
2176 ready!(send.poll_ready(cx))?;
2177 send.start_send(SynicMessage {
2178 data: msg.to_vec(),
2179 multiclient: self.multiclient,
2180 trusted,
2181 })?;
2182
2183 Poll::Ready(Ok(()))
2184 }
2185}
2186
2187impl MessagePort for MessageSender {
2188 fn poll_handle_message(
2189 &self,
2190 cx: &mut std::task::Context<'_>,
2191 msg: &[u8],
2192 trusted: bool,
2193 ) -> Poll<()> {
2194 if let Err(err) = ready!(self.poll_handle_message(cx, msg, trusted)) {
2195 tracelimit::error_ratelimited!(
2196 error = &err as &dyn std::error::Error,
2197 "failed to send message"
2198 );
2199 }
2200
2201 Poll::Ready(())
2202 }
2203}
2204
2205#[async_trait]
2206impl ParentBus for VmbusServerControl {
2207 async fn add_child(&self, request: OfferInput) -> anyhow::Result<OfferResources> {
2208 self.offer(request).await
2209 }
2210
2211 fn clone_bus(&self) -> Box<dyn ParentBus> {
2212 Box::new(self.clone())
2213 }
2214
2215 fn use_event(&self) -> bool {
2216 self.use_event
2217 }
2218}