1#![expect(missing_docs)]
5#![forbid(unsafe_code)]
6
7mod channel_bitmap;
8pub mod channels;
9pub mod event;
10pub mod hvsock;
11mod monitor;
12mod proxyintegration;
13#[cfg(test)]
14mod tests;
15
16pub type Guid = guid::Guid;
18
19use anyhow::Context;
20use async_trait::async_trait;
21use channel_bitmap::ChannelBitmap;
22use channels::ConnectionTarget;
23pub use channels::InitiateContactRequest;
24use channels::MessageTarget;
25pub use channels::MnfUsage;
26use channels::ModifyConnectionRequest;
27use channels::ModifyConnectionResponse;
28use channels::Notifier;
29use channels::OfferId;
30pub use channels::OfferParamsInternal;
31use channels::OpenParams;
32use channels::RestoreError;
33pub use channels::Update;
34use futures::FutureExt;
35use futures::StreamExt;
36use futures::channel::mpsc;
37use futures::channel::mpsc::SendError;
38use futures::future::OptionFuture;
39use futures::future::poll_fn;
40use futures::stream::SelectAll;
41use guestmem::GuestMemory;
42use hvdef::Vtl;
43use inspect::Inspect;
44use mesh::payload::Protobuf;
45use mesh::rpc::FailableRpc;
46use mesh::rpc::Rpc;
47use mesh::rpc::RpcError;
48use mesh::rpc::RpcSend;
49use pal_async::driver::Driver;
50use pal_async::driver::SpawnDriver;
51use pal_async::task::Task;
52use pal_async::timer::PolledTimer;
53use pal_event::Event;
54#[cfg(windows)]
55pub use proxyintegration::ProxyIntegration;
56#[cfg(windows)]
57pub use proxyintegration::ProxyServerInfo;
58use ring::PAGE_SIZE;
59use std::collections::HashMap;
60use std::future;
61use std::future::Future;
62use std::pin::Pin;
63use std::sync::Arc;
64use std::task::Poll;
65use std::task::ready;
66use std::time::Duration;
67use unicycle::FuturesUnordered;
68use vmbus_channel::bus::ChannelRequest;
69use vmbus_channel::bus::ChannelServerRequest;
70use vmbus_channel::bus::GpadlRequest;
71use vmbus_channel::bus::ModifyRequest;
72use vmbus_channel::bus::OfferInput;
73use vmbus_channel::bus::OfferKey;
74use vmbus_channel::bus::OfferResources;
75use vmbus_channel::bus::OpenData;
76use vmbus_channel::bus::OpenRequest;
77use vmbus_channel::bus::ParentBus;
78use vmbus_channel::bus::RestoreResult;
79use vmbus_channel::gpadl::GpadlMap;
80use vmbus_channel::gpadl_ring::AlignedGpadlView;
81use vmbus_core::HvsockConnectRequest;
82use vmbus_core::HvsockConnectResult;
83use vmbus_core::MaxVersionInfo;
84use vmbus_core::OutgoingMessage;
85use vmbus_core::TaggedStream;
86use vmbus_core::VersionInfo;
87use vmbus_core::protocol;
88pub use vmbus_core::protocol::GpadlId;
89#[cfg(windows)]
90use vmbus_proxy::ProxyHandle;
91use vmbus_ring as ring;
92use vmbus_ring::gparange::MultiPagedRangeBuf;
93use vmcore::interrupt::Interrupt;
94use vmcore::save_restore::SavedStateRoot;
95use vmcore::synic::EventPort;
96use vmcore::synic::GuestEventPort;
97use vmcore::synic::GuestMessagePort;
98use vmcore::synic::MessagePort;
99use vmcore::synic::MonitorPageGpas;
100use vmcore::synic::SynicPortAccess;
101
102const SINT: u8 = 2;
103pub const REDIRECT_SINT: u8 = 7;
104pub const REDIRECT_VTL: Vtl = Vtl::Vtl2;
105const SHARED_EVENT_CONNECTION_ID: u32 = 2;
106const EVENT_PORT_ID: u32 = 2;
107const VMBUS_MESSAGE_TYPE: u32 = 1;
108
109const MAX_CONCURRENT_HVSOCK_REQUESTS: usize = 16;
110
111pub struct VmbusServer {
112 task_send: mesh::Sender<VmbusRequest>,
113 control: Arc<VmbusServerControl>,
114 _message_port: Box<dyn Sync + Send>,
115 _multiclient_message_port: Option<Box<dyn Sync + Send>>,
116 task: Task<ServerTask>,
117}
118
119pub struct VmbusServerBuilder<T: SpawnDriver> {
120 spawner: T,
121 synic: Arc<dyn SynicPortAccess>,
122 gm: GuestMemory,
123 private_gm: Option<GuestMemory>,
124 vtl: Vtl,
125 hvsock_notify: Option<HvsockServerChannelHalf>,
126 server_relay: Option<VmbusServerChannelHalf>,
127 saved_state_notify: Option<mesh::Sender<SavedStateRequest>>,
128 external_server: Option<mesh::Sender<InitiateContactRequest>>,
129 external_requests: Option<mesh::Receiver<InitiateContactRequest>>,
130 use_message_redirect: bool,
131 channel_id_offset: u16,
132 max_version: Option<MaxVersionInfo>,
133 delay_max_version: bool,
134 enable_mnf: bool,
135 force_confidential_external_memory: bool,
136 send_messages_while_stopped: bool,
137 channel_unstick_delay: Option<Duration>,
138}
139
140#[derive(mesh::MeshPayload)]
141pub enum SavedStateRequest {
143 Set(FailableRpc<Box<channels::SavedState>, ()>),
144 Clear(Rpc<(), ()>),
145}
146
147pub struct ServerChannelHalf<Request, Response> {
149 request_send: mesh::Sender<Request>,
150 response_receive: mesh::Receiver<Response>,
151}
152
153pub struct RelayChannelHalf<Request, Response> {
155 pub request_receive: mesh::Receiver<Request>,
156 pub response_send: mesh::Sender<Response>,
157}
158
159pub struct RelayChannel<Request, Response> {
161 pub relay_half: RelayChannelHalf<Request, Response>,
162 pub server_half: ServerChannelHalf<Request, Response>,
163}
164
165impl<Request: 'static + Send, Response: 'static + Send> RelayChannel<Request, Response> {
166 pub fn new() -> Self {
168 let (request_send, request_receive) = mesh::channel();
169 let (response_send, response_receive) = mesh::channel();
170 Self {
171 relay_half: RelayChannelHalf {
172 request_receive,
173 response_send,
174 },
175 server_half: ServerChannelHalf {
176 request_send,
177 response_receive,
178 },
179 }
180 }
181}
182
183pub type VmbusServerChannelHalf = ServerChannelHalf<ModifyRelayRequest, ModifyRelayResponse>;
184pub type VmbusRelayChannelHalf = RelayChannelHalf<ModifyRelayRequest, ModifyRelayResponse>;
185pub type VmbusRelayChannel = RelayChannel<ModifyRelayRequest, ModifyRelayResponse>;
186pub type HvsockServerChannelHalf = ServerChannelHalf<HvsockConnectRequest, HvsockConnectResult>;
187pub type HvsockRelayChannelHalf = RelayChannelHalf<HvsockConnectRequest, HvsockConnectResult>;
188pub type HvsockRelayChannel = RelayChannel<HvsockConnectRequest, HvsockConnectResult>;
189
190#[derive(Debug, Copy, Clone)]
199pub struct ModifyRelayRequest {
200 pub version: Option<u32>,
201 pub monitor_page: Update<MonitorPageGpas>,
202 pub use_interrupt_page: Option<bool>,
203}
204
205#[derive(Debug, Copy, Clone)]
207pub enum ModifyRelayResponse {
208 Supported(protocol::ConnectionState, protocol::FeatureFlags),
212 Unsupported,
215 Modified(protocol::ConnectionState),
218}
219
220impl From<ModifyConnectionRequest> for ModifyRelayRequest {
221 fn from(value: ModifyConnectionRequest) -> Self {
222 Self {
223 version: value.version.map(|v| v.version as u32),
224 monitor_page: value.monitor_page,
225 use_interrupt_page: match value.interrupt_page {
226 Update::Unchanged => None,
227 Update::Reset => Some(false),
228 Update::Set(_) => Some(true),
229 },
230 }
231 }
232}
233
234#[derive(Debug)]
235enum VmbusRequest {
236 Reset(Rpc<(), ()>),
237 Inspect(inspect::Deferred),
238 Save(Rpc<(), SavedState>),
239 Restore(Rpc<Box<SavedState>, Result<(), RestoreError>>),
240 Start,
241 Stop(Rpc<(), ()>),
242}
243
244#[derive(mesh::MeshPayload, Debug)]
245pub struct OfferInfo {
246 pub params: OfferParamsInternal,
247 pub event: Interrupt,
248 pub request_send: mesh::Sender<ChannelRequest>,
249 pub server_request_recv: mesh::Receiver<ChannelServerRequest>,
250}
251
252#[expect(clippy::large_enum_variant)]
253#[derive(mesh::MeshPayload)]
254pub(crate) enum OfferRequest {
255 Offer(FailableRpc<OfferInfo, ()>),
256 ForceReset(Rpc<(), ()>),
257}
258
259impl Inspect for VmbusServer {
260 fn inspect(&self, req: inspect::Request<'_>) {
261 self.task_send.send(VmbusRequest::Inspect(req.defer()));
262 }
263}
264
265struct ChannelEvent(Interrupt);
266
267impl EventPort for ChannelEvent {
268 fn handle_event(&self, _flag: u16) {
269 self.0.deliver();
270 }
271
272 fn os_event(&self) -> Option<&Event> {
273 self.0.event()
274 }
275}
276
277#[derive(Debug, Protobuf, SavedStateRoot)]
278#[mesh(package = "vmbus.server")]
279pub struct SavedState {
280 #[mesh(1)]
281 server: channels::SavedState,
282 #[mesh(2)]
286 lost_synic_bug_fixed: bool,
287}
288
289const MESSAGE_CONNECTION_ID: u32 = 1;
290const MULTICLIENT_MESSAGE_CONNECTION_ID: u32 = 4;
291
292impl<T: SpawnDriver + Clone> VmbusServerBuilder<T> {
293 pub fn new(spawner: T, synic: Arc<dyn SynicPortAccess>, gm: GuestMemory) -> Self {
295 Self {
296 spawner,
297 synic,
298 gm,
299 private_gm: None,
300 vtl: Vtl::Vtl0,
301 hvsock_notify: None,
302 server_relay: None,
303 saved_state_notify: None,
304 external_server: None,
305 external_requests: None,
306 use_message_redirect: false,
307 channel_id_offset: 0,
308 max_version: None,
309 delay_max_version: false,
310 enable_mnf: false,
311 force_confidential_external_memory: false,
312 send_messages_while_stopped: false,
313 channel_unstick_delay: Some(Duration::from_millis(100)),
314 }
315 }
316
317 pub fn private_gm(mut self, private_gm: Option<GuestMemory>) -> Self {
321 self.private_gm = private_gm;
322 self
323 }
324
325 pub fn vtl(mut self, vtl: Vtl) -> Self {
327 self.vtl = vtl;
328 self
329 }
330
331 pub fn hvsock_notify(mut self, hvsock_notify: Option<HvsockServerChannelHalf>) -> Self {
333 self.hvsock_notify = hvsock_notify;
334 self
335 }
336
337 pub fn saved_state_notify(
339 mut self,
340 saved_state_notify: Option<mesh::Sender<SavedStateRequest>>,
341 ) -> Self {
342 self.saved_state_notify = saved_state_notify;
343 self
344 }
345
346 pub fn server_relay(mut self, server_relay: Option<VmbusServerChannelHalf>) -> Self {
349 self.server_relay = server_relay;
350 self
351 }
352
353 pub fn external_requests(
355 mut self,
356 external_requests: Option<mesh::Receiver<InitiateContactRequest>>,
357 ) -> Self {
358 self.external_requests = external_requests;
359 self
360 }
361
362 pub fn external_server(
365 mut self,
366 external_server: Option<mesh::Sender<InitiateContactRequest>>,
367 ) -> Self {
368 self.external_server = external_server;
369 self
370 }
371
372 pub fn use_message_redirect(mut self, use_message_redirect: bool) -> Self {
374 self.use_message_redirect = use_message_redirect;
375 self
376 }
377
378 pub fn enable_channel_id_offset(mut self, enable: bool) -> Self {
383 self.channel_id_offset = if enable { 1024 } else { 0 };
384 self
385 }
386
387 pub fn max_version(mut self, max_version: Option<MaxVersionInfo>) -> Self {
391 self.max_version = max_version;
392 self
393 }
394
395 pub fn delay_max_version(mut self, delay: bool) -> Self {
400 self.delay_max_version = delay;
401 self
402 }
403
404 pub fn enable_mnf(mut self, enable: bool) -> Self {
408 self.enable_mnf = enable;
409 self
410 }
411
412 pub fn force_confidential_external_memory(mut self, force: bool) -> Self {
415 self.force_confidential_external_memory = force;
416 self
417 }
418
419 pub fn send_messages_while_stopped(mut self, send: bool) -> Self {
426 self.send_messages_while_stopped = send;
427 self
428 }
429
430 pub fn channel_unstick_delay(mut self, delay: Option<Duration>) -> Self {
437 self.channel_unstick_delay = delay;
438 self
439 }
440
441 pub fn build(self) -> anyhow::Result<VmbusServer> {
446 #[expect(clippy::disallowed_methods)] let (message_send, message_recv) = mpsc::channel(64);
448 let message_sender = Arc::new(MessageSender {
449 send: message_send.clone(),
450 multiclient: self.use_message_redirect,
451 });
452
453 let (redirect_vtl, redirect_sint) = if self.use_message_redirect {
454 (REDIRECT_VTL, REDIRECT_SINT)
455 } else {
456 (self.vtl, SINT)
457 };
458
459 let connection_id = if self.vtl == Vtl::Vtl0 && !self.use_message_redirect {
462 MESSAGE_CONNECTION_ID
463 } else {
464 VmbusServer::get_child_message_connection_id(0, redirect_sint, redirect_vtl)
467 };
468
469 let _message_port = self
470 .synic
471 .add_message_port(connection_id, redirect_vtl, message_sender)
472 .context("failed to create vmbus synic ports")?;
473
474 let _multiclient_message_port = if self.vtl == Vtl::Vtl0 && !self.use_message_redirect {
478 let multiclient_message_sender = Arc::new(MessageSender {
479 send: message_send,
480 multiclient: true,
481 });
482
483 Some(
484 self.synic
485 .add_message_port(
486 MULTICLIENT_MESSAGE_CONNECTION_ID,
487 self.vtl,
488 multiclient_message_sender,
489 )
490 .context("failed to create vmbus synic ports")?,
491 )
492 } else {
493 None
494 };
495
496 let (offer_send, offer_recv) = mesh::mpsc_channel();
497 let control = Arc::new(VmbusServerControl {
498 mem: self.gm.clone(),
499 private_mem: self.private_gm.clone(),
500 send: offer_send,
501 use_event: self.synic.prefer_os_events(),
502 force_confidential_external_memory: self.force_confidential_external_memory,
503 });
504
505 let mut server = channels::Server::new(self.vtl, connection_id, self.channel_id_offset);
506
507 server.set_require_server_allocated_mnf(self.enable_mnf && self.private_gm.is_some());
512
513 if let Some(version) = self.max_version {
515 server.set_compatibility_version(version, self.delay_max_version);
516 }
517 let (relay_request_send, relay_response_recv) =
518 if let Some(server_relay) = self.server_relay {
519 let r = server_relay.response_receive.boxed().fuse();
520 (server_relay.request_send, r)
521 } else {
522 let (req_send, req_recv) = mesh::channel();
523 let resp_recv = req_recv
524 .map(|req: ModifyRelayRequest| {
525 if req.version.is_some() {
527 ModifyRelayResponse::Supported(
528 protocol::ConnectionState::SUCCESSFUL,
529 protocol::FeatureFlags::from_bits(u32::MAX),
530 )
531 } else {
532 ModifyRelayResponse::Modified(protocol::ConnectionState::SUCCESSFUL)
533 }
534 })
535 .boxed()
536 .fuse();
537 (req_send, resp_recv)
538 };
539
540 let (hvsock_send, hvsock_recv) = if let Some(hvsock_notify) = self.hvsock_notify {
542 let r = hvsock_notify.response_receive.boxed().fuse();
543 (hvsock_notify.request_send, r)
544 } else {
545 let (req_send, req_recv) = mesh::channel();
546 let resp_recv = req_recv
547 .map(|r: HvsockConnectRequest| HvsockConnectResult::from_request(&r, false))
548 .boxed()
549 .fuse();
550 (req_send, resp_recv)
551 };
552
553 let inner = ServerTaskInner {
554 running: false,
555 send_messages_while_stopped: self.send_messages_while_stopped,
556 gm: self.gm,
557 private_gm: self.private_gm,
558 vtl: self.vtl,
559 redirect_vtl,
560 redirect_sint,
561 message_port: self
562 .synic
563 .new_guest_message_port(redirect_vtl, 0, redirect_sint)?,
564 synic: self.synic,
565 hvsock_requests: 0,
566 hvsock_send,
567 saved_state_notify: self.saved_state_notify,
568 channels: HashMap::new(),
569 channel_responses: FuturesUnordered::new(),
570 relay_send: relay_request_send,
571 external_server_send: self.external_server,
572 channel_bitmap: None,
573 shared_event_port: None,
574 reset_done: Vec::new(),
575 mnf_support: self.enable_mnf.then(MnfSupport::default),
576 };
577
578 let (task_send, task_recv) = mesh::channel();
579 let mut server_task = ServerTask {
580 driver: Box::new(self.spawner.clone()),
581 server,
582 task_recv,
583 offer_recv,
584 message_recv,
585 server_request_recv: SelectAll::new(),
586 inner,
587 external_requests: self.external_requests,
588 next_seq: 0,
589 unstick_on_start: false,
590 channel_unstickers: FuturesUnordered::new(),
591 channel_unstick_delay: self.channel_unstick_delay,
592 };
593
594 let task = self.spawner.spawn("vmbus server", async move {
595 server_task.run(relay_response_recv, hvsock_recv).await;
596 server_task
597 });
598
599 Ok(VmbusServer {
600 task_send,
601 control,
602 _message_port,
603 _multiclient_message_port,
604 task,
605 })
606 }
607}
608
609impl VmbusServer {
610 pub fn builder<T: SpawnDriver + Clone>(
612 spawner: T,
613 synic: Arc<dyn SynicPortAccess>,
614 gm: GuestMemory,
615 ) -> VmbusServerBuilder<T> {
616 VmbusServerBuilder::new(spawner, synic, gm)
617 }
618
619 pub async fn save(&self) -> SavedState {
620 self.task_send.call(VmbusRequest::Save, ()).await.unwrap()
621 }
622
623 pub async fn restore(&self, state: SavedState) -> Result<(), RestoreError> {
624 self.task_send
625 .call(VmbusRequest::Restore, Box::new(state))
626 .await
627 .unwrap()
628 }
629
630 pub async fn stop(&self) {
632 self.task_send.call(VmbusRequest::Stop, ()).await.unwrap()
633 }
634
635 pub fn start(&self) {
637 self.task_send.send(VmbusRequest::Start);
638 }
639
640 pub async fn reset(&self) {
642 tracing::debug!("resetting channel state");
643 self.task_send.call(VmbusRequest::Reset, ()).await.unwrap()
644 }
645
646 pub async fn shutdown(self) {
648 drop(self.task_send);
649 let _ = self.task.await;
650 }
651
652 pub fn control(&self) -> Arc<VmbusServerControl> {
654 self.control.clone()
655 }
656
657 fn get_child_message_connection_id(vp_index: u32, sint_index: u8, vtl: Vtl) -> u32 {
660 MULTICLIENT_MESSAGE_CONNECTION_ID
661 | (vtl as u32) << 22
662 | vp_index << 8
663 | (sint_index as u32) << 4
664 }
665
666 fn get_child_event_port_id(channel_id: protocol::ChannelId, sint_index: u8, vtl: Vtl) -> u32 {
667 EVENT_PORT_ID | (vtl as u32) << 22 | channel_id.0 << 8 | (sint_index as u32) << 4
668 }
669}
670
671#[derive(mesh::MeshPayload)]
672pub struct RestoreInfo {
673 open_data: Option<OpenData>,
674 gpadls: Vec<(GpadlId, u16, Vec<u64>)>,
675 interrupt: Option<Interrupt>,
676}
677
678#[derive(Default)]
679pub struct SynicMessage {
680 data: Vec<u8>,
681 multiclient: bool,
682 trusted: bool,
683}
684
685#[derive(Default)]
687struct MnfSupport {
688 allocated_monitor_page: Option<MonitorPageGpas>,
689}
690
691#[derive(Debug, Clone, Copy)]
693struct OfferInstanceId {
694 offer_id: OfferId,
695 seq: u64,
696}
697
698struct ServerTask {
699 driver: Box<dyn Driver>,
700 server: channels::Server,
701 task_recv: mesh::Receiver<VmbusRequest>,
702 offer_recv: mesh::Receiver<OfferRequest>,
703 message_recv: mpsc::Receiver<SynicMessage>,
704 server_request_recv:
705 SelectAll<TaggedStream<OfferInstanceId, mesh::Receiver<ChannelServerRequest>>>,
706 inner: ServerTaskInner,
707 external_requests: Option<mesh::Receiver<InitiateContactRequest>>,
708 next_seq: u64,
710 unstick_on_start: bool,
711 channel_unstickers: FuturesUnordered<Pin<Box<dyn Send + Future<Output = OfferInstanceId>>>>,
712 channel_unstick_delay: Option<Duration>,
713}
714
715struct ServerTaskInner {
716 running: bool,
717 send_messages_while_stopped: bool,
718 gm: GuestMemory,
719 private_gm: Option<GuestMemory>,
720 synic: Arc<dyn SynicPortAccess>,
721 vtl: Vtl,
722 redirect_vtl: Vtl,
723 redirect_sint: u8,
724 message_port: Box<dyn GuestMessagePort>,
725 hvsock_requests: usize,
726 hvsock_send: mesh::Sender<HvsockConnectRequest>,
727 saved_state_notify: Option<mesh::Sender<SavedStateRequest>>,
728 channels: HashMap<OfferId, Channel>,
729 channel_responses: FuturesUnordered<
730 Pin<Box<dyn Send + Future<Output = (OfferId, u64, Result<ChannelResponse, RpcError>)>>>,
731 >,
732 external_server_send: Option<mesh::Sender<InitiateContactRequest>>,
733 relay_send: mesh::Sender<ModifyRelayRequest>,
734 channel_bitmap: Option<Arc<ChannelBitmap>>,
735 shared_event_port: Option<Box<dyn Send>>,
736 reset_done: Vec<Rpc<(), ()>>,
737 mnf_support: Option<MnfSupport>,
740}
741
742#[derive(Debug)]
743enum ChannelResponse {
744 Open(bool),
745 Close,
746 Gpadl(GpadlId, bool),
747 TeardownGpadl(GpadlId),
748 Modify(i32),
749}
750
751#[derive(Debug, Copy, Clone, PartialEq, Eq)]
752enum ChannelUnstickState {
753 None,
754 Queued,
755 NeedsRequeue,
756}
757
758struct Channel {
759 key: OfferKey,
760 send: mesh::Sender<ChannelRequest>,
761 seq: u64,
762 state: ChannelState,
763 gpadls: Arc<GpadlMap>,
764 guest_to_host_event: Arc<ChannelEvent>,
765 flags: protocol::OfferFlags,
766 reserved_state: ReservedState,
771 unstick_state: ChannelUnstickState,
772}
773
774struct ReservedState {
775 message_port: Option<Box<dyn GuestMessagePort>>,
776 target: ConnectionTarget,
777}
778
779struct ChannelOpenState {
780 open_params: OpenParams,
781 _event_port: Box<dyn Send>,
782 guest_event_port: Box<dyn GuestEventPort>,
783 host_to_guest_interrupt: Interrupt,
784}
785
786enum ChannelState {
787 Closed,
788 Open(Box<ChannelOpenState>),
789 Closing,
790}
791
792impl ServerTask {
793 fn handle_offer(&mut self, mut info: OfferInfo) -> anyhow::Result<()> {
794 let key = info.params.key();
795 let flags = info.params.flags;
796
797 if self.inner.mnf_support.is_some() && self.inner.synic.monitor_support().is_some() {
798 if info.params.use_mnf.is_relayed() {
803 info.params.use_mnf = MnfUsage::Enabled {
804 latency: Duration::ZERO,
805 }
806 }
807 } else if info.params.use_mnf.is_enabled() {
808 info.params.use_mnf = MnfUsage::Disabled;
811 }
812
813 let offer_id = self
814 .server
815 .with_notifier(&mut self.inner)
816 .offer_channel(info.params)
817 .context("channel offer failed")?;
818
819 tracing::debug!(?offer_id, %key, "offered channel");
820
821 let seq = self.next_seq;
822 self.next_seq += 1;
823 self.inner.channels.insert(
824 offer_id,
825 Channel {
826 key,
827 send: info.request_send,
828 state: ChannelState::Closed,
829 gpadls: GpadlMap::new(),
830 guest_to_host_event: Arc::new(ChannelEvent(info.event)),
831 seq,
832 flags,
833 reserved_state: ReservedState {
834 message_port: None,
835 target: ConnectionTarget { vp: 0, sint: 0 },
836 },
837 unstick_state: ChannelUnstickState::None,
838 },
839 );
840
841 self.server_request_recv.push(TaggedStream::new(
842 OfferInstanceId { offer_id, seq },
843 info.server_request_recv,
844 ));
845
846 Ok(())
847 }
848
849 fn handle_revoke(&mut self, id: OfferInstanceId) {
850 if let Some(channel) = self.inner.channels.get(&id.offer_id) {
853 if channel.seq == id.seq {
854 tracing::info!(?id.offer_id, "revoking channel");
855 self.inner.channels.remove(&id.offer_id);
856 self.server
857 .with_notifier(&mut self.inner)
858 .revoke_channel(id.offer_id);
859 }
860 }
861 }
862
863 fn handle_response(
864 &mut self,
865 offer_id: OfferId,
866 seq: u64,
867 response: Result<ChannelResponse, RpcError>,
868 ) {
869 let channel = self
871 .inner
872 .channels
873 .get(&offer_id)
874 .filter(|channel| channel.seq == seq);
875
876 if let Some(channel) = channel {
877 match response {
878 Ok(response) => match response {
879 ChannelResponse::Open(result) => self.handle_open(offer_id, result),
880 ChannelResponse::Close => self.handle_close(offer_id),
881 ChannelResponse::Gpadl(gpadl_id, ok) => {
882 self.handle_gpadl_create(offer_id, gpadl_id, ok)
883 }
884 ChannelResponse::TeardownGpadl(gpadl_id) => {
885 self.handle_gpadl_teardown(offer_id, gpadl_id)
886 }
887 ChannelResponse::Modify(status) => self.handle_modify_channel(offer_id, status),
888 },
889 Err(err) => {
890 tracing::error!(
891 key = %channel.key,
892 error = &err as &dyn std::error::Error,
893 "channel response failure, channel is in inconsistent state until revoked"
894 );
895 }
896 }
897 } else {
898 tracing::debug!(offer_id = ?offer_id, seq, ?response, "received response after revoke");
899 }
900 }
901
902 fn handle_open(&mut self, offer_id: OfferId, success: bool) {
903 let status = if success {
904 let channel = self
905 .inner
906 .channels
907 .get_mut(&offer_id)
908 .expect("channel exists");
909
910 if let Some(delay) = self.channel_unstick_delay {
913 if channel.unstick_state == ChannelUnstickState::None {
914 channel.unstick_state = ChannelUnstickState::Queued;
915 let seq = channel.seq;
916 let mut timer = PolledTimer::new(&self.driver);
917 self.channel_unstickers.push(Box::pin(async move {
918 timer.sleep(delay).await;
919 OfferInstanceId { offer_id, seq }
920 }));
921 } else {
922 channel.unstick_state = ChannelUnstickState::NeedsRequeue;
923 }
924 }
925
926 0
927 } else {
928 protocol::STATUS_UNSUCCESSFUL
929 };
930
931 self.server
932 .with_notifier(&mut self.inner)
933 .open_complete(offer_id, status);
934 }
935
936 fn handle_close(&mut self, offer_id: OfferId) {
937 let channel = self
938 .inner
939 .channels
940 .get_mut(&offer_id)
941 .expect("channel still exists");
942
943 match &mut channel.state {
944 ChannelState::Closing => {
945 channel.state = ChannelState::Closed;
946 self.server
947 .with_notifier(&mut self.inner)
948 .close_complete(offer_id);
949 }
950 _ => {
951 tracing::error!(?offer_id, "invalid close channel response");
952 }
953 };
954 }
955
956 fn handle_gpadl_create(&mut self, offer_id: OfferId, gpadl_id: GpadlId, ok: bool) {
957 let status = if ok { 0 } else { protocol::STATUS_UNSUCCESSFUL };
958 self.server
959 .with_notifier(&mut self.inner)
960 .gpadl_create_complete(offer_id, gpadl_id, status);
961 }
962
963 fn handle_gpadl_teardown(&mut self, offer_id: OfferId, gpadl_id: GpadlId) {
964 self.server
965 .with_notifier(&mut self.inner)
966 .gpadl_teardown_complete(offer_id, gpadl_id);
967 }
968
969 fn handle_modify_channel(&mut self, offer_id: OfferId, status: i32) {
970 self.server
971 .with_notifier(&mut self.inner)
972 .modify_channel_complete(offer_id, status);
973 }
974
975 fn handle_restore_channel(
976 &mut self,
977 offer_id: OfferId,
978 open: bool,
979 ) -> anyhow::Result<RestoreResult> {
980 let gpadls = self.server.channel_gpadls(offer_id);
981
982 let open_request = open
985 .then(|| -> anyhow::Result<_> {
986 let params = self.server.get_restore_open_params(offer_id)?;
987 let (channel, interrupt) = self.inner.open_channel(offer_id, ¶ms)?;
988 Ok(OpenRequest::new(
989 params.open_data,
990 interrupt,
991 self.server
992 .get_version()
993 .expect("must be connected")
994 .feature_flags,
995 channel.flags,
996 ))
997 })
998 .transpose()?;
999
1000 self.server
1001 .with_notifier(&mut self.inner)
1002 .restore_channel(offer_id, open_request.is_some())?;
1003
1004 let channel = self.inner.channels.get_mut(&offer_id).unwrap();
1005 for gpadl in &gpadls {
1006 if let Ok(buf) =
1007 MultiPagedRangeBuf::new(gpadl.request.count.into(), gpadl.request.buf.clone())
1008 {
1009 channel.gpadls.add(gpadl.request.id, buf);
1010 }
1011 }
1012
1013 let result = RestoreResult {
1014 open_request,
1015 gpadls,
1016 };
1017 Ok(result)
1018 }
1019
1020 async fn handle_request(&mut self, request: VmbusRequest) {
1021 tracing::debug!(?request, "handle_request");
1022 match request {
1023 VmbusRequest::Reset(rpc) => self.handle_reset(rpc),
1024 VmbusRequest::Inspect(deferred) => {
1025 deferred.respond(|resp| {
1026 resp.field("message_port", &self.inner.message_port)
1027 .field("running", self.inner.running)
1028 .field("hvsock_requests", self.inner.hvsock_requests)
1029 .field("channel_unstick_delay", self.channel_unstick_delay)
1030 .field_mut_with("unstick_channels", |v| {
1031 let v: inspect::ValueKind = if let Some(v) = v {
1032 if v == "force" {
1033 self.unstick_channels(true);
1034 v.into()
1035 } else {
1036 let v =
1037 v.parse().ok().context("expected false, true, or force")?;
1038 if v {
1039 self.unstick_channels(false);
1040 }
1041 v.into()
1042 }
1043 } else {
1044 false.into()
1045 };
1046 anyhow::Ok(v)
1047 })
1048 .merge(&self.server.with_notifier(&mut self.inner));
1049 });
1050 }
1051 VmbusRequest::Save(rpc) => rpc.handle_sync(|()| SavedState {
1052 server: self.server.save(),
1053 lost_synic_bug_fixed: true,
1054 }),
1055 VmbusRequest::Restore(rpc) => {
1056 rpc.handle(async |state| {
1057 self.unstick_on_start = !state.lost_synic_bug_fixed;
1058 if let Some(sender) = &self.inner.saved_state_notify {
1059 tracing::trace!("sending saved state to proxy");
1060 if let Err(err) = sender
1061 .call_failable(SavedStateRequest::Set, Box::new(state.server.clone()))
1062 .await
1063 {
1064 tracing::error!(
1065 err = &err as &dyn std::error::Error,
1066 "failed to restore proxy saved state"
1067 );
1068 return Err(RestoreError::ServerError(err.into()));
1069 }
1070 }
1071
1072 self.server
1073 .with_notifier(&mut self.inner)
1074 .restore(state.server)
1075 })
1076 .await
1077 }
1078 VmbusRequest::Stop(rpc) => rpc.handle_sync(|()| {
1079 if self.inner.running {
1080 self.inner.running = false;
1081 }
1082 }),
1083 VmbusRequest::Start => {
1084 if !self.inner.running {
1085 self.inner.running = true;
1086 if let Some(sender) = self.inner.saved_state_notify.as_ref() {
1087 tracing::trace!("sending clear saved state message to proxy");
1090 sender
1091 .call(SavedStateRequest::Clear, ())
1092 .await
1093 .expect("failed to clear proxy saved state");
1094 }
1095
1096 self.server
1097 .with_notifier(&mut self.inner)
1098 .revoke_unclaimed_channels();
1099 if self.unstick_on_start {
1100 tracing::info!(
1101 "lost synic bug fix is not in yet, call unstick_channels to mitigate the issue."
1102 );
1103 self.unstick_channels(false);
1104 self.unstick_on_start = false;
1105 }
1106 }
1107 }
1108 }
1109 }
1110
1111 fn handle_reset(&mut self, rpc: Rpc<(), ()>) {
1112 let needs_reset = self.inner.reset_done.is_empty();
1113 self.inner.reset_done.push(rpc);
1114 if needs_reset {
1115 self.server.with_notifier(&mut self.inner).reset();
1116 }
1117 }
1118
1119 fn handle_relay_response(&mut self, response: ModifyRelayResponse) {
1120 let response = match response {
1122 ModifyRelayResponse::Supported(state, features) => {
1123 let allocated_monitor_gpas = self
1126 .inner
1127 .mnf_support
1128 .as_ref()
1129 .and_then(|mnf| mnf.allocated_monitor_page);
1130
1131 ModifyConnectionResponse::Supported(state, features, allocated_monitor_gpas)
1132 }
1133 ModifyRelayResponse::Unsupported => ModifyConnectionResponse::Unsupported,
1134 ModifyRelayResponse::Modified(state) => ModifyConnectionResponse::Modified(state),
1135 };
1136
1137 self.server
1138 .with_notifier(&mut self.inner)
1139 .complete_modify_connection(response);
1140 }
1141
1142 fn handle_tl_connect_result(&mut self, result: HvsockConnectResult) {
1143 assert_ne!(self.inner.hvsock_requests, 0);
1144 self.inner.hvsock_requests -= 1;
1145
1146 self.server
1147 .with_notifier(&mut self.inner)
1148 .send_tl_connect_result(result);
1149 }
1150
1151 fn handle_synic_message(&mut self, message: SynicMessage) {
1152 match self
1153 .server
1154 .with_notifier(&mut self.inner)
1155 .handle_synic_message(message)
1156 {
1157 Ok(()) => {}
1158 Err(err) => {
1159 tracing::warn!(
1160 error = &err as &dyn std::error::Error,
1161 "synic message error"
1162 );
1163 }
1164 }
1165 }
1166
1167 fn handle_external_request(&mut self, request: InitiateContactRequest) {
1174 self.server
1175 .with_notifier(&mut self.inner)
1176 .initiate_contact(request);
1177 }
1178
1179 async fn run(
1180 &mut self,
1181 mut relay_response_recv: impl futures::stream::FusedStream<Item = ModifyRelayResponse> + Unpin,
1182 mut hvsock_recv: impl futures::stream::FusedStream<Item = HvsockConnectResult> + Unpin,
1183 ) {
1184 loop {
1185 let running_not_resetting = self.inner.running && self.inner.reset_done.is_empty();
1190 let mut external_requests = OptionFuture::from(
1191 running_not_resetting
1192 .then(|| {
1193 self.external_requests
1194 .as_mut()
1195 .map(|r| r.select_next_some())
1196 })
1197 .flatten(),
1198 );
1199
1200 let has_pending_messages = self.server.has_pending_messages();
1202 let message_port = self.inner.message_port.as_mut();
1203 let mut flush_pending_messages =
1204 OptionFuture::from((running_not_resetting && has_pending_messages).then(|| {
1205 poll_fn(|cx| {
1206 self.server.poll_flush_pending_messages(|msg| {
1207 message_port.poll_post_message(cx, VMBUS_MESSAGE_TYPE, msg.data())
1208 })
1209 })
1210 .fuse()
1211 }));
1212
1213 let mut message_recv = OptionFuture::from(
1217 (running_not_resetting
1218 && !has_pending_messages
1219 && self.inner.hvsock_requests < MAX_CONCURRENT_HVSOCK_REQUESTS)
1220 .then(|| self.message_recv.select_next_some()),
1221 );
1222
1223 let mut channel_response = OptionFuture::from(
1225 (self.inner.running || !self.inner.reset_done.is_empty())
1226 .then(|| self.inner.channel_responses.select_next_some()),
1227 );
1228
1229 let mut hvsock_response =
1231 OptionFuture::from(running_not_resetting.then(|| hvsock_recv.select_next_some()));
1232
1233 let mut channel_unstickers = OptionFuture::from(
1234 running_not_resetting.then(|| self.channel_unstickers.select_next_some()),
1235 );
1236
1237 futures::select! { r = self.task_recv.recv().fuse() => {
1239 if let Ok(request) = r {
1240 self.handle_request(request).await;
1241 } else {
1242 break;
1243 }
1244 }
1245 r = self.offer_recv.select_next_some() => {
1246 match r {
1247 OfferRequest::Offer(rpc) => {
1248 rpc.handle_failable_sync(|request| { self.handle_offer(request) })
1249 },
1250 OfferRequest::ForceReset(rpc) => {
1251 self.handle_reset(rpc);
1252 }
1253 }
1254 }
1255 r = self.server_request_recv.select_next_some() => {
1256 match r {
1257 (id, Some(request)) => match request {
1258 ChannelServerRequest::Restore(rpc) => rpc.handle_failable_sync(|open| {
1259 self.handle_restore_channel(id.offer_id, open)
1260 }),
1261 ChannelServerRequest::Revoke(rpc) => rpc.handle_sync(|_| {
1262 self.handle_revoke(id);
1263 })
1264 },
1265 (id, None) => self.handle_revoke(id),
1266 }
1267 }
1268 r = channel_response => {
1269 let (id, seq, response) = r.unwrap();
1270 self.handle_response(id, seq, response);
1271 }
1272 r = relay_response_recv.select_next_some() => {
1273 self.handle_relay_response(r);
1274 },
1275 r = hvsock_response => {
1276 self.handle_tl_connect_result(r.unwrap());
1277 }
1278 data = message_recv => {
1279 let data = data.unwrap();
1280 self.handle_synic_message(data);
1281 }
1282 r = external_requests => {
1283 let r = r.unwrap();
1284 self.handle_external_request(r);
1285 }
1286 r = channel_unstickers => {
1287 self.unstick_channel_by_id(r.unwrap());
1288 }
1289 _r = flush_pending_messages => {}
1290 complete => break,
1291 }
1292 }
1293 }
1294
1295 fn unstick_channels(&self, force: bool) {
1299 let Some(version) = self.server.get_version() else {
1300 tracing::warn!("cannot unstick when not connected");
1301 return;
1302 };
1303
1304 for channel in self.inner.channels.values() {
1305 let gm = self.inner.get_gm_for_channel(version, channel);
1306 if let Err(err) = Self::unstick_channel(gm, channel, force, true) {
1307 tracing::warn!(
1308 channel = %channel.key,
1309 error = err.as_ref() as &dyn std::error::Error,
1310 "could not unstick channel"
1311 );
1312 }
1313 }
1314 }
1315
1316 fn unstick_channel_by_id(&mut self, id: OfferInstanceId) {
1319 let Some(version) = self.server.get_version() else {
1320 tracelimit::warn_ratelimited!("cannot unstick when not connected");
1321 return;
1322 };
1323
1324 if let Some(channel) = self.inner.channels.get_mut(&id.offer_id) {
1325 if channel.seq != id.seq {
1326 return;
1328 }
1329
1330 if channel.unstick_state == ChannelUnstickState::NeedsRequeue {
1333 channel.unstick_state = ChannelUnstickState::Queued;
1334 let mut timer = PolledTimer::new(&self.driver);
1335 let delay = self.channel_unstick_delay.unwrap();
1336 self.channel_unstickers.push(Box::pin(async move {
1337 timer.sleep(delay).await;
1338 id
1339 }));
1340
1341 return;
1342 }
1343
1344 channel.unstick_state = ChannelUnstickState::None;
1345 let gm = select_gm_for_channel(
1346 &self.inner.gm,
1347 self.inner.private_gm.as_ref(),
1348 version,
1349 channel,
1350 );
1351 if let Err(err) = Self::unstick_channel(gm, channel, false, false) {
1352 tracelimit::warn_ratelimited!(
1353 channel = %channel.key,
1354 error = err.as_ref() as &dyn std::error::Error,
1355 "could not unstick channel"
1356 );
1357 }
1358 }
1359 }
1360
1361 fn unstick_channel(
1362 gm: &GuestMemory,
1363 channel: &Channel,
1364 force: bool,
1365 unstick_host: bool,
1366 ) -> anyhow::Result<()> {
1367 if let ChannelState::Open(state) = &channel.state {
1368 if force {
1369 tracing::info!(channel = %channel.key, "waking host and guest");
1370 if unstick_host {
1371 channel.guest_to_host_event.0.deliver();
1372 }
1373 state.host_to_guest_interrupt.deliver();
1374 return Ok(());
1375 }
1376
1377 let gpadl = channel
1378 .gpadls
1379 .clone()
1380 .view()
1381 .map(state.open_params.open_data.ring_gpadl_id)
1382 .context("couldn't find ring gpadl")?;
1383
1384 let aligned = AlignedGpadlView::new(gpadl)
1385 .ok()
1386 .context("ring not aligned")?;
1387 let (in_gpadl, out_gpadl) = aligned
1388 .split(state.open_params.open_data.ring_offset)
1389 .ok()
1390 .context("couldn't split ring")?;
1391
1392 if let Err(err) = Self::unstick_incoming_ring(
1393 gm,
1394 channel,
1395 in_gpadl,
1396 unstick_host.then_some(channel.guest_to_host_event.as_ref()),
1397 &state.host_to_guest_interrupt,
1398 ) {
1399 tracelimit::warn_ratelimited!(
1400 channel = %channel.key,
1401 error = err.as_ref() as &dyn std::error::Error,
1402 "could not unstick incoming ring"
1403 );
1404 }
1405 if let Err(err) = Self::unstick_outgoing_ring(
1406 gm,
1407 channel,
1408 out_gpadl,
1409 unstick_host.then_some(channel.guest_to_host_event.as_ref()),
1410 &state.host_to_guest_interrupt,
1411 ) {
1412 tracelimit::warn_ratelimited!(
1413 channel = %channel.key,
1414 error = err.as_ref() as &dyn std::error::Error,
1415 "could not unstick outgoing ring"
1416 );
1417 }
1418 }
1419 Ok(())
1420 }
1421
1422 fn unstick_incoming_ring(
1423 gm: &GuestMemory,
1424 channel: &Channel,
1425 in_gpadl: AlignedGpadlView,
1426 guest_to_host_event: Option<&ChannelEvent>,
1427 host_to_guest_interrupt: &Interrupt,
1428 ) -> anyhow::Result<()> {
1429 let control_page = lock_gpn_with_subrange(gm, in_gpadl.gpns()[0])?;
1430 if let Some(guest_to_host_event) = guest_to_host_event {
1431 if ring::reader_needs_signal(control_page.pages()[0]) {
1432 tracelimit::info_ratelimited!(channel = %channel.key, "waking host for incoming ring");
1433 guest_to_host_event.0.deliver();
1434 }
1435 }
1436
1437 let ring_size = gpadl_ring_size(&in_gpadl).try_into()?;
1438 if ring::writer_needs_signal(control_page.pages()[0], ring_size) {
1439 tracelimit::info_ratelimited!(channel = %channel.key, "waking guest for incoming ring");
1440 host_to_guest_interrupt.deliver();
1441 }
1442 Ok(())
1443 }
1444
1445 fn unstick_outgoing_ring(
1446 gm: &GuestMemory,
1447 channel: &Channel,
1448 out_gpadl: AlignedGpadlView,
1449 guest_to_host_event: Option<&ChannelEvent>,
1450 host_to_guest_interrupt: &Interrupt,
1451 ) -> anyhow::Result<()> {
1452 let control_page = lock_gpn_with_subrange(gm, out_gpadl.gpns()[0])?;
1453 if ring::reader_needs_signal(control_page.pages()[0]) {
1454 tracelimit::info_ratelimited!(channel = %channel.key, "waking guest for outgoing ring");
1455 host_to_guest_interrupt.deliver();
1456 }
1457
1458 if let Some(guest_to_host_event) = guest_to_host_event {
1459 let ring_size = gpadl_ring_size(&out_gpadl).try_into()?;
1460 if ring::writer_needs_signal(control_page.pages()[0], ring_size) {
1461 tracelimit::info_ratelimited!(channel = %channel.key, "waking host for outgoing ring");
1462 guest_to_host_event.0.deliver();
1463 }
1464 }
1465 Ok(())
1466 }
1467}
1468
1469impl Notifier for ServerTaskInner {
1470 fn notify(&mut self, offer_id: OfferId, action: channels::Action) {
1471 let channel = self
1472 .channels
1473 .get_mut(&offer_id)
1474 .expect("channel does not exist");
1475
1476 fn handle<I: 'static + Send, R: 'static + Send>(
1477 offer_id: OfferId,
1478 channel: &Channel,
1479 req: impl FnOnce(Rpc<I, R>) -> ChannelRequest,
1480 input: I,
1481 f: impl 'static + Send + FnOnce(R) -> ChannelResponse,
1482 ) -> Pin<Box<dyn Send + Future<Output = (OfferId, u64, Result<ChannelResponse, RpcError>)>>>
1483 {
1484 let recv = channel.send.call(req, input);
1485 let seq = channel.seq;
1486 Box::pin(async move {
1487 let r = recv.await.map(f);
1488 (offer_id, seq, r)
1489 })
1490 }
1491
1492 let response = match action {
1493 channels::Action::Open(open_params, version) => {
1494 let seq = channel.seq;
1495 match self.open_channel(offer_id, &open_params) {
1496 Ok((channel, interrupt)) => handle(
1497 offer_id,
1498 channel,
1499 ChannelRequest::Open,
1500 OpenRequest::new(
1501 open_params.open_data,
1502 interrupt,
1503 version.feature_flags,
1504 channel.flags,
1505 ),
1506 ChannelResponse::Open,
1507 ),
1508 Err(err) => {
1509 tracelimit::error_ratelimited!(
1510 err = err.as_ref() as &dyn std::error::Error,
1511 ?offer_id,
1512 "could not open channel",
1513 );
1514
1515 Box::pin(future::ready((
1518 offer_id,
1519 seq,
1520 Ok(ChannelResponse::Open(false)),
1521 )))
1522 }
1523 }
1524 }
1525 channels::Action::Close => {
1526 if let Some(channel_bitmap) = self.channel_bitmap.as_ref() {
1527 if let ChannelState::Open(ref state) = channel.state {
1528 channel_bitmap.unregister_channel(state.open_params.event_flag);
1529 }
1530 }
1531
1532 channel.state = ChannelState::Closing;
1533 handle(offer_id, channel, ChannelRequest::Close, (), |()| {
1534 ChannelResponse::Close
1535 })
1536 }
1537 channels::Action::Gpadl(gpadl_id, count, buf) => {
1538 channel.gpadls.add(
1539 gpadl_id,
1540 MultiPagedRangeBuf::new(count.into(), buf.clone()).unwrap(),
1541 );
1542 handle(
1543 offer_id,
1544 channel,
1545 ChannelRequest::Gpadl,
1546 GpadlRequest {
1547 id: gpadl_id,
1548 count,
1549 buf,
1550 },
1551 move |r| ChannelResponse::Gpadl(gpadl_id, r),
1552 )
1553 }
1554 channels::Action::TeardownGpadl {
1555 gpadl_id,
1556 post_restore,
1557 } => {
1558 if !post_restore {
1559 channel.gpadls.remove(gpadl_id, Box::new(|| ()));
1560 }
1561
1562 handle(
1563 offer_id,
1564 channel,
1565 ChannelRequest::TeardownGpadl,
1566 gpadl_id,
1567 move |()| ChannelResponse::TeardownGpadl(gpadl_id),
1568 )
1569 }
1570 channels::Action::Modify { target_vp } => {
1571 if let ChannelState::Open(state) = &mut channel.state {
1572 if let Err(err) = state.guest_event_port.set_target_vp(target_vp) {
1573 tracelimit::error_ratelimited!(
1574 error = &err as &dyn std::error::Error,
1575 channel = %channel.key,
1576 "could not modify channel",
1577 );
1578 let seq = channel.seq;
1579 Box::pin(async move {
1580 (
1581 offer_id,
1582 seq,
1583 Ok(ChannelResponse::Modify(protocol::STATUS_UNSUCCESSFUL)),
1584 )
1585 })
1586 } else {
1587 handle(
1588 offer_id,
1589 channel,
1590 ChannelRequest::Modify,
1591 ModifyRequest::TargetVp { target_vp },
1592 ChannelResponse::Modify,
1593 )
1594 }
1595 } else {
1596 unreachable!();
1597 }
1598 }
1599 };
1600 self.channel_responses.push(response);
1601 }
1602
1603 fn modify_connection(&mut self, mut request: ModifyConnectionRequest) -> anyhow::Result<()> {
1604 self.map_interrupt_page(request.interrupt_page)
1605 .context("Failed to map interrupt page.")?;
1606
1607 self.set_monitor_page(&mut request)
1608 .context("Failed to map monitor page.")?;
1609
1610 if let Some(vp) = request.target_message_vp {
1611 self.message_port.set_target_vp(vp)?;
1612 }
1613
1614 if request.notify_relay {
1615 self.relay_send.send(request.into());
1616 }
1617
1618 Ok(())
1619 }
1620
1621 fn forward_unhandled(&mut self, request: InitiateContactRequest) {
1622 if let Some(external_server) = &self.external_server_send {
1623 external_server.send(request);
1624 } else {
1625 tracing::warn!(?request, "nowhere to forward unhandled request")
1626 }
1627 }
1628
1629 fn inspect(&self, version: Option<VersionInfo>, offer_id: OfferId, req: inspect::Request<'_>) {
1630 let channel = self.channels.get(&offer_id).expect("should exist");
1631 let mut resp = req.respond();
1632 if let ChannelState::Open(state) = &channel.state {
1633 let mem = self.get_gm_for_channel(version.expect("must be connected"), channel);
1634 inspect_rings(
1635 &mut resp,
1636 mem,
1637 channel.gpadls.clone(),
1638 &state.open_params.open_data,
1639 );
1640 }
1641 }
1642
1643 fn send_message(&mut self, message: &OutgoingMessage, target: MessageTarget) -> bool {
1644 if !self.running && !self.send_messages_while_stopped {
1657 if !matches!(target, MessageTarget::Default) {
1658 tracelimit::error_ratelimited!(?target, "dropping message while paused");
1659 }
1660 return false;
1661 }
1662
1663 let mut port_storage;
1664 let port = match target {
1665 MessageTarget::Default => self.message_port.as_mut(),
1666 MessageTarget::ReservedChannel(offer_id, target) => {
1667 if let Some(port) = self.get_reserved_channel_message_port(offer_id, target) {
1668 port.as_mut()
1669 } else {
1670 return true;
1672 }
1673 }
1674 MessageTarget::Custom(target) => {
1675 port_storage = match self.synic.new_guest_message_port(
1676 self.redirect_vtl,
1677 target.vp,
1678 target.sint,
1679 ) {
1680 Ok(port) => port,
1681 Err(err) => {
1682 tracing::error!(
1683 ?err,
1684 ?self.redirect_vtl,
1685 ?target,
1686 "could not create message port"
1687 );
1688
1689 return true;
1691 }
1692 };
1693 port_storage.as_mut()
1694 }
1695 };
1696
1697 matches!(
1700 port.poll_post_message(
1701 &mut std::task::Context::from_waker(std::task::Waker::noop()),
1702 VMBUS_MESSAGE_TYPE,
1703 message.data()
1704 ),
1705 Poll::Ready(())
1706 )
1707 }
1708
1709 fn notify_hvsock(&mut self, request: &HvsockConnectRequest) {
1710 self.hvsock_requests += 1;
1711 self.hvsock_send.send(*request);
1712 }
1713
1714 fn reset_complete(&mut self) {
1715 if let Some(monitor) = self.synic.monitor_support() {
1716 if let Err(err) = monitor.set_monitor_page(self.vtl, None) {
1717 tracing::warn!(?err, "resetting monitor page failed")
1718 }
1719 }
1720
1721 self.unreserve_channels();
1722 for done in self.reset_done.drain(..) {
1723 done.complete(());
1724 }
1725 }
1726
1727 fn unload_complete(&mut self) {
1728 self.unreserve_channels();
1729 }
1730}
1731
1732impl ServerTaskInner {
1733 fn open_channel(
1734 &mut self,
1735 offer_id: OfferId,
1736 open_params: &OpenParams,
1737 ) -> anyhow::Result<(&mut Channel, Interrupt)> {
1738 let channel = self
1739 .channels
1740 .get_mut(&offer_id)
1741 .expect("channel does not exist");
1742
1743 if let Some(channel_bitmap) = self.channel_bitmap.as_ref() {
1745 channel_bitmap.register_channel(
1746 open_params.event_flag,
1747 channel.guest_to_host_event.0.clone(),
1748 );
1749 }
1750 let event_port = self
1755 .synic
1756 .add_event_port(
1757 open_params.connection_id,
1758 self.vtl,
1759 channel.guest_to_host_event.clone(),
1760 open_params.monitor_info,
1761 )
1762 .context("failed to create guest-to-host event port")?;
1763
1764 let (target_vp, event_flag) = if self.channel_bitmap.is_some() {
1767 (0, 0)
1768 } else {
1769 (open_params.open_data.target_vp, open_params.event_flag)
1770 };
1771 let (target_vtl, target_sint) = if open_params.flags.redirect_interrupt() {
1772 (self.redirect_vtl, self.redirect_sint)
1773 } else {
1774 (self.vtl, SINT)
1775 };
1776
1777 let guest_event_port = self.synic.new_guest_event_port(
1778 VmbusServer::get_child_event_port_id(open_params.channel_id, SINT, self.vtl),
1779 target_vtl,
1780 target_vp,
1781 target_sint,
1782 event_flag,
1783 open_params.monitor_info,
1784 )?;
1785
1786 let interrupt = ChannelBitmap::create_interrupt(
1787 &self.channel_bitmap,
1788 guest_event_port.interrupt(),
1789 open_params.event_flag,
1790 );
1791
1792 channel.reserved_state.message_port = None;
1794
1795 if let Some(target) = open_params.reserved_target {
1797 channel.reserved_state.message_port = Some(self.synic.new_guest_message_port(
1798 self.redirect_vtl,
1799 target.vp,
1800 target.sint,
1801 )?);
1802
1803 channel.reserved_state.target = target;
1804 }
1805
1806 channel.state = ChannelState::Open(Box::new(ChannelOpenState {
1807 open_params: *open_params,
1808 _event_port: event_port,
1809 guest_event_port,
1810 host_to_guest_interrupt: interrupt.clone(),
1811 }));
1812 Ok((channel, interrupt))
1813 }
1814
1815 fn map_interrupt_page(&mut self, interrupt_page: Update<u64>) -> anyhow::Result<()> {
1818 let interrupt_page = match interrupt_page {
1819 Update::Unchanged => return Ok(()),
1820 Update::Reset => {
1821 self.channel_bitmap = None;
1822 self.shared_event_port = None;
1823 return Ok(());
1824 }
1825 Update::Set(interrupt_page) => interrupt_page,
1826 };
1827
1828 assert_ne!(interrupt_page, 0);
1829
1830 if interrupt_page % PAGE_SIZE as u64 != 0 {
1831 anyhow::bail!("interrupt page {:#x} is not page aligned", interrupt_page);
1832 }
1833
1834 let interrupt_page = lock_page_with_subrange(&self.gm, interrupt_page)?;
1837 let channel_bitmap = Arc::new(ChannelBitmap::new(interrupt_page));
1838 self.channel_bitmap = Some(channel_bitmap.clone());
1839
1840 let interrupt = Interrupt::from_fn(move || {
1842 channel_bitmap.handle_shared_interrupt();
1843 });
1844
1845 self.shared_event_port = Some(self.synic.add_event_port(
1846 SHARED_EVENT_CONNECTION_ID,
1847 self.vtl,
1848 Arc::new(ChannelEvent(interrupt)),
1849 None,
1850 )?);
1851
1852 Ok(())
1853 }
1854
1855 fn set_monitor_page(&mut self, request: &mut ModifyConnectionRequest) -> anyhow::Result<()> {
1856 let monitor_page = match request.monitor_page {
1857 Update::Unchanged => return Ok(()),
1858 Update::Reset => None,
1859 Update::Set(value) => Some(value),
1860 };
1861
1862 if self.channels.iter().any(|(_, c)| {
1864 matches!(
1865 &c.state,
1866 ChannelState::Open(state) if state.open_params.monitor_info.is_some()
1867 )
1868 }) {
1869 anyhow::bail!("attempt to change monitor page while open channels using mnf");
1870 }
1871
1872 if let Some(mnf_support) = self.mnf_support.as_mut() {
1876 if let Some(monitor) = self.synic.monitor_support() {
1877 mnf_support.allocated_monitor_page = None;
1878
1879 if let Some(version) = request.version {
1880 if version.feature_flags.server_specified_monitor_pages() {
1881 if let Some(monitor_page) = monitor.allocate_monitor_page(self.vtl)? {
1882 tracelimit::info_ratelimited!(
1883 ?monitor_page,
1884 "using server-allocated monitor pages"
1885 );
1886 mnf_support.allocated_monitor_page = Some(monitor_page);
1887 }
1888 }
1889 }
1890
1891 if mnf_support.allocated_monitor_page.is_none() {
1893 if let Err(err) = monitor.set_monitor_page(self.vtl, monitor_page) {
1894 anyhow::bail!(
1895 "setting monitor page failed, err = {err:?}, monitor_page = {monitor_page:?}"
1896 );
1897 }
1898 }
1899 }
1900
1901 request.monitor_page = Update::Unchanged;
1904 }
1905
1906 Ok(())
1907 }
1908
1909 fn get_reserved_channel_message_port(
1910 &mut self,
1911 offer_id: OfferId,
1912 new_target: ConnectionTarget,
1913 ) -> Option<&mut Box<dyn GuestMessagePort>> {
1914 let channel = self
1915 .channels
1916 .get_mut(&offer_id)
1917 .expect("channel does not exist");
1918
1919 assert!(
1920 channel.reserved_state.message_port.is_some(),
1921 "channel is not reserved"
1922 );
1923
1924 if channel.reserved_state.target.sint != new_target.sint {
1927 channel.reserved_state.message_port = None;
1929 let message_port = self
1930 .synic
1931 .new_guest_message_port(self.redirect_vtl, new_target.vp, new_target.sint)
1932 .inspect_err(|err| {
1933 tracing::error!(
1934 ?err,
1935 ?self.redirect_vtl,
1936 ?new_target,
1937 "could not create reserved channel message port"
1938 )
1939 })
1940 .ok()?;
1941
1942 channel.reserved_state.message_port = Some(message_port);
1943 channel.reserved_state.target = new_target;
1944 } else if channel.reserved_state.target.vp != new_target.vp {
1945 let message_port = channel.reserved_state.message_port.as_mut().unwrap();
1946
1947 if let Err(err) = message_port.set_target_vp(new_target.vp) {
1950 tracing::error!(
1951 ?err,
1952 ?self.redirect_vtl,
1953 ?new_target,
1954 "could not update reserved channel message port"
1955 );
1956 }
1957
1958 channel.reserved_state.target = new_target;
1959 return Some(message_port);
1960 }
1961
1962 Some(channel.reserved_state.message_port.as_mut().unwrap())
1963 }
1964
1965 fn unreserve_channels(&mut self) {
1966 for channel in self.channels.values_mut() {
1968 if let ChannelState::Closed = channel.state {
1969 channel.reserved_state.message_port = None;
1970 }
1971 }
1972 }
1973
1974 fn get_gm_for_channel(&self, version: VersionInfo, channel: &Channel) -> &GuestMemory {
1975 select_gm_for_channel(&self.gm, self.private_gm.as_ref(), version, channel)
1976 }
1977}
1978
1979fn select_gm_for_channel<'a>(
1980 gm: &'a GuestMemory,
1981 private_gm: Option<&'a GuestMemory>,
1982 version: VersionInfo,
1983 channel: &Channel,
1984) -> &'a GuestMemory {
1985 if channel.flags.confidential_ring_buffer() && version.feature_flags.confidential_channels() {
1986 if let Some(private_gm) = private_gm {
1987 return private_gm;
1988 }
1989 }
1990
1991 gm
1992}
1993
1994#[derive(Clone)]
1996pub struct VmbusServerControl {
1997 mem: GuestMemory,
1998 private_mem: Option<GuestMemory>,
1999 send: mesh::Sender<OfferRequest>,
2000 use_event: bool,
2001 force_confidential_external_memory: bool,
2002}
2003
2004impl VmbusServerControl {
2005 pub async fn offer_core(&self, offer_info: OfferInfo) -> anyhow::Result<OfferResources> {
2008 let flags = offer_info.params.flags;
2009 self.send
2010 .call_failable(OfferRequest::Offer, offer_info)
2011 .await?;
2012 Ok(OfferResources::new(
2013 self.mem.clone(),
2014 if flags.confidential_ring_buffer() || flags.confidential_external_memory() {
2015 self.private_mem.clone()
2016 } else {
2017 None
2018 },
2019 ))
2020 }
2021
2022 pub async fn force_reset(&self) -> anyhow::Result<()> {
2025 self.send
2026 .call(OfferRequest::ForceReset, ())
2027 .await
2028 .context("vmbus server is gone")
2029 }
2030
2031 async fn offer(&self, request: OfferInput) -> anyhow::Result<OfferResources> {
2032 let mut offer_info = OfferInfo {
2033 params: request.params.into(),
2034 event: request.event,
2035 request_send: request.request_send,
2036 server_request_recv: request.server_request_recv,
2037 };
2038
2039 if self.force_confidential_external_memory {
2040 tracing::warn!(
2041 key = %offer_info.params.key(),
2042 "forcing confidential external memory for channel"
2043 );
2044
2045 offer_info
2046 .params
2047 .flags
2048 .set_confidential_external_memory(true);
2049 }
2050
2051 self.offer_core(offer_info).await
2052 }
2053}
2054
2055fn inspect_rings(
2057 resp: &mut inspect::Response<'_>,
2058 gm: &GuestMemory,
2059 gpadl_map: Arc<GpadlMap>,
2060 open_data: &OpenData,
2061) -> Option<()> {
2062 let gpadl = gpadl_map
2063 .view()
2064 .map(GpadlId(open_data.ring_gpadl_id.0))
2065 .ok()?;
2066
2067 let aligned = AlignedGpadlView::new(gpadl).ok()?;
2068 let (in_gpadl, out_gpadl) = aligned.split(open_data.ring_offset).ok()?;
2069 resp.child("incoming_ring", |req| inspect_ring(req, &in_gpadl, gm));
2070 resp.child("outgoing_ring", |req| inspect_ring(req, &out_gpadl, gm));
2071 Some(())
2072}
2073
2074fn inspect_ring(req: inspect::Request<'_>, gpadl: &AlignedGpadlView, gm: &GuestMemory) {
2076 let mut resp = req.respond();
2077
2078 resp.hex("ring_size", gpadl_ring_size(gpadl));
2079
2080 if let Ok(pages) = lock_gpn_with_subrange(gm, gpadl.gpns()[0]) {
2083 ring::inspect_ring(pages.pages()[0], &mut resp);
2084 }
2085}
2086
2087fn gpadl_ring_size(gpadl: &AlignedGpadlView) -> usize {
2088 (gpadl.gpns().len() - 1) * PAGE_SIZE
2090}
2091
2092fn lock_page_with_subrange(gm: &GuestMemory, offset: u64) -> anyhow::Result<guestmem::LockedPages> {
2097 Ok(gm
2098 .lockable_subrange(offset, PAGE_SIZE as u64)?
2099 .lock_gpns(false, &[0])?)
2100}
2101
2102fn lock_gpn_with_subrange(gm: &GuestMemory, gpn: u64) -> anyhow::Result<guestmem::LockedPages> {
2107 lock_page_with_subrange(gm, gpn * PAGE_SIZE as u64)
2108}
2109
2110pub(crate) struct MessageSender {
2111 send: mpsc::Sender<SynicMessage>,
2112 multiclient: bool,
2113}
2114
2115impl MessageSender {
2116 fn poll_handle_message(
2117 &self,
2118 cx: &mut std::task::Context<'_>,
2119 msg: &[u8],
2120 trusted: bool,
2121 ) -> Poll<Result<(), SendError>> {
2122 let mut send = self.send.clone();
2123 ready!(send.poll_ready(cx))?;
2124 send.start_send(SynicMessage {
2125 data: msg.to_vec(),
2126 multiclient: self.multiclient,
2127 trusted,
2128 })?;
2129
2130 Poll::Ready(Ok(()))
2131 }
2132}
2133
2134impl MessagePort for MessageSender {
2135 fn poll_handle_message(
2136 &self,
2137 cx: &mut std::task::Context<'_>,
2138 msg: &[u8],
2139 trusted: bool,
2140 ) -> Poll<()> {
2141 if let Err(err) = ready!(self.poll_handle_message(cx, msg, trusted)) {
2142 tracelimit::error_ratelimited!(
2143 error = &err as &dyn std::error::Error,
2144 "failed to send message"
2145 );
2146 }
2147
2148 Poll::Ready(())
2149 }
2150}
2151
2152#[async_trait]
2153impl ParentBus for VmbusServerControl {
2154 async fn add_child(&self, request: OfferInput) -> anyhow::Result<OfferResources> {
2155 self.offer(request).await
2156 }
2157
2158 fn clone_bus(&self) -> Box<dyn ParentBus> {
2159 Box::new(self.clone())
2160 }
2161
2162 fn use_event(&self) -> bool {
2163 self.use_event
2164 }
2165}