1#![expect(missing_docs)]
5#![forbid(unsafe_code)]
6
7mod channel_bitmap;
8pub mod channels;
9pub mod hvsock;
10mod monitor;
11mod proxyintegration;
12#[cfg(test)]
13mod tests;
14
15pub type Guid = guid::Guid;
17
18use anyhow::Context;
19use async_trait::async_trait;
20use channel_bitmap::ChannelBitmap;
21use channels::ConnectionTarget;
22pub use channels::InitiateContactRequest;
23use channels::MessageTarget;
24pub use channels::MnfUsage;
25use channels::ModifyConnectionRequest;
26use channels::ModifyConnectionResponse;
27use channels::Notifier;
28use channels::OfferId;
29pub use channels::OfferParamsInternal;
30use channels::OpenParams;
31use channels::RestoreError;
32pub use channels::Update;
33use futures::FutureExt;
34use futures::StreamExt;
35use futures::channel::mpsc;
36use futures::channel::mpsc::SendError;
37use futures::future::OptionFuture;
38use futures::future::poll_fn;
39use futures::stream::SelectAll;
40use guestmem::GuestMemory;
41use hvdef::Vtl;
42use inspect::Inspect;
43use mesh::payload::Protobuf;
44use mesh::rpc::FailableRpc;
45use mesh::rpc::Rpc;
46use mesh::rpc::RpcError;
47use mesh::rpc::RpcSend;
48use pal_async::driver::Driver;
49use pal_async::driver::SpawnDriver;
50use pal_async::task::Task;
51use pal_async::timer::PolledTimer;
52use pal_event::Event;
53#[cfg(windows)]
54pub use proxyintegration::ProxyIntegration;
55#[cfg(windows)]
56pub use proxyintegration::ProxyServerInfo;
57use ring::PAGE_SIZE;
58use std::collections::HashMap;
59use std::future;
60use std::future::Future;
61use std::pin::Pin;
62use std::sync::Arc;
63use std::task::Poll;
64use std::task::ready;
65use std::time::Duration;
66use unicycle::FuturesUnordered;
67use vmbus_channel::bus::ChannelRequest;
68use vmbus_channel::bus::ChannelServerRequest;
69use vmbus_channel::bus::GpadlRequest;
70use vmbus_channel::bus::ModifyRequest;
71use vmbus_channel::bus::OfferInput;
72use vmbus_channel::bus::OfferKey;
73use vmbus_channel::bus::OfferResources;
74use vmbus_channel::bus::OpenData;
75use vmbus_channel::bus::OpenRequest;
76use vmbus_channel::bus::ParentBus;
77use vmbus_channel::bus::RestoreResult;
78use vmbus_channel::gpadl::GpadlMap;
79use vmbus_channel::gpadl_ring::AlignedGpadlView;
80use vmbus_core::HvsockConnectRequest;
81use vmbus_core::HvsockConnectResult;
82use vmbus_core::MaxVersionInfo;
83use vmbus_core::OutgoingMessage;
84use vmbus_core::TaggedStream;
85use vmbus_core::VMBUS_SINT;
86use vmbus_core::VersionInfo;
87use vmbus_core::protocol;
88pub use vmbus_core::protocol::GpadlId;
89#[cfg(windows)]
90use vmbus_proxy::ProxyHandle;
91use vmbus_ring as ring;
92use vmbus_ring::gparange::MultiPagedRangeBuf;
93use vmcore::interrupt::Interrupt;
94use vmcore::save_restore::SavedStateRoot;
95use vmcore::synic::EventPort;
96use vmcore::synic::GuestEventPort;
97use vmcore::synic::GuestMessagePort;
98use vmcore::synic::MessagePort;
99use vmcore::synic::MonitorPageGpas;
100use vmcore::synic::SynicPortAccess;
101
102pub const REDIRECT_SINT: u8 = 7;
103pub const REDIRECT_VTL: Vtl = Vtl::Vtl2;
104const SHARED_EVENT_CONNECTION_ID: u32 = 2;
105const EVENT_PORT_ID: u32 = 2;
106const VMBUS_MESSAGE_TYPE: u32 = 1;
107
108const MAX_CONCURRENT_HVSOCK_REQUESTS: usize = 16;
109
110#[derive(Inspect)]
111pub struct VmbusServer {
112 #[inspect(flatten, send = "VmbusRequest::Inspect")]
113 task_send: mesh::Sender<VmbusRequest>,
114 #[inspect(skip)]
115 control: Arc<VmbusServerControl>,
116 #[inspect(skip)]
117 _message_port: Box<dyn Sync + Send>,
118 #[inspect(skip)]
119 _multiclient_message_port: Option<Box<dyn Sync + Send>>,
120 #[inspect(skip)]
121 task: Task<ServerTask>,
122}
123
124pub struct VmbusServerBuilder<T: SpawnDriver> {
125 spawner: T,
126 synic: Arc<dyn SynicPortAccess>,
127 gm: GuestMemory,
128 private_gm: Option<GuestMemory>,
129 vtl: Vtl,
130 hvsock_notify: Option<HvsockServerChannelHalf>,
131 server_relay: Option<VmbusServerChannelHalf>,
132 saved_state_notify: Option<mesh::Sender<SavedStateRequest>>,
133 external_server: Option<mesh::Sender<InitiateContactRequest>>,
134 external_requests: Option<mesh::Receiver<InitiateContactRequest>>,
135 use_message_redirect: bool,
136 channel_id_offset: u16,
137 max_version: Option<MaxVersionInfo>,
138 delay_max_version: bool,
139 enable_mnf: bool,
140 force_confidential_external_memory: bool,
141 send_messages_while_stopped: bool,
142 channel_unstick_delay: Option<Duration>,
143 use_absolute_channel_order: bool,
144}
145
146#[derive(mesh::MeshPayload)]
147pub enum SavedStateRequest {
149 Set(FailableRpc<Box<channels::SavedState>, ()>),
150 Clear(Rpc<(), ()>),
151}
152
153pub struct ServerChannelHalf<Request, Response> {
155 request_send: mesh::Sender<Request>,
156 response_receive: mesh::Receiver<Response>,
157}
158
159pub struct RelayChannelHalf<Request, Response> {
161 pub request_receive: mesh::Receiver<Request>,
162 pub response_send: mesh::Sender<Response>,
163}
164
165pub struct RelayChannel<Request, Response> {
167 pub relay_half: RelayChannelHalf<Request, Response>,
168 pub server_half: ServerChannelHalf<Request, Response>,
169}
170
171impl<Request: 'static + Send, Response: 'static + Send> RelayChannel<Request, Response> {
172 pub fn new() -> Self {
174 let (request_send, request_receive) = mesh::channel();
175 let (response_send, response_receive) = mesh::channel();
176 Self {
177 relay_half: RelayChannelHalf {
178 request_receive,
179 response_send,
180 },
181 server_half: ServerChannelHalf {
182 request_send,
183 response_receive,
184 },
185 }
186 }
187}
188
189pub type VmbusServerChannelHalf = ServerChannelHalf<ModifyRelayRequest, ModifyRelayResponse>;
190pub type VmbusRelayChannelHalf = RelayChannelHalf<ModifyRelayRequest, ModifyRelayResponse>;
191pub type VmbusRelayChannel = RelayChannel<ModifyRelayRequest, ModifyRelayResponse>;
192pub type HvsockServerChannelHalf = ServerChannelHalf<HvsockConnectRequest, HvsockConnectResult>;
193pub type HvsockRelayChannelHalf = RelayChannelHalf<HvsockConnectRequest, HvsockConnectResult>;
194pub type HvsockRelayChannel = RelayChannel<HvsockConnectRequest, HvsockConnectResult>;
195
196#[derive(Debug, Copy, Clone)]
205pub struct ModifyRelayRequest {
206 pub version: Option<u32>,
207 pub monitor_page: Update<MonitorPageGpas>,
208 pub use_interrupt_page: Option<bool>,
209}
210
211#[derive(Debug, Copy, Clone)]
213pub enum ModifyRelayResponse {
214 Supported(protocol::ConnectionState, protocol::FeatureFlags),
218 Unsupported,
221 Modified(protocol::ConnectionState),
224}
225
226impl From<ModifyConnectionRequest> for ModifyRelayRequest {
227 fn from(value: ModifyConnectionRequest) -> Self {
228 Self {
229 version: value.version.map(|v| v.version as u32),
230 monitor_page: value.monitor_page,
231 use_interrupt_page: match value.interrupt_page {
232 Update::Unchanged => None,
233 Update::Reset => Some(false),
234 Update::Set(_) => Some(true),
235 },
236 }
237 }
238}
239
240#[derive(Debug)]
241enum VmbusRequest {
242 Reset(Rpc<(), ()>),
243 Inspect(inspect::Deferred),
244 Save(Rpc<(), SavedState>),
245 Restore(Rpc<Box<SavedState>, Result<(), RestoreError>>),
246 Start,
247 Stop(Rpc<(), ()>),
248}
249
250#[derive(mesh::MeshPayload, Debug)]
251pub struct OfferInfo {
252 pub params: OfferParamsInternal,
253 pub event: Interrupt,
254 pub request_send: mesh::Sender<ChannelRequest>,
255 pub server_request_recv: mesh::Receiver<ChannelServerRequest>,
256}
257
258#[expect(clippy::large_enum_variant)]
259#[derive(mesh::MeshPayload)]
260pub(crate) enum OfferRequest {
261 Offer(FailableRpc<OfferInfo, ()>),
262 ForceReset(Rpc<(), ()>),
263}
264
265struct ChannelEvent(Interrupt);
266
267impl EventPort for ChannelEvent {
268 fn handle_event(&self, _flag: u16) {
269 self.0.deliver();
270 }
271
272 fn os_event(&self) -> Option<&Event> {
273 self.0.event()
274 }
275}
276
277#[derive(Debug, Protobuf, SavedStateRoot)]
278#[mesh(package = "vmbus.server")]
279pub struct SavedState {
280 #[mesh(1)]
281 pub server: channels::SavedState,
282 #[mesh(2)]
286 pub lost_synic_bug_fixed: bool,
287}
288
289const MESSAGE_CONNECTION_ID: u32 = 1;
290const MULTICLIENT_MESSAGE_CONNECTION_ID: u32 = 4;
291
292impl<T: SpawnDriver + Clone> VmbusServerBuilder<T> {
293 pub fn new(spawner: T, synic: Arc<dyn SynicPortAccess>, gm: GuestMemory) -> Self {
295 Self {
296 spawner,
297 synic,
298 gm,
299 private_gm: None,
300 vtl: Vtl::Vtl0,
301 hvsock_notify: None,
302 server_relay: None,
303 saved_state_notify: None,
304 external_server: None,
305 external_requests: None,
306 use_message_redirect: false,
307 channel_id_offset: 0,
308 max_version: None,
309 delay_max_version: false,
310 enable_mnf: false,
311 force_confidential_external_memory: false,
312 send_messages_while_stopped: false,
313 channel_unstick_delay: Some(Duration::from_millis(100)),
314 use_absolute_channel_order: false,
315 }
316 }
317
318 pub fn private_gm(mut self, private_gm: Option<GuestMemory>) -> Self {
322 self.private_gm = private_gm;
323 self
324 }
325
326 pub fn vtl(mut self, vtl: Vtl) -> Self {
328 self.vtl = vtl;
329 self
330 }
331
332 pub fn hvsock_notify(mut self, hvsock_notify: Option<HvsockServerChannelHalf>) -> Self {
334 self.hvsock_notify = hvsock_notify;
335 self
336 }
337
338 pub fn saved_state_notify(
340 mut self,
341 saved_state_notify: Option<mesh::Sender<SavedStateRequest>>,
342 ) -> Self {
343 self.saved_state_notify = saved_state_notify;
344 self
345 }
346
347 pub fn server_relay(mut self, server_relay: Option<VmbusServerChannelHalf>) -> Self {
350 self.server_relay = server_relay;
351 self
352 }
353
354 pub fn external_requests(
356 mut self,
357 external_requests: Option<mesh::Receiver<InitiateContactRequest>>,
358 ) -> Self {
359 self.external_requests = external_requests;
360 self
361 }
362
363 pub fn external_server(
366 mut self,
367 external_server: Option<mesh::Sender<InitiateContactRequest>>,
368 ) -> Self {
369 self.external_server = external_server;
370 self
371 }
372
373 pub fn use_message_redirect(mut self, use_message_redirect: bool) -> Self {
375 self.use_message_redirect = use_message_redirect;
376 self
377 }
378
379 pub fn enable_channel_id_offset(mut self, enable: bool) -> Self {
384 self.channel_id_offset = if enable { 1024 } else { 0 };
385 self
386 }
387
388 pub fn max_version(mut self, max_version: Option<MaxVersionInfo>) -> Self {
392 self.max_version = max_version;
393 self
394 }
395
396 pub fn delay_max_version(mut self, delay: bool) -> Self {
401 self.delay_max_version = delay;
402 self
403 }
404
405 pub fn enable_mnf(mut self, enable: bool) -> Self {
409 self.enable_mnf = enable;
410 self
411 }
412
413 pub fn force_confidential_external_memory(mut self, force: bool) -> Self {
416 self.force_confidential_external_memory = force;
417 self
418 }
419
420 pub fn send_messages_while_stopped(mut self, send: bool) -> Self {
427 self.send_messages_while_stopped = send;
428 self
429 }
430
431 pub fn channel_unstick_delay(mut self, delay: Option<Duration>) -> Self {
438 self.channel_unstick_delay = delay;
439 self
440 }
441
442 pub fn use_absolute_channel_order(mut self, assign: bool) -> Self {
446 self.use_absolute_channel_order = assign;
447 self
448 }
449
450 pub fn build(self) -> anyhow::Result<VmbusServer> {
455 #[expect(clippy::disallowed_methods)] let (message_send, message_recv) = mpsc::channel(64);
457 let message_sender = Arc::new(MessageSender {
458 send: message_send.clone(),
459 multiclient: self.use_message_redirect,
460 });
461
462 let (redirect_vtl, redirect_sint) = if self.use_message_redirect {
463 (REDIRECT_VTL, REDIRECT_SINT)
464 } else {
465 (self.vtl, VMBUS_SINT)
466 };
467
468 let connection_id = if self.vtl == Vtl::Vtl0 && !self.use_message_redirect {
471 MESSAGE_CONNECTION_ID
472 } else {
473 VmbusServer::get_child_message_connection_id(0, redirect_sint, redirect_vtl)
476 };
477
478 let _message_port = self
479 .synic
480 .add_message_port(connection_id, redirect_vtl, message_sender)
481 .context("failed to create vmbus synic ports")?;
482
483 let _multiclient_message_port = if self.vtl == Vtl::Vtl0 && !self.use_message_redirect {
487 let multiclient_message_sender = Arc::new(MessageSender {
488 send: message_send,
489 multiclient: true,
490 });
491
492 Some(
493 self.synic
494 .add_message_port(
495 MULTICLIENT_MESSAGE_CONNECTION_ID,
496 self.vtl,
497 multiclient_message_sender,
498 )
499 .context("failed to create vmbus synic ports")?,
500 )
501 } else {
502 None
503 };
504
505 let (offer_send, offer_recv) = mesh::mpsc_channel();
506 let control = Arc::new(VmbusServerControl {
507 mem: self.gm.clone(),
508 private_mem: self.private_gm.clone(),
509 send: offer_send,
510 use_event: self.synic.prefer_os_events(),
511 force_confidential_external_memory: self.force_confidential_external_memory,
512 });
513
514 let mut server = channels::Server::new(
515 self.vtl,
516 connection_id,
517 self.channel_id_offset,
518 self.use_absolute_channel_order,
519 );
520
521 server.set_require_server_allocated_mnf(self.enable_mnf && self.private_gm.is_some());
526
527 if let Some(version) = self.max_version {
529 server.set_compatibility_version(version, self.delay_max_version);
530 }
531 let (relay_request_send, relay_response_recv) =
532 if let Some(server_relay) = self.server_relay {
533 let r = server_relay.response_receive.boxed().fuse();
534 (server_relay.request_send, r)
535 } else {
536 let (req_send, req_recv) = mesh::channel();
537 let resp_recv = req_recv
538 .map(|req: ModifyRelayRequest| {
539 if req.version.is_some() {
541 ModifyRelayResponse::Supported(
542 protocol::ConnectionState::SUCCESSFUL,
543 protocol::FeatureFlags::from_bits(u32::MAX),
544 )
545 } else {
546 ModifyRelayResponse::Modified(protocol::ConnectionState::SUCCESSFUL)
547 }
548 })
549 .boxed()
550 .fuse();
551 (req_send, resp_recv)
552 };
553
554 let (hvsock_send, hvsock_recv) = if let Some(hvsock_notify) = self.hvsock_notify {
556 let r = hvsock_notify.response_receive.boxed().fuse();
557 (hvsock_notify.request_send, r)
558 } else {
559 let (req_send, req_recv) = mesh::channel();
560 let resp_recv = req_recv
561 .map(|r: HvsockConnectRequest| HvsockConnectResult::from_request(&r, false))
562 .boxed()
563 .fuse();
564 (req_send, resp_recv)
565 };
566
567 let inner = ServerTaskInner {
568 running: false,
569 send_messages_while_stopped: self.send_messages_while_stopped,
570 gm: self.gm,
571 private_gm: self.private_gm,
572 vtl: self.vtl,
573 redirect_vtl,
574 redirect_sint,
575 message_port: self
576 .synic
577 .new_guest_message_port(redirect_vtl, 0, redirect_sint)?,
578 synic: self.synic,
579 hvsock_requests: 0,
580 hvsock_send,
581 saved_state_notify: self.saved_state_notify,
582 channels: HashMap::new(),
583 channel_responses: FuturesUnordered::new(),
584 relay_send: relay_request_send,
585 external_server_send: self.external_server,
586 channel_bitmap: None,
587 shared_event_port: None,
588 reset_done: Vec::new(),
589 mnf_support: self.enable_mnf.then(MnfSupport::default),
590 };
591
592 let (task_send, task_recv) = mesh::channel();
593 let mut server_task = ServerTask {
594 driver: Box::new(self.spawner.clone()),
595 server,
596 task_recv,
597 offer_recv,
598 message_recv,
599 server_request_recv: SelectAll::new(),
600 inner,
601 external_requests: self.external_requests,
602 next_seq: 0,
603 perform_post_restore_on_start: false,
604 unstick_on_start: false,
605 channel_unstickers: FuturesUnordered::new(),
606 channel_unstick_delay: self.channel_unstick_delay,
607 };
608
609 let task = self.spawner.spawn("vmbus server", async move {
610 server_task.run(relay_response_recv, hvsock_recv).await;
611 server_task
612 });
613
614 Ok(VmbusServer {
615 task_send,
616 control,
617 _message_port,
618 _multiclient_message_port,
619 task,
620 })
621 }
622}
623
624impl VmbusServer {
625 pub fn builder<T: SpawnDriver + Clone>(
627 spawner: T,
628 synic: Arc<dyn SynicPortAccess>,
629 gm: GuestMemory,
630 ) -> VmbusServerBuilder<T> {
631 VmbusServerBuilder::new(spawner, synic, gm)
632 }
633
634 pub async fn save(&self) -> SavedState {
635 self.task_send.call(VmbusRequest::Save, ()).await.unwrap()
636 }
637
638 pub async fn restore(&self, state: SavedState) -> Result<(), RestoreError> {
639 self.task_send
640 .call(VmbusRequest::Restore, Box::new(state))
641 .await
642 .unwrap()
643 }
644
645 pub async fn stop(&self) {
647 self.task_send.call(VmbusRequest::Stop, ()).await.unwrap()
648 }
649
650 pub fn start(&self) {
652 self.task_send.send(VmbusRequest::Start);
653 }
654
655 pub async fn reset(&self) {
657 tracing::debug!("resetting channel state");
658 self.task_send.call(VmbusRequest::Reset, ()).await.unwrap()
659 }
660
661 pub async fn shutdown(self) {
663 drop(self.task_send);
664 let _ = self.task.await;
665 }
666
667 pub fn control(&self) -> Arc<VmbusServerControl> {
669 self.control.clone()
670 }
671
672 fn get_child_message_connection_id(vp_index: u32, sint_index: u8, vtl: Vtl) -> u32 {
675 MULTICLIENT_MESSAGE_CONNECTION_ID
676 | (vtl as u32) << 22
677 | vp_index << 8
678 | (sint_index as u32) << 4
679 }
680
681 fn get_child_event_port_id(channel_id: protocol::ChannelId, sint_index: u8, vtl: Vtl) -> u32 {
682 EVENT_PORT_ID | (vtl as u32) << 22 | channel_id.0 << 8 | (sint_index as u32) << 4
683 }
684}
685
686#[derive(mesh::MeshPayload)]
687pub struct RestoreInfo {
688 open_data: Option<OpenData>,
689 gpadls: Vec<(GpadlId, u16, Vec<u64>)>,
690 interrupt: Option<Interrupt>,
691}
692
693#[derive(Default)]
694pub struct SynicMessage {
695 data: Vec<u8>,
696 multiclient: bool,
697 trusted: bool,
698}
699
700#[derive(Default)]
702struct MnfSupport {
703 allocated_monitor_page: Option<MonitorPageGpas>,
704}
705
706#[derive(Debug, Clone, Copy)]
708struct OfferInstanceId {
709 offer_id: OfferId,
710 seq: u64,
711}
712
713struct ServerTask {
714 driver: Box<dyn Driver>,
715 server: channels::Server,
716 task_recv: mesh::Receiver<VmbusRequest>,
717 offer_recv: mesh::Receiver<OfferRequest>,
718 message_recv: mpsc::Receiver<SynicMessage>,
719 server_request_recv:
720 SelectAll<TaggedStream<OfferInstanceId, mesh::Receiver<ChannelServerRequest>>>,
721 inner: ServerTaskInner,
722 external_requests: Option<mesh::Receiver<InitiateContactRequest>>,
723 next_seq: u64,
725 perform_post_restore_on_start: bool,
726 unstick_on_start: bool,
727 channel_unstickers: FuturesUnordered<Pin<Box<dyn Send + Future<Output = OfferInstanceId>>>>,
728 channel_unstick_delay: Option<Duration>,
729}
730
731struct ServerTaskInner {
732 running: bool,
733 send_messages_while_stopped: bool,
734 gm: GuestMemory,
735 private_gm: Option<GuestMemory>,
736 synic: Arc<dyn SynicPortAccess>,
737 vtl: Vtl,
738 redirect_vtl: Vtl,
739 redirect_sint: u8,
740 message_port: Box<dyn GuestMessagePort>,
741 hvsock_requests: usize,
742 hvsock_send: mesh::Sender<HvsockConnectRequest>,
743 saved_state_notify: Option<mesh::Sender<SavedStateRequest>>,
744 channels: HashMap<OfferId, Channel>,
745 channel_responses: FuturesUnordered<
746 Pin<Box<dyn Send + Future<Output = (OfferId, u64, Result<ChannelResponse, RpcError>)>>>,
747 >,
748 external_server_send: Option<mesh::Sender<InitiateContactRequest>>,
749 relay_send: mesh::Sender<ModifyRelayRequest>,
750 channel_bitmap: Option<Arc<ChannelBitmap>>,
751 shared_event_port: Option<Box<dyn Send>>,
752 reset_done: Vec<Rpc<(), ()>>,
753 mnf_support: Option<MnfSupport>,
756}
757
758#[derive(Debug)]
759enum ChannelResponse {
760 Open(bool),
761 Close,
762 Gpadl(GpadlId, bool),
763 TeardownGpadl(GpadlId),
764 Modify(i32),
765}
766
767#[derive(Debug, Copy, Clone, PartialEq, Eq)]
768enum ChannelUnstickState {
769 None,
770 Queued,
771 NeedsRequeue,
772}
773
774struct Channel {
775 key: OfferKey,
776 send: mesh::Sender<ChannelRequest>,
777 seq: u64,
778 state: ChannelState,
779 gpadls: Arc<GpadlMap>,
780 guest_to_host_event: Arc<ChannelEvent>,
781 flags: protocol::OfferFlags,
782 reserved_state: ReservedState,
787 unstick_state: ChannelUnstickState,
788}
789
790struct ReservedState {
791 message_port: Option<Box<dyn GuestMessagePort>>,
792 target: ConnectionTarget,
793}
794
795struct ChannelOpenState {
796 open_params: OpenParams,
797 _event_port: Box<dyn Send>,
798 guest_event_port: Option<Box<dyn GuestEventPort>>,
799 host_to_guest_interrupt: Interrupt,
800}
801
802impl ChannelOpenState {
803 fn set_event_port_target_vp(&mut self, vp: u32) -> anyhow::Result<()> {
804 let Some(guest_event_port) = self.guest_event_port.as_mut() else {
805 anyhow::bail!("cannot set target VP if the channel interrupt is disabled");
806 };
807
808 guest_event_port.set_target_vp(vp)?;
809 Ok(())
810 }
811}
812
813enum ChannelState {
814 Closed,
815 Open(Box<ChannelOpenState>),
816 Closing,
817}
818
819impl ServerTask {
820 fn handle_offer(&mut self, mut info: OfferInfo) -> anyhow::Result<()> {
821 let key = info.params.key();
822 let flags = info.params.flags;
823
824 if self.inner.mnf_support.is_some() && self.inner.synic.monitor_support().is_some() {
825 if info.params.use_mnf.is_relayed() {
830 info.params.use_mnf = MnfUsage::Enabled {
831 latency: Duration::ZERO,
832 }
833 }
834 } else if info.params.use_mnf.is_enabled() {
835 info.params.use_mnf = MnfUsage::Disabled;
838 }
839
840 let offer_id = self
841 .server
842 .with_notifier(&mut self.inner)
843 .offer_channel(info.params)
844 .context("channel offer failed")?;
845
846 tracing::debug!(?offer_id, %key, "offered channel");
847
848 let seq = self.next_seq;
849 self.next_seq += 1;
850 self.inner.channels.insert(
851 offer_id,
852 Channel {
853 key,
854 send: info.request_send,
855 state: ChannelState::Closed,
856 gpadls: GpadlMap::new(),
857 guest_to_host_event: Arc::new(ChannelEvent(info.event)),
858 seq,
859 flags,
860 reserved_state: ReservedState {
861 message_port: None,
862 target: ConnectionTarget { vp: 0, sint: 0 },
863 },
864 unstick_state: ChannelUnstickState::None,
865 },
866 );
867
868 self.server_request_recv.push(TaggedStream::new(
869 OfferInstanceId { offer_id, seq },
870 info.server_request_recv,
871 ));
872
873 Ok(())
874 }
875
876 fn handle_revoke(&mut self, id: OfferInstanceId) {
877 if let Some(channel) = self.inner.channels.get(&id.offer_id) {
880 if channel.seq == id.seq {
881 tracing::info!(?id.offer_id, key = %channel.key, "revoking channel");
882 self.inner.channels.remove(&id.offer_id);
883 self.server
884 .with_notifier(&mut self.inner)
885 .revoke_channel(id.offer_id);
886 }
887 }
888 }
889
890 fn handle_response(
891 &mut self,
892 offer_id: OfferId,
893 seq: u64,
894 response: Result<ChannelResponse, RpcError>,
895 ) {
896 let channel = self
898 .inner
899 .channels
900 .get(&offer_id)
901 .filter(|channel| channel.seq == seq);
902
903 if let Some(channel) = channel {
904 match response {
905 Ok(response) => match response {
906 ChannelResponse::Open(result) => self.handle_open(offer_id, result),
907 ChannelResponse::Close => self.handle_close(offer_id),
908 ChannelResponse::Gpadl(gpadl_id, ok) => {
909 self.handle_gpadl_create(offer_id, gpadl_id, ok)
910 }
911 ChannelResponse::TeardownGpadl(gpadl_id) => {
912 self.handle_gpadl_teardown(offer_id, gpadl_id)
913 }
914 ChannelResponse::Modify(status) => self.handle_modify_channel(offer_id, status),
915 },
916 Err(err) => {
917 tracing::error!(
918 key = %channel.key,
919 error = &err as &dyn std::error::Error,
920 "channel response failure, channel is in inconsistent state until revoked"
921 );
922 }
923 }
924 } else {
925 tracing::debug!(offer_id = ?offer_id, seq, ?response, "received response after revoke");
926 }
927 }
928
929 fn handle_open(&mut self, offer_id: OfferId, success: bool) {
930 let status = if success {
931 let channel = self
932 .inner
933 .channels
934 .get_mut(&offer_id)
935 .expect("channel exists");
936
937 if let Some(delay) = self.channel_unstick_delay {
940 if channel.unstick_state == ChannelUnstickState::None {
941 channel.unstick_state = ChannelUnstickState::Queued;
942 let seq = channel.seq;
943 let mut timer = PolledTimer::new(&self.driver);
944 self.channel_unstickers.push(Box::pin(async move {
945 timer.sleep(delay).await;
946 OfferInstanceId { offer_id, seq }
947 }));
948 } else {
949 channel.unstick_state = ChannelUnstickState::NeedsRequeue;
950 }
951 }
952
953 0
954 } else {
955 protocol::STATUS_UNSUCCESSFUL
956 };
957
958 self.server
959 .with_notifier(&mut self.inner)
960 .open_complete(offer_id, status);
961 }
962
963 fn handle_close(&mut self, offer_id: OfferId) {
964 let channel = self
965 .inner
966 .channels
967 .get_mut(&offer_id)
968 .expect("channel still exists");
969
970 match &mut channel.state {
971 ChannelState::Closing => {
972 tracing::debug!(?offer_id, key = %channel.key, "closing channel");
973 channel.state = ChannelState::Closed;
974 self.server
975 .with_notifier(&mut self.inner)
976 .close_complete(offer_id);
977 }
978 _ => {
979 tracing::error!(?offer_id, key = %channel.key, "invalid close channel response");
980 }
981 };
982 }
983
984 fn handle_gpadl_create(&mut self, offer_id: OfferId, gpadl_id: GpadlId, ok: bool) {
985 let status = if ok { 0 } else { protocol::STATUS_UNSUCCESSFUL };
986 self.server
987 .with_notifier(&mut self.inner)
988 .gpadl_create_complete(offer_id, gpadl_id, status);
989 }
990
991 fn handle_gpadl_teardown(&mut self, offer_id: OfferId, gpadl_id: GpadlId) {
992 self.server
993 .with_notifier(&mut self.inner)
994 .gpadl_teardown_complete(offer_id, gpadl_id);
995 }
996
997 fn handle_modify_channel(&mut self, offer_id: OfferId, status: i32) {
998 self.server
999 .with_notifier(&mut self.inner)
1000 .modify_channel_complete(offer_id, status);
1001 }
1002
1003 fn handle_restore_channel(
1004 &mut self,
1005 offer_id: OfferId,
1006 open: bool,
1007 ) -> anyhow::Result<RestoreResult> {
1008 let gpadls = self.server.channel_gpadls(offer_id);
1009
1010 let open_request = open
1013 .then(|| -> anyhow::Result<_> {
1014 let params = self.server.get_restore_open_params(offer_id)?;
1015 let (channel, interrupt) = self.inner.open_channel(offer_id, ¶ms)?;
1016 Ok(OpenRequest::new(
1017 params.open_data,
1018 interrupt,
1019 self.server
1020 .get_version()
1021 .expect("must be connected")
1022 .feature_flags,
1023 channel.flags,
1024 ))
1025 })
1026 .transpose()?;
1027
1028 self.server
1029 .with_notifier(&mut self.inner)
1030 .restore_channel(offer_id, open_request.is_some())?;
1031
1032 let channel = self.inner.channels.get_mut(&offer_id).unwrap();
1033 for gpadl in &gpadls {
1034 if let Ok(buf) = MultiPagedRangeBuf::from_range_buffer(
1035 gpadl.request.count.into(),
1036 gpadl.request.buf.clone(),
1037 ) {
1038 channel.gpadls.add(gpadl.request.id, buf);
1039 }
1040 }
1041
1042 let result = RestoreResult {
1043 open_request,
1044 gpadls,
1045 };
1046 Ok(result)
1047 }
1048
1049 async fn handle_request(&mut self, request: VmbusRequest) {
1050 tracing::debug!(?request, "handle_request");
1051 match request {
1052 VmbusRequest::Reset(rpc) => self.handle_reset(rpc),
1053 VmbusRequest::Inspect(deferred) => {
1054 deferred.respond(|resp| {
1055 resp.field("message_port", &self.inner.message_port)
1056 .field("running", self.inner.running)
1057 .field("hvsock_requests", self.inner.hvsock_requests)
1058 .field("channel_unstick_delay", self.channel_unstick_delay)
1059 .field_mut_with("unstick_channels", |v| {
1060 let v: inspect::ValueKind = if let Some(v) = v {
1061 if v == "force" {
1062 self.unstick_channels(true);
1063 v.into()
1064 } else {
1065 let v =
1066 v.parse().ok().context("expected false, true, or force")?;
1067 if v {
1068 self.unstick_channels(false);
1069 }
1070 v.into()
1071 }
1072 } else {
1073 false.into()
1074 };
1075 anyhow::Ok(v)
1076 })
1077 .merge(&self.server.with_notifier(&mut self.inner));
1078 });
1079 }
1080 VmbusRequest::Save(rpc) => rpc.handle_sync(|()| SavedState {
1081 server: self.server.save(),
1082 lost_synic_bug_fixed: true,
1083 }),
1084 VmbusRequest::Restore(rpc) => {
1085 rpc.handle(async |state| {
1086 self.unstick_on_start = !state.lost_synic_bug_fixed;
1087 self.perform_post_restore_on_start = true;
1088 if let Some(sender) = &self.inner.saved_state_notify {
1089 tracing::trace!("sending saved state to proxy");
1090 if let Err(err) = sender
1091 .call_failable(SavedStateRequest::Set, Box::new(state.server.clone()))
1092 .await
1093 {
1094 tracing::error!(
1095 err = &err as &dyn std::error::Error,
1096 "failed to restore proxy saved state"
1097 );
1098 return Err(RestoreError::ServerError(err.into()));
1099 }
1100 }
1101
1102 self.server
1103 .with_notifier(&mut self.inner)
1104 .restore(state.server)
1105 })
1106 .await
1107 }
1108 VmbusRequest::Stop(rpc) => rpc.handle_sync(|()| {
1109 if self.inner.running {
1110 self.inner.running = false;
1111 }
1112 }),
1113 VmbusRequest::Start => {
1114 if !self.inner.running {
1115 self.inner.running = true;
1116 if self.perform_post_restore_on_start {
1117 if let Some(sender) = self.inner.saved_state_notify.as_ref() {
1118 tracing::trace!("sending clear saved state message to proxy");
1121 sender
1122 .call(SavedStateRequest::Clear, ())
1123 .await
1124 .expect("failed to clear proxy saved state");
1125 }
1126
1127 self.server
1128 .with_notifier(&mut self.inner)
1129 .revoke_unclaimed_channels();
1130
1131 self.perform_post_restore_on_start = false;
1132 }
1133
1134 if self.unstick_on_start {
1135 tracing::info!(
1136 "lost synic bug fix is not in yet, call unstick_channels to mitigate the issue."
1137 );
1138 self.unstick_channels(false);
1139 self.unstick_on_start = false;
1140 }
1141 }
1142 }
1143 }
1144 }
1145
1146 fn handle_reset(&mut self, rpc: Rpc<(), ()>) {
1147 let needs_reset = self.inner.reset_done.is_empty();
1148 self.inner.reset_done.push(rpc);
1149 if needs_reset {
1150 self.server.with_notifier(&mut self.inner).reset();
1151 }
1152 }
1153
1154 fn handle_relay_response(&mut self, response: ModifyRelayResponse) {
1155 let response = match response {
1157 ModifyRelayResponse::Supported(state, features) => {
1158 let allocated_monitor_gpas = self
1161 .inner
1162 .mnf_support
1163 .as_ref()
1164 .and_then(|mnf| mnf.allocated_monitor_page);
1165
1166 ModifyConnectionResponse::Supported(state, features, allocated_monitor_gpas)
1167 }
1168 ModifyRelayResponse::Unsupported => ModifyConnectionResponse::Unsupported,
1169 ModifyRelayResponse::Modified(state) => ModifyConnectionResponse::Modified(state),
1170 };
1171
1172 self.server
1173 .with_notifier(&mut self.inner)
1174 .complete_modify_connection(response);
1175 }
1176
1177 fn handle_tl_connect_result(&mut self, result: HvsockConnectResult) {
1178 tracing::debug!(?result, "hvsock connect result");
1179 assert_ne!(self.inner.hvsock_requests, 0);
1180 self.inner.hvsock_requests -= 1;
1181
1182 self.server
1183 .with_notifier(&mut self.inner)
1184 .send_tl_connect_result(result);
1185 }
1186
1187 fn handle_synic_message(&mut self, message: SynicMessage) {
1188 match self
1189 .server
1190 .with_notifier(&mut self.inner)
1191 .handle_synic_message(message)
1192 {
1193 Ok(()) => {}
1194 Err(err) => {
1195 tracing::warn!(
1196 error = &err as &dyn std::error::Error,
1197 "synic message error"
1198 );
1199 }
1200 }
1201 }
1202
1203 fn handle_external_request(&mut self, request: InitiateContactRequest) {
1210 self.server
1211 .with_notifier(&mut self.inner)
1212 .initiate_contact(request);
1213 }
1214
1215 async fn run(
1216 &mut self,
1217 mut relay_response_recv: impl futures::stream::FusedStream<Item = ModifyRelayResponse> + Unpin,
1218 mut hvsock_recv: impl futures::stream::FusedStream<Item = HvsockConnectResult> + Unpin,
1219 ) {
1220 loop {
1221 let running_not_resetting = self.inner.running && self.inner.reset_done.is_empty();
1226 let mut external_requests = OptionFuture::from(
1227 running_not_resetting
1228 .then(|| {
1229 self.external_requests
1230 .as_mut()
1231 .map(|r| r.select_next_some())
1232 })
1233 .flatten(),
1234 );
1235
1236 let has_pending_messages = self.server.has_pending_messages();
1238 let message_port = self.inner.message_port.as_mut();
1239 let mut flush_pending_messages =
1240 OptionFuture::from((running_not_resetting && has_pending_messages).then(|| {
1241 poll_fn(|cx| {
1242 self.server.poll_flush_pending_messages(|msg| {
1243 message_port.poll_post_message(cx, VMBUS_MESSAGE_TYPE, msg.data())
1244 })
1245 })
1246 .fuse()
1247 }));
1248
1249 let mut message_recv = OptionFuture::from(
1253 (running_not_resetting
1254 && !has_pending_messages
1255 && self.inner.hvsock_requests < MAX_CONCURRENT_HVSOCK_REQUESTS)
1256 .then(|| self.message_recv.select_next_some()),
1257 );
1258
1259 let mut channel_response = OptionFuture::from(
1261 (self.inner.running || !self.inner.reset_done.is_empty())
1262 .then(|| self.inner.channel_responses.select_next_some()),
1263 );
1264
1265 let mut hvsock_response =
1267 OptionFuture::from(running_not_resetting.then(|| hvsock_recv.select_next_some()));
1268
1269 let mut channel_unstickers = OptionFuture::from(
1270 running_not_resetting.then(|| self.channel_unstickers.select_next_some()),
1271 );
1272
1273 futures::select! { r = self.task_recv.recv().fuse() => {
1275 if let Ok(request) = r {
1276 self.handle_request(request).await;
1277 } else {
1278 break;
1279 }
1280 }
1281 r = self.offer_recv.select_next_some() => {
1282 match r {
1283 OfferRequest::Offer(rpc) => {
1284 rpc.handle_failable_sync(|request| { self.handle_offer(request) })
1285 },
1286 OfferRequest::ForceReset(rpc) => {
1287 self.handle_reset(rpc);
1288 }
1289 }
1290 }
1291 r = self.server_request_recv.select_next_some() => {
1292 match r {
1293 (id, Some(request)) => match request {
1294 ChannelServerRequest::Restore(rpc) => rpc.handle_failable_sync(|open| {
1295 self.handle_restore_channel(id.offer_id, open)
1296 }),
1297 ChannelServerRequest::Revoke(rpc) => rpc.handle_sync(|_| {
1298 self.handle_revoke(id);
1299 })
1300 },
1301 (id, None) => self.handle_revoke(id),
1302 }
1303 }
1304 r = channel_response => {
1305 let (id, seq, response) = r.unwrap();
1306 self.handle_response(id, seq, response);
1307 }
1308 r = relay_response_recv.select_next_some() => {
1309 self.handle_relay_response(r);
1310 },
1311 r = hvsock_response => {
1312 self.handle_tl_connect_result(r.unwrap());
1313 }
1314 data = message_recv => {
1315 let data = data.unwrap();
1316 self.handle_synic_message(data);
1317 }
1318 r = external_requests => {
1319 let r = r.unwrap();
1320 self.handle_external_request(r);
1321 }
1322 r = channel_unstickers => {
1323 self.unstick_channel_by_id(r.unwrap());
1324 }
1325 _r = flush_pending_messages => {}
1326 complete => break,
1327 }
1328 }
1329 }
1330
1331 fn unstick_channels(&self, force: bool) {
1335 let Some(version) = self.server.get_version() else {
1336 tracing::warn!("cannot unstick when not connected");
1337 return;
1338 };
1339
1340 for channel in self.inner.channels.values() {
1341 let gm = self.inner.get_gm_for_channel(version, channel);
1342 if let Err(err) = Self::unstick_channel(gm, channel, force, true) {
1343 tracing::warn!(
1344 channel = %channel.key,
1345 error = err.as_ref() as &dyn std::error::Error,
1346 "could not unstick channel"
1347 );
1348 }
1349 }
1350 }
1351
1352 fn unstick_channel_by_id(&mut self, id: OfferInstanceId) {
1355 let Some(version) = self.server.get_version() else {
1356 tracelimit::warn_ratelimited!("cannot unstick when not connected");
1357 return;
1358 };
1359
1360 if let Some(channel) = self.inner.channels.get_mut(&id.offer_id) {
1361 if channel.seq != id.seq {
1362 return;
1364 }
1365
1366 if channel.unstick_state == ChannelUnstickState::NeedsRequeue {
1369 channel.unstick_state = ChannelUnstickState::Queued;
1370 let mut timer = PolledTimer::new(&self.driver);
1371 let delay = self.channel_unstick_delay.unwrap();
1372 self.channel_unstickers.push(Box::pin(async move {
1373 timer.sleep(delay).await;
1374 id
1375 }));
1376
1377 return;
1378 }
1379
1380 channel.unstick_state = ChannelUnstickState::None;
1381 let gm = select_gm_for_channel(
1382 &self.inner.gm,
1383 self.inner.private_gm.as_ref(),
1384 version,
1385 channel,
1386 );
1387 if let Err(err) = Self::unstick_channel(gm, channel, false, false) {
1388 tracelimit::warn_ratelimited!(
1389 channel = %channel.key,
1390 error = err.as_ref() as &dyn std::error::Error,
1391 "could not unstick channel"
1392 );
1393 }
1394 }
1395 }
1396
1397 fn unstick_channel(
1398 gm: &GuestMemory,
1399 channel: &Channel,
1400 force: bool,
1401 unstick_host: bool,
1402 ) -> anyhow::Result<()> {
1403 if let ChannelState::Open(state) = &channel.state {
1404 if force {
1405 tracing::info!(channel = %channel.key, "waking host and guest");
1406 if unstick_host {
1407 channel.guest_to_host_event.0.deliver();
1408 }
1409 state.host_to_guest_interrupt.deliver();
1410 return Ok(());
1411 }
1412
1413 let gpadl = channel
1414 .gpadls
1415 .clone()
1416 .view()
1417 .map(state.open_params.open_data.ring_gpadl_id)
1418 .context("couldn't find ring gpadl")?;
1419
1420 let aligned = AlignedGpadlView::new(gpadl)
1421 .ok()
1422 .context("ring not aligned")?;
1423 let (in_gpadl, out_gpadl) = aligned
1424 .split(state.open_params.open_data.ring_offset)
1425 .ok()
1426 .context("couldn't split ring")?;
1427
1428 if let Err(err) = Self::unstick_incoming_ring(
1429 gm,
1430 channel,
1431 in_gpadl,
1432 unstick_host.then_some(channel.guest_to_host_event.as_ref()),
1433 &state.host_to_guest_interrupt,
1434 ) {
1435 tracelimit::warn_ratelimited!(
1436 channel = %channel.key,
1437 error = err.as_ref() as &dyn std::error::Error,
1438 "could not unstick incoming ring"
1439 );
1440 }
1441 if let Err(err) = Self::unstick_outgoing_ring(
1442 gm,
1443 channel,
1444 out_gpadl,
1445 unstick_host.then_some(channel.guest_to_host_event.as_ref()),
1446 &state.host_to_guest_interrupt,
1447 ) {
1448 tracelimit::warn_ratelimited!(
1449 channel = %channel.key,
1450 error = err.as_ref() as &dyn std::error::Error,
1451 "could not unstick outgoing ring"
1452 );
1453 }
1454 }
1455 Ok(())
1456 }
1457
1458 fn unstick_incoming_ring(
1459 gm: &GuestMemory,
1460 channel: &Channel,
1461 in_gpadl: AlignedGpadlView,
1462 guest_to_host_event: Option<&ChannelEvent>,
1463 host_to_guest_interrupt: &Interrupt,
1464 ) -> anyhow::Result<()> {
1465 let control_page = lock_gpn_with_subrange(gm, in_gpadl.gpns()[0])?;
1466 if let Some(guest_to_host_event) = guest_to_host_event {
1467 if ring::reader_needs_signal(control_page.pages()[0]) {
1468 tracelimit::info_ratelimited!(channel = %channel.key, "waking host for incoming ring");
1469 guest_to_host_event.0.deliver();
1470 }
1471 }
1472
1473 let ring_size = gpadl_ring_size(&in_gpadl).try_into()?;
1474 if ring::writer_needs_signal(control_page.pages()[0], ring_size) {
1475 tracelimit::info_ratelimited!(channel = %channel.key, "waking guest for incoming ring");
1476 host_to_guest_interrupt.deliver();
1477 }
1478 Ok(())
1479 }
1480
1481 fn unstick_outgoing_ring(
1482 gm: &GuestMemory,
1483 channel: &Channel,
1484 out_gpadl: AlignedGpadlView,
1485 guest_to_host_event: Option<&ChannelEvent>,
1486 host_to_guest_interrupt: &Interrupt,
1487 ) -> anyhow::Result<()> {
1488 let control_page = lock_gpn_with_subrange(gm, out_gpadl.gpns()[0])?;
1489 if ring::reader_needs_signal(control_page.pages()[0]) {
1490 tracelimit::info_ratelimited!(channel = %channel.key, "waking guest for outgoing ring");
1491 host_to_guest_interrupt.deliver();
1492 }
1493
1494 if let Some(guest_to_host_event) = guest_to_host_event {
1495 let ring_size = gpadl_ring_size(&out_gpadl).try_into()?;
1496 if ring::writer_needs_signal(control_page.pages()[0], ring_size) {
1497 tracelimit::info_ratelimited!(channel = %channel.key, "waking host for outgoing ring");
1498 guest_to_host_event.0.deliver();
1499 }
1500 }
1501 Ok(())
1502 }
1503}
1504
1505impl Notifier for ServerTaskInner {
1506 fn notify(&mut self, offer_id: OfferId, action: channels::Action) {
1507 let channel = self
1508 .channels
1509 .get_mut(&offer_id)
1510 .expect("channel does not exist");
1511
1512 fn handle<I: 'static + Send, R: 'static + Send>(
1513 offer_id: OfferId,
1514 channel: &Channel,
1515 req: impl FnOnce(Rpc<I, R>) -> ChannelRequest,
1516 input: I,
1517 f: impl 'static + Send + FnOnce(R) -> ChannelResponse,
1518 ) -> Pin<Box<dyn Send + Future<Output = (OfferId, u64, Result<ChannelResponse, RpcError>)>>>
1519 {
1520 let recv = channel.send.call(req, input);
1521 let seq = channel.seq;
1522 Box::pin(async move {
1523 let r = recv.await.map(f);
1524 (offer_id, seq, r)
1525 })
1526 }
1527
1528 let response = match action {
1529 channels::Action::Open(open_params, version) => {
1530 let seq = channel.seq;
1531 let key = channel.key;
1532 match self.open_channel(offer_id, &open_params) {
1533 Ok((channel, interrupt)) => handle(
1534 offer_id,
1535 channel,
1536 ChannelRequest::Open,
1537 OpenRequest::new(
1538 open_params.open_data,
1539 interrupt,
1540 version.feature_flags,
1541 channel.flags,
1542 ),
1543 ChannelResponse::Open,
1544 ),
1545 Err(err) => {
1546 tracelimit::error_ratelimited!(
1547 err = err.as_ref() as &dyn std::error::Error,
1548 ?offer_id,
1549 %key,
1550 "could not open channel",
1551 );
1552
1553 Box::pin(future::ready((
1556 offer_id,
1557 seq,
1558 Ok(ChannelResponse::Open(false)),
1559 )))
1560 }
1561 }
1562 }
1563 channels::Action::Close => {
1564 if let Some(channel_bitmap) = self.channel_bitmap.as_ref() {
1565 if let ChannelState::Open(ref state) = channel.state {
1566 channel_bitmap.unregister_channel(state.open_params.event_flag);
1567 }
1568 }
1569
1570 channel.state = ChannelState::Closing;
1571 handle(offer_id, channel, ChannelRequest::Close, (), |()| {
1572 ChannelResponse::Close
1573 })
1574 }
1575 channels::Action::Gpadl(gpadl_id, count, buf) => {
1576 channel.gpadls.add(
1577 gpadl_id,
1578 MultiPagedRangeBuf::from_range_buffer(count.into(), buf.clone()).unwrap(),
1579 );
1580 handle(
1581 offer_id,
1582 channel,
1583 ChannelRequest::Gpadl,
1584 GpadlRequest {
1585 id: gpadl_id,
1586 count,
1587 buf,
1588 },
1589 move |r| ChannelResponse::Gpadl(gpadl_id, r),
1590 )
1591 }
1592 channels::Action::TeardownGpadl {
1593 gpadl_id,
1594 post_restore,
1595 } => {
1596 if !post_restore {
1597 channel.gpadls.remove(gpadl_id, Box::new(|| ()));
1598 }
1599
1600 handle(
1601 offer_id,
1602 channel,
1603 ChannelRequest::TeardownGpadl,
1604 gpadl_id,
1605 move |()| ChannelResponse::TeardownGpadl(gpadl_id),
1606 )
1607 }
1608 channels::Action::Modify { target_vp } => {
1609 let ChannelState::Open(state) = &mut channel.state else {
1610 unreachable!();
1611 };
1612
1613 if let Err(err) = state.set_event_port_target_vp(target_vp) {
1614 tracelimit::error_ratelimited!(
1615 error = err.as_ref() as &dyn std::error::Error,
1616 channel = %channel.key,
1617 "could not modify channel",
1618 );
1619
1620 let seq = channel.seq;
1622 Box::pin(async move {
1623 (
1624 offer_id,
1625 seq,
1626 Ok(ChannelResponse::Modify(protocol::STATUS_UNSUCCESSFUL)),
1627 )
1628 })
1629 } else {
1630 handle(
1631 offer_id,
1632 channel,
1633 ChannelRequest::Modify,
1634 ModifyRequest::TargetVp { target_vp },
1635 ChannelResponse::Modify,
1636 )
1637 }
1638 }
1639 };
1640 self.channel_responses.push(response);
1641 }
1642
1643 fn modify_connection(&mut self, mut request: ModifyConnectionRequest) -> anyhow::Result<()> {
1644 self.map_interrupt_page(request.interrupt_page)
1645 .context("Failed to map interrupt page.")?;
1646
1647 self.set_monitor_page(&mut request)
1648 .context("Failed to map monitor page.")?;
1649
1650 if let Some(vp) = request.target_message_vp {
1651 self.message_port.set_target_vp(vp)?;
1652 }
1653
1654 if request.notify_relay {
1655 self.relay_send.send(request.into());
1656 }
1657
1658 Ok(())
1659 }
1660
1661 fn forward_unhandled(&mut self, request: InitiateContactRequest) {
1662 if let Some(external_server) = &self.external_server_send {
1663 external_server.send(request);
1664 } else {
1665 tracing::warn!(?request, "nowhere to forward unhandled request")
1666 }
1667 }
1668
1669 fn inspect(&self, version: Option<VersionInfo>, offer_id: OfferId, req: inspect::Request<'_>) {
1670 let channel = self.channels.get(&offer_id).expect("should exist");
1671 let mut resp = req.respond();
1672 if let ChannelState::Open(state) = &channel.state {
1673 let mem = self.get_gm_for_channel(version.expect("must be connected"), channel);
1674 inspect_rings(
1675 &mut resp,
1676 mem,
1677 channel.gpadls.clone(),
1678 &state.open_params.open_data,
1679 );
1680 }
1681 }
1682
1683 fn send_message(&mut self, message: &OutgoingMessage, target: MessageTarget) -> bool {
1684 if !self.running && !self.send_messages_while_stopped {
1697 if !matches!(target, MessageTarget::Default) {
1698 tracelimit::error_ratelimited!(?target, "dropping message while paused");
1699 }
1700 return false;
1701 }
1702
1703 let mut port_storage;
1704 let port = match target {
1705 MessageTarget::Default => self.message_port.as_mut(),
1706 MessageTarget::ReservedChannel(offer_id, target) => {
1707 if let Some(port) = self.get_reserved_channel_message_port(offer_id, target) {
1708 port.as_mut()
1709 } else {
1710 return true;
1712 }
1713 }
1714 MessageTarget::Custom(target) => {
1715 port_storage = match self.synic.new_guest_message_port(
1716 self.redirect_vtl,
1717 target.vp,
1718 target.sint,
1719 ) {
1720 Ok(port) => port,
1721 Err(err) => {
1722 tracing::error!(
1723 ?err,
1724 ?self.redirect_vtl,
1725 ?target,
1726 "could not create message port"
1727 );
1728
1729 return true;
1731 }
1732 };
1733 port_storage.as_mut()
1734 }
1735 };
1736
1737 matches!(
1740 port.poll_post_message(
1741 &mut std::task::Context::from_waker(std::task::Waker::noop()),
1742 VMBUS_MESSAGE_TYPE,
1743 message.data()
1744 ),
1745 Poll::Ready(())
1746 )
1747 }
1748
1749 fn notify_hvsock(&mut self, request: &HvsockConnectRequest) {
1750 tracing::debug!(?request, "received hvsock connect request");
1751 self.hvsock_requests += 1;
1752 self.hvsock_send.send(*request);
1753 }
1754
1755 fn reset_complete(&mut self) {
1756 if let Some(monitor) = self.synic.monitor_support() {
1757 if let Err(err) = monitor.set_monitor_page(self.vtl, None) {
1758 tracing::warn!(?err, "resetting monitor page failed")
1759 }
1760 }
1761
1762 self.unreserve_channels();
1763 for done in self.reset_done.drain(..) {
1764 done.complete(());
1765 }
1766 }
1767
1768 fn unload_complete(&mut self) {
1769 self.unreserve_channels();
1770 }
1771}
1772
1773impl ServerTaskInner {
1774 fn open_channel(
1775 &mut self,
1776 offer_id: OfferId,
1777 open_params: &OpenParams,
1778 ) -> anyhow::Result<(&mut Channel, Interrupt)> {
1779 let channel = self
1780 .channels
1781 .get_mut(&offer_id)
1782 .expect("channel does not exist");
1783
1784 if let Some(channel_bitmap) = self.channel_bitmap.as_ref() {
1786 channel_bitmap.register_channel(
1787 open_params.event_flag,
1788 channel.guest_to_host_event.0.clone(),
1789 );
1790 }
1791 let event_port = self
1796 .synic
1797 .add_event_port(
1798 open_params.connection_id,
1799 self.vtl,
1800 channel.guest_to_host_event.clone(),
1801 open_params.monitor_info,
1802 )
1803 .context("failed to create guest-to-host event port")?;
1804
1805 let (target_vp, event_flag) = if self.channel_bitmap.is_some() {
1808 (Some(0), 0)
1809 } else {
1810 (open_params.open_data.target_vp, open_params.event_flag)
1811 };
1812
1813 let (guest_event_port, interrupt) = if let Some(target_vp) = target_vp {
1814 let (target_vtl, target_sint) = if open_params.flags.redirect_interrupt() {
1815 (self.redirect_vtl, self.redirect_sint)
1816 } else {
1817 (self.vtl, VMBUS_SINT)
1818 };
1819
1820 let guest_event_port = self.synic.new_guest_event_port(
1821 VmbusServer::get_child_event_port_id(open_params.channel_id, VMBUS_SINT, self.vtl),
1822 target_vtl,
1823 target_vp,
1824 target_sint,
1825 event_flag,
1826 open_params.monitor_info,
1827 )?;
1828
1829 let interrupt = ChannelBitmap::create_interrupt(
1830 &self.channel_bitmap,
1831 guest_event_port.interrupt(),
1832 open_params.event_flag,
1833 );
1834
1835 (Some(guest_event_port), interrupt)
1836 } else {
1837 (None, Interrupt::null_event())
1840 };
1841
1842 channel.reserved_state.message_port = None;
1844
1845 if let Some(target) = open_params.reserved_target {
1847 channel.reserved_state.message_port = Some(self.synic.new_guest_message_port(
1848 self.redirect_vtl,
1849 target.vp,
1850 target.sint,
1851 )?);
1852
1853 channel.reserved_state.target = target;
1854 }
1855
1856 channel.state = ChannelState::Open(Box::new(ChannelOpenState {
1857 open_params: *open_params,
1858 _event_port: event_port,
1859 guest_event_port,
1860 host_to_guest_interrupt: interrupt.clone(),
1861 }));
1862 Ok((channel, interrupt))
1863 }
1864
1865 fn map_interrupt_page(&mut self, interrupt_page: Update<u64>) -> anyhow::Result<()> {
1868 let interrupt_page = match interrupt_page {
1869 Update::Unchanged => return Ok(()),
1870 Update::Reset => {
1871 self.channel_bitmap = None;
1872 self.shared_event_port = None;
1873 return Ok(());
1874 }
1875 Update::Set(interrupt_page) => interrupt_page,
1876 };
1877
1878 assert_ne!(interrupt_page, 0);
1879
1880 if interrupt_page % PAGE_SIZE as u64 != 0 {
1881 anyhow::bail!("interrupt page {:#x} is not page aligned", interrupt_page);
1882 }
1883
1884 let interrupt_page = lock_page_with_subrange(&self.gm, interrupt_page)?;
1887 let channel_bitmap = Arc::new(ChannelBitmap::new(interrupt_page));
1888 self.channel_bitmap = Some(channel_bitmap.clone());
1889
1890 let interrupt = Interrupt::from_fn(move || {
1892 channel_bitmap.handle_shared_interrupt();
1893 });
1894
1895 self.shared_event_port = Some(self.synic.add_event_port(
1896 SHARED_EVENT_CONNECTION_ID,
1897 self.vtl,
1898 Arc::new(ChannelEvent(interrupt)),
1899 None,
1900 )?);
1901
1902 Ok(())
1903 }
1904
1905 fn set_monitor_page(&mut self, request: &mut ModifyConnectionRequest) -> anyhow::Result<()> {
1906 let monitor_page = match request.monitor_page {
1907 Update::Unchanged => return Ok(()),
1908 Update::Reset => None,
1909 Update::Set(value) => Some(value),
1910 };
1911
1912 if self.channels.iter().any(|(_, c)| {
1914 matches!(
1915 &c.state,
1916 ChannelState::Open(state) if state.open_params.monitor_info.is_some()
1917 )
1918 }) {
1919 anyhow::bail!("attempt to change monitor page while open channels using mnf");
1920 }
1921
1922 if let Some(mnf_support) = self.mnf_support.as_mut() {
1926 if let Some(monitor) = self.synic.monitor_support() {
1927 mnf_support.allocated_monitor_page = None;
1928
1929 if let Some(version) = request.version {
1930 if version.feature_flags.server_specified_monitor_pages() {
1931 if let Some(monitor_page) = monitor.allocate_monitor_page(self.vtl)? {
1932 tracelimit::info_ratelimited!(
1933 ?monitor_page,
1934 "using server-allocated monitor pages"
1935 );
1936 mnf_support.allocated_monitor_page = Some(monitor_page);
1937 }
1938 }
1939 }
1940
1941 if mnf_support.allocated_monitor_page.is_none() {
1943 if let Err(err) = monitor.set_monitor_page(self.vtl, monitor_page) {
1944 anyhow::bail!(
1945 "setting monitor page failed, err = {err:?}, monitor_page = {monitor_page:?}"
1946 );
1947 }
1948 }
1949 }
1950
1951 request.monitor_page = Update::Unchanged;
1954 }
1955
1956 Ok(())
1957 }
1958
1959 fn get_reserved_channel_message_port(
1960 &mut self,
1961 offer_id: OfferId,
1962 new_target: ConnectionTarget,
1963 ) -> Option<&mut Box<dyn GuestMessagePort>> {
1964 let channel = self
1965 .channels
1966 .get_mut(&offer_id)
1967 .expect("channel does not exist");
1968
1969 assert!(
1970 channel.reserved_state.message_port.is_some(),
1971 "channel is not reserved"
1972 );
1973
1974 if channel.reserved_state.target.sint != new_target.sint {
1977 channel.reserved_state.message_port = None;
1979 let message_port = self
1980 .synic
1981 .new_guest_message_port(self.redirect_vtl, new_target.vp, new_target.sint)
1982 .inspect_err(|err| {
1983 tracing::error!(
1984 key = %channel.key,
1985 ?err,
1986 ?self.redirect_vtl,
1987 ?new_target,
1988 "could not create reserved channel message port"
1989 )
1990 })
1991 .ok()?;
1992
1993 channel.reserved_state.message_port = Some(message_port);
1994 channel.reserved_state.target = new_target;
1995 } else if channel.reserved_state.target.vp != new_target.vp {
1996 let message_port = channel.reserved_state.message_port.as_mut().unwrap();
1997
1998 if let Err(err) = message_port.set_target_vp(new_target.vp) {
2001 tracing::error!(
2002 key = %channel.key,
2003 ?err,
2004 ?self.redirect_vtl,
2005 ?new_target,
2006 "could not update reserved channel message port"
2007 );
2008 }
2009
2010 channel.reserved_state.target = new_target;
2011 return Some(message_port);
2012 }
2013
2014 Some(channel.reserved_state.message_port.as_mut().unwrap())
2015 }
2016
2017 fn unreserve_channels(&mut self) {
2018 for channel in self.channels.values_mut() {
2020 if let ChannelState::Closed = channel.state {
2021 channel.reserved_state.message_port = None;
2022 }
2023 }
2024 }
2025
2026 fn get_gm_for_channel(&self, version: VersionInfo, channel: &Channel) -> &GuestMemory {
2027 select_gm_for_channel(&self.gm, self.private_gm.as_ref(), version, channel)
2028 }
2029}
2030
2031fn select_gm_for_channel<'a>(
2032 gm: &'a GuestMemory,
2033 private_gm: Option<&'a GuestMemory>,
2034 version: VersionInfo,
2035 channel: &Channel,
2036) -> &'a GuestMemory {
2037 if channel.flags.confidential_ring_buffer() && version.feature_flags.confidential_channels() {
2038 if let Some(private_gm) = private_gm {
2039 return private_gm;
2040 }
2041 }
2042
2043 gm
2044}
2045
2046#[derive(Clone)]
2048pub struct VmbusServerControl {
2049 mem: GuestMemory,
2050 private_mem: Option<GuestMemory>,
2051 send: mesh::Sender<OfferRequest>,
2052 use_event: bool,
2053 force_confidential_external_memory: bool,
2054}
2055
2056impl VmbusServerControl {
2057 pub async fn offer_core(&self, offer_info: OfferInfo) -> anyhow::Result<OfferResources> {
2060 let flags = offer_info.params.flags;
2061 self.send
2062 .call_failable(OfferRequest::Offer, offer_info)
2063 .await?;
2064 Ok(OfferResources::new(
2065 self.mem.clone(),
2066 if flags.confidential_ring_buffer() || flags.confidential_external_memory() {
2067 self.private_mem.clone()
2068 } else {
2069 None
2070 },
2071 ))
2072 }
2073
2074 pub async fn force_reset(&self) -> anyhow::Result<()> {
2077 self.send
2078 .call(OfferRequest::ForceReset, ())
2079 .await
2080 .context("vmbus server is gone")
2081 }
2082
2083 async fn offer(&self, request: OfferInput) -> anyhow::Result<OfferResources> {
2084 let mut offer_info = OfferInfo {
2085 params: request.params.into(),
2086 event: request.event,
2087 request_send: request.request_send,
2088 server_request_recv: request.server_request_recv,
2089 };
2090
2091 if self.force_confidential_external_memory {
2092 tracing::warn!(
2093 key = %offer_info.params.key(),
2094 "forcing confidential external memory for channel"
2095 );
2096
2097 offer_info
2098 .params
2099 .flags
2100 .set_confidential_external_memory(true);
2101 }
2102
2103 self.offer_core(offer_info).await
2104 }
2105}
2106
2107fn inspect_rings(
2109 resp: &mut inspect::Response<'_>,
2110 gm: &GuestMemory,
2111 gpadl_map: Arc<GpadlMap>,
2112 open_data: &OpenData,
2113) -> Option<()> {
2114 let gpadl = gpadl_map
2115 .view()
2116 .map(GpadlId(open_data.ring_gpadl_id.0))
2117 .ok()?;
2118
2119 let aligned = AlignedGpadlView::new(gpadl).ok()?;
2120 let (in_gpadl, out_gpadl) = aligned.split(open_data.ring_offset).ok()?;
2121 resp.child("incoming_ring", |req| inspect_ring(req, &in_gpadl, gm));
2122 resp.child("outgoing_ring", |req| inspect_ring(req, &out_gpadl, gm));
2123 Some(())
2124}
2125
2126fn inspect_ring(req: inspect::Request<'_>, gpadl: &AlignedGpadlView, gm: &GuestMemory) {
2128 let mut resp = req.respond();
2129
2130 resp.hex("ring_size", gpadl_ring_size(gpadl));
2131
2132 if let Ok(pages) = lock_gpn_with_subrange(gm, gpadl.gpns()[0]) {
2135 ring::inspect_ring(pages.pages()[0], &mut resp);
2136 }
2137}
2138
2139fn gpadl_ring_size(gpadl: &AlignedGpadlView) -> usize {
2140 (gpadl.gpns().len() - 1) * PAGE_SIZE
2142}
2143
2144fn lock_page_with_subrange(gm: &GuestMemory, offset: u64) -> anyhow::Result<guestmem::LockedPages> {
2149 Ok(gm
2150 .lockable_subrange(offset, PAGE_SIZE as u64)?
2151 .lock_gpns(false, &[0])?)
2152}
2153
2154fn lock_gpn_with_subrange(gm: &GuestMemory, gpn: u64) -> anyhow::Result<guestmem::LockedPages> {
2159 lock_page_with_subrange(gm, gpn * PAGE_SIZE as u64)
2160}
2161
2162pub(crate) struct MessageSender {
2163 send: mpsc::Sender<SynicMessage>,
2164 multiclient: bool,
2165}
2166
2167impl MessageSender {
2168 fn poll_handle_message(
2169 &self,
2170 cx: &mut std::task::Context<'_>,
2171 msg: &[u8],
2172 trusted: bool,
2173 ) -> Poll<Result<(), SendError>> {
2174 let mut send = self.send.clone();
2175 ready!(send.poll_ready(cx))?;
2176 send.start_send(SynicMessage {
2177 data: msg.to_vec(),
2178 multiclient: self.multiclient,
2179 trusted,
2180 })?;
2181
2182 Poll::Ready(Ok(()))
2183 }
2184}
2185
2186impl MessagePort for MessageSender {
2187 fn poll_handle_message(
2188 &self,
2189 cx: &mut std::task::Context<'_>,
2190 msg: &[u8],
2191 trusted: bool,
2192 ) -> Poll<()> {
2193 if let Err(err) = ready!(self.poll_handle_message(cx, msg, trusted)) {
2194 tracelimit::error_ratelimited!(
2195 error = &err as &dyn std::error::Error,
2196 "failed to send message"
2197 );
2198 }
2199
2200 Poll::Ready(())
2201 }
2202}
2203
2204#[async_trait]
2205impl ParentBus for VmbusServerControl {
2206 async fn add_child(&self, request: OfferInput) -> anyhow::Result<OfferResources> {
2207 self.offer(request).await
2208 }
2209
2210 fn clone_bus(&self) -> Box<dyn ParentBus> {
2211 Box::new(self.clone())
2212 }
2213
2214 fn use_event(&self) -> bool {
2215 self.use_event
2216 }
2217}