1#![expect(missing_docs)]
5#![forbid(unsafe_code)]
6
7mod channel_bitmap;
8pub mod channels;
9pub mod event;
10pub mod hvsock;
11mod monitor;
12mod proxyintegration;
13
14pub type Guid = guid::Guid;
16
17use anyhow::Context;
18use async_trait::async_trait;
19use channel_bitmap::ChannelBitmap;
20use channels::ConnectionTarget;
21pub use channels::InitiateContactRequest;
22use channels::MessageTarget;
23pub use channels::MnfUsage;
24use channels::ModifyConnectionRequest;
25pub use channels::ModifyConnectionResponse;
26use channels::Notifier;
27use channels::OfferId;
28pub use channels::OfferParamsInternal;
29use channels::OpenParams;
30use channels::RestoreError;
31pub use channels::Update;
32use futures::FutureExt;
33use futures::StreamExt;
34use futures::channel::mpsc;
35use futures::channel::mpsc::SendError;
36use futures::future::OptionFuture;
37use futures::future::poll_fn;
38use futures::stream::SelectAll;
39use guestmem::GuestMemory;
40use hvdef::Vtl;
41use inspect::Inspect;
42use mesh::payload::Protobuf;
43use mesh::rpc::FailableRpc;
44use mesh::rpc::Rpc;
45use mesh::rpc::RpcError;
46use mesh::rpc::RpcSend;
47use pal_async::driver::Driver;
48use pal_async::driver::SpawnDriver;
49use pal_async::task::Task;
50use pal_async::timer::PolledTimer;
51use pal_event::Event;
52#[cfg(windows)]
53pub use proxyintegration::ProxyIntegration;
54#[cfg(windows)]
55pub use proxyintegration::ProxyServerInfo;
56use ring::PAGE_SIZE;
57use std::collections::HashMap;
58use std::future;
59use std::future::Future;
60use std::pin::Pin;
61use std::sync::Arc;
62use std::task::Poll;
63use std::task::ready;
64use std::time::Duration;
65use unicycle::FuturesUnordered;
66use vmbus_channel::bus::ChannelRequest;
67use vmbus_channel::bus::ChannelServerRequest;
68use vmbus_channel::bus::GpadlRequest;
69use vmbus_channel::bus::ModifyRequest;
70use vmbus_channel::bus::OfferInput;
71use vmbus_channel::bus::OfferKey;
72use vmbus_channel::bus::OfferResources;
73use vmbus_channel::bus::OpenData;
74use vmbus_channel::bus::OpenRequest;
75use vmbus_channel::bus::ParentBus;
76use vmbus_channel::bus::RestoreResult;
77use vmbus_channel::gpadl::GpadlMap;
78use vmbus_channel::gpadl_ring::AlignedGpadlView;
79use vmbus_core::HvsockConnectRequest;
80use vmbus_core::HvsockConnectResult;
81use vmbus_core::MaxVersionInfo;
82use vmbus_core::OutgoingMessage;
83use vmbus_core::TaggedStream;
84use vmbus_core::VersionInfo;
85use vmbus_core::protocol;
86pub use vmbus_core::protocol::GpadlId;
87#[cfg(windows)]
88use vmbus_proxy::ProxyHandle;
89use vmbus_ring as ring;
90use vmbus_ring::gparange::MultiPagedRangeBuf;
91use vmcore::interrupt::Interrupt;
92use vmcore::save_restore::SavedStateRoot;
93use vmcore::synic::EventPort;
94use vmcore::synic::GuestEventPort;
95use vmcore::synic::GuestMessagePort;
96use vmcore::synic::MessagePort;
97use vmcore::synic::MonitorPageGpas;
98use vmcore::synic::SynicPortAccess;
99
100const SINT: u8 = 2;
101pub const REDIRECT_SINT: u8 = 7;
102pub const REDIRECT_VTL: Vtl = Vtl::Vtl2;
103const SHARED_EVENT_CONNECTION_ID: u32 = 2;
104const EVENT_PORT_ID: u32 = 2;
105const VMBUS_MESSAGE_TYPE: u32 = 1;
106
107const MAX_CONCURRENT_HVSOCK_REQUESTS: usize = 16;
108
109pub struct VmbusServer {
110 task_send: mesh::Sender<VmbusRequest>,
111 control: Arc<VmbusServerControl>,
112 _message_port: Box<dyn Sync + Send>,
113 _multiclient_message_port: Option<Box<dyn Sync + Send>>,
114 task: Task<ServerTask>,
115}
116
117pub struct VmbusServerBuilder<T: SpawnDriver> {
118 spawner: T,
119 synic: Arc<dyn SynicPortAccess>,
120 gm: GuestMemory,
121 private_gm: Option<GuestMemory>,
122 vtl: Vtl,
123 hvsock_notify: Option<HvsockServerChannelHalf>,
124 server_relay: Option<VmbusServerChannelHalf>,
125 saved_state_notify: Option<mesh::Sender<SavedStateRequest>>,
126 external_server: Option<mesh::Sender<InitiateContactRequest>>,
127 external_requests: Option<mesh::Receiver<InitiateContactRequest>>,
128 use_message_redirect: bool,
129 channel_id_offset: u16,
130 max_version: Option<MaxVersionInfo>,
131 delay_max_version: bool,
132 enable_mnf: bool,
133 force_confidential_external_memory: bool,
134 send_messages_while_stopped: bool,
135 channel_unstick_delay: Option<Duration>,
136}
137
138#[derive(mesh::MeshPayload)]
139pub enum SavedStateRequest {
141 Set(FailableRpc<Box<channels::SavedState>, ()>),
142 Clear(Rpc<(), ()>),
143}
144
145pub struct ServerChannelHalf<Request, Response> {
147 request_send: mesh::Sender<Request>,
148 response_receive: mesh::Receiver<Response>,
149}
150
151pub struct RelayChannelHalf<Request, Response> {
153 pub request_receive: mesh::Receiver<Request>,
154 pub response_send: mesh::Sender<Response>,
155}
156
157pub struct RelayChannel<Request, Response> {
159 pub relay_half: RelayChannelHalf<Request, Response>,
160 pub server_half: ServerChannelHalf<Request, Response>,
161}
162
163impl<Request: 'static + Send, Response: 'static + Send> RelayChannel<Request, Response> {
164 pub fn new() -> Self {
166 let (request_send, request_receive) = mesh::channel();
167 let (response_send, response_receive) = mesh::channel();
168 Self {
169 relay_half: RelayChannelHalf {
170 request_receive,
171 response_send,
172 },
173 server_half: ServerChannelHalf {
174 request_send,
175 response_receive,
176 },
177 }
178 }
179}
180
181pub type VmbusServerChannelHalf = ServerChannelHalf<ModifyRelayRequest, ModifyConnectionResponse>;
182pub type VmbusRelayChannelHalf = RelayChannelHalf<ModifyRelayRequest, ModifyConnectionResponse>;
183pub type VmbusRelayChannel = RelayChannel<ModifyRelayRequest, ModifyConnectionResponse>;
184pub type HvsockServerChannelHalf = ServerChannelHalf<HvsockConnectRequest, HvsockConnectResult>;
185pub type HvsockRelayChannelHalf = RelayChannelHalf<HvsockConnectRequest, HvsockConnectResult>;
186pub type HvsockRelayChannel = RelayChannel<HvsockConnectRequest, HvsockConnectResult>;
187
188#[derive(Debug, Copy, Clone)]
193pub struct ModifyRelayRequest {
194 pub version: Option<u32>,
195 pub monitor_page: Update<MonitorPageGpas>,
196 pub use_interrupt_page: Option<bool>,
197}
198
199impl From<ModifyConnectionRequest> for ModifyRelayRequest {
200 fn from(value: ModifyConnectionRequest) -> Self {
201 Self {
202 version: value.version,
203 monitor_page: value.monitor_page,
204 use_interrupt_page: match value.interrupt_page {
205 Update::Unchanged => None,
206 Update::Reset => Some(false),
207 Update::Set(_) => Some(true),
208 },
209 }
210 }
211}
212
213#[derive(Debug)]
214enum VmbusRequest {
215 Reset(Rpc<(), ()>),
216 Inspect(inspect::Deferred),
217 Save(Rpc<(), SavedState>),
218 Restore(Rpc<Box<SavedState>, Result<(), RestoreError>>),
219 Start,
220 Stop(Rpc<(), ()>),
221}
222
223#[derive(mesh::MeshPayload, Debug)]
224pub struct OfferInfo {
225 pub params: OfferParamsInternal,
226 pub event: Interrupt,
227 pub request_send: mesh::Sender<ChannelRequest>,
228 pub server_request_recv: mesh::Receiver<ChannelServerRequest>,
229}
230
231#[expect(clippy::large_enum_variant)]
232#[derive(mesh::MeshPayload)]
233pub(crate) enum OfferRequest {
234 Offer(FailableRpc<OfferInfo, ()>),
235 ForceReset(Rpc<(), ()>),
236}
237
238impl Inspect for VmbusServer {
239 fn inspect(&self, req: inspect::Request<'_>) {
240 self.task_send.send(VmbusRequest::Inspect(req.defer()));
241 }
242}
243
244struct ChannelEvent(Interrupt);
245
246impl EventPort for ChannelEvent {
247 fn handle_event(&self, _flag: u16) {
248 self.0.deliver();
249 }
250
251 fn os_event(&self) -> Option<&Event> {
252 self.0.event()
253 }
254}
255
256#[derive(Debug, Protobuf, SavedStateRoot)]
257#[mesh(package = "vmbus.server")]
258pub struct SavedState {
259 #[mesh(1)]
260 server: channels::SavedState,
261 #[mesh(2)]
265 lost_synic_bug_fixed: bool,
266}
267
268const MESSAGE_CONNECTION_ID: u32 = 1;
269const MULTICLIENT_MESSAGE_CONNECTION_ID: u32 = 4;
270
271impl<T: SpawnDriver + Clone> VmbusServerBuilder<T> {
272 pub fn new(spawner: T, synic: Arc<dyn SynicPortAccess>, gm: GuestMemory) -> Self {
274 Self {
275 spawner,
276 synic,
277 gm,
278 private_gm: None,
279 vtl: Vtl::Vtl0,
280 hvsock_notify: None,
281 server_relay: None,
282 saved_state_notify: None,
283 external_server: None,
284 external_requests: None,
285 use_message_redirect: false,
286 channel_id_offset: 0,
287 max_version: None,
288 delay_max_version: false,
289 enable_mnf: false,
290 force_confidential_external_memory: false,
291 send_messages_while_stopped: false,
292 channel_unstick_delay: Some(Duration::from_millis(100)),
293 }
294 }
295
296 pub fn private_gm(mut self, private_gm: Option<GuestMemory>) -> Self {
300 self.private_gm = private_gm;
301 self
302 }
303
304 pub fn vtl(mut self, vtl: Vtl) -> Self {
306 self.vtl = vtl;
307 self
308 }
309
310 pub fn hvsock_notify(mut self, hvsock_notify: Option<HvsockServerChannelHalf>) -> Self {
312 self.hvsock_notify = hvsock_notify;
313 self
314 }
315
316 pub fn saved_state_notify(
318 mut self,
319 saved_state_notify: Option<mesh::Sender<SavedStateRequest>>,
320 ) -> Self {
321 self.saved_state_notify = saved_state_notify;
322 self
323 }
324
325 pub fn server_relay(mut self, server_relay: Option<VmbusServerChannelHalf>) -> Self {
328 self.server_relay = server_relay;
329 self
330 }
331
332 pub fn external_requests(
334 mut self,
335 external_requests: Option<mesh::Receiver<InitiateContactRequest>>,
336 ) -> Self {
337 self.external_requests = external_requests;
338 self
339 }
340
341 pub fn external_server(
344 mut self,
345 external_server: Option<mesh::Sender<InitiateContactRequest>>,
346 ) -> Self {
347 self.external_server = external_server;
348 self
349 }
350
351 pub fn use_message_redirect(mut self, use_message_redirect: bool) -> Self {
353 self.use_message_redirect = use_message_redirect;
354 self
355 }
356
357 pub fn enable_channel_id_offset(mut self, enable: bool) -> Self {
362 self.channel_id_offset = if enable { 1024 } else { 0 };
363 self
364 }
365
366 pub fn max_version(mut self, max_version: Option<MaxVersionInfo>) -> Self {
370 self.max_version = max_version;
371 self
372 }
373
374 pub fn delay_max_version(mut self, delay: bool) -> Self {
379 self.delay_max_version = delay;
380 self
381 }
382
383 pub fn enable_mnf(mut self, enable: bool) -> Self {
387 self.enable_mnf = enable;
388 self
389 }
390
391 pub fn force_confidential_external_memory(mut self, force: bool) -> Self {
394 self.force_confidential_external_memory = force;
395 self
396 }
397
398 pub fn send_messages_while_stopped(mut self, send: bool) -> Self {
405 self.send_messages_while_stopped = send;
406 self
407 }
408
409 pub fn channel_unstick_delay(mut self, delay: Option<Duration>) -> Self {
416 self.channel_unstick_delay = delay;
417 self
418 }
419
420 pub fn build(self) -> anyhow::Result<VmbusServer> {
425 #[expect(clippy::disallowed_methods)] let (message_send, message_recv) = mpsc::channel(64);
427 let message_sender = Arc::new(MessageSender {
428 send: message_send.clone(),
429 multiclient: self.use_message_redirect,
430 });
431
432 let (redirect_vtl, redirect_sint) = if self.use_message_redirect {
433 (REDIRECT_VTL, REDIRECT_SINT)
434 } else {
435 (self.vtl, SINT)
436 };
437
438 let connection_id = if self.vtl == Vtl::Vtl0 && !self.use_message_redirect {
441 MESSAGE_CONNECTION_ID
442 } else {
443 VmbusServer::get_child_message_connection_id(0, redirect_sint, redirect_vtl)
446 };
447
448 let _message_port = self
449 .synic
450 .add_message_port(connection_id, redirect_vtl, message_sender)
451 .context("failed to create vmbus synic ports")?;
452
453 let _multiclient_message_port = if self.vtl == Vtl::Vtl0 && !self.use_message_redirect {
457 let multiclient_message_sender = Arc::new(MessageSender {
458 send: message_send,
459 multiclient: true,
460 });
461
462 Some(
463 self.synic
464 .add_message_port(
465 MULTICLIENT_MESSAGE_CONNECTION_ID,
466 self.vtl,
467 multiclient_message_sender,
468 )
469 .context("failed to create vmbus synic ports")?,
470 )
471 } else {
472 None
473 };
474
475 let (offer_send, offer_recv) = mesh::mpsc_channel();
476 let control = Arc::new(VmbusServerControl {
477 mem: self.gm.clone(),
478 private_mem: self.private_gm.clone(),
479 send: offer_send,
480 use_event: self.synic.prefer_os_events(),
481 force_confidential_external_memory: self.force_confidential_external_memory,
482 });
483
484 let mut server = channels::Server::new(self.vtl, connection_id, self.channel_id_offset);
485
486 if let Some(version) = self.max_version {
488 server.set_compatibility_version(version, self.delay_max_version);
489 }
490 let (relay_request_send, relay_response_recv) =
491 if let Some(server_relay) = self.server_relay {
492 let r = server_relay.response_receive.boxed().fuse();
493 (server_relay.request_send, r)
494 } else {
495 let (req_send, req_recv) = mesh::channel();
496 let resp_recv = req_recv
497 .map(|_| {
498 ModifyConnectionResponse::Supported(
499 protocol::ConnectionState::SUCCESSFUL,
500 protocol::FeatureFlags::from_bits(u32::MAX),
501 )
502 })
503 .boxed()
504 .fuse();
505 (req_send, resp_recv)
506 };
507
508 let (hvsock_send, hvsock_recv) = if let Some(hvsock_notify) = self.hvsock_notify {
510 let r = hvsock_notify.response_receive.boxed().fuse();
511 (hvsock_notify.request_send, r)
512 } else {
513 let (req_send, req_recv) = mesh::channel();
514 let resp_recv = req_recv
515 .map(|r: HvsockConnectRequest| HvsockConnectResult::from_request(&r, false))
516 .boxed()
517 .fuse();
518 (req_send, resp_recv)
519 };
520
521 let inner = ServerTaskInner {
522 running: false,
523 send_messages_while_stopped: self.send_messages_while_stopped,
524 gm: self.gm,
525 private_gm: self.private_gm,
526 vtl: self.vtl,
527 redirect_vtl,
528 redirect_sint,
529 message_port: self
530 .synic
531 .new_guest_message_port(redirect_vtl, 0, redirect_sint)?,
532 synic: self.synic,
533 hvsock_requests: 0,
534 hvsock_send,
535 saved_state_notify: self.saved_state_notify,
536 channels: HashMap::new(),
537 channel_responses: FuturesUnordered::new(),
538 relay_send: relay_request_send,
539 external_server_send: self.external_server,
540 channel_bitmap: None,
541 shared_event_port: None,
542 reset_done: Vec::new(),
543 enable_mnf: self.enable_mnf,
544 };
545
546 let (task_send, task_recv) = mesh::channel();
547 let mut server_task = ServerTask {
548 driver: Box::new(self.spawner.clone()),
549 server,
550 task_recv,
551 offer_recv,
552 message_recv,
553 server_request_recv: SelectAll::new(),
554 inner,
555 external_requests: self.external_requests,
556 next_seq: 0,
557 unstick_on_start: false,
558 channel_unstickers: FuturesUnordered::new(),
559 channel_unstick_delay: self.channel_unstick_delay,
560 };
561
562 let task = self.spawner.spawn("vmbus server", async move {
563 server_task.run(relay_response_recv, hvsock_recv).await;
564 server_task
565 });
566
567 Ok(VmbusServer {
568 task_send,
569 control,
570 _message_port,
571 _multiclient_message_port,
572 task,
573 })
574 }
575}
576
577impl VmbusServer {
578 pub fn builder<T: SpawnDriver + Clone>(
580 spawner: T,
581 synic: Arc<dyn SynicPortAccess>,
582 gm: GuestMemory,
583 ) -> VmbusServerBuilder<T> {
584 VmbusServerBuilder::new(spawner, synic, gm)
585 }
586
587 pub async fn save(&self) -> SavedState {
588 self.task_send.call(VmbusRequest::Save, ()).await.unwrap()
589 }
590
591 pub async fn restore(&self, state: SavedState) -> Result<(), RestoreError> {
592 self.task_send
593 .call(VmbusRequest::Restore, Box::new(state))
594 .await
595 .unwrap()
596 }
597
598 pub async fn stop(&self) {
600 self.task_send.call(VmbusRequest::Stop, ()).await.unwrap()
601 }
602
603 pub fn start(&self) {
605 self.task_send.send(VmbusRequest::Start);
606 }
607
608 pub async fn reset(&self) {
610 tracing::debug!("resetting channel state");
611 self.task_send.call(VmbusRequest::Reset, ()).await.unwrap()
612 }
613
614 pub async fn shutdown(self) {
616 drop(self.task_send);
617 let _ = self.task.await;
618 }
619
620 pub fn control(&self) -> Arc<VmbusServerControl> {
622 self.control.clone()
623 }
624
625 fn get_child_message_connection_id(vp_index: u32, sint_index: u8, vtl: Vtl) -> u32 {
628 MULTICLIENT_MESSAGE_CONNECTION_ID
629 | (vtl as u32) << 22
630 | vp_index << 8
631 | (sint_index as u32) << 4
632 }
633
634 fn get_child_event_port_id(channel_id: protocol::ChannelId, sint_index: u8, vtl: Vtl) -> u32 {
635 EVENT_PORT_ID | (vtl as u32) << 22 | channel_id.0 << 8 | (sint_index as u32) << 4
636 }
637}
638
639#[derive(mesh::MeshPayload)]
640pub struct RestoreInfo {
641 open_data: Option<OpenData>,
642 gpadls: Vec<(GpadlId, u16, Vec<u64>)>,
643 interrupt: Option<Interrupt>,
644}
645
646#[derive(Default)]
647pub struct SynicMessage {
648 data: Vec<u8>,
649 multiclient: bool,
650 trusted: bool,
651}
652
653#[derive(Debug, Clone, Copy)]
655struct OfferInstanceId {
656 offer_id: OfferId,
657 seq: u64,
658}
659
660struct ServerTask {
661 driver: Box<dyn Driver>,
662 server: channels::Server,
663 task_recv: mesh::Receiver<VmbusRequest>,
664 offer_recv: mesh::Receiver<OfferRequest>,
665 message_recv: mpsc::Receiver<SynicMessage>,
666 server_request_recv:
667 SelectAll<TaggedStream<OfferInstanceId, mesh::Receiver<ChannelServerRequest>>>,
668 inner: ServerTaskInner,
669 external_requests: Option<mesh::Receiver<InitiateContactRequest>>,
670 next_seq: u64,
672 unstick_on_start: bool,
673 channel_unstickers: FuturesUnordered<Pin<Box<dyn Send + Future<Output = OfferInstanceId>>>>,
674 channel_unstick_delay: Option<Duration>,
675}
676
677struct ServerTaskInner {
678 running: bool,
679 send_messages_while_stopped: bool,
680 gm: GuestMemory,
681 private_gm: Option<GuestMemory>,
682 synic: Arc<dyn SynicPortAccess>,
683 vtl: Vtl,
684 redirect_vtl: Vtl,
685 redirect_sint: u8,
686 message_port: Box<dyn GuestMessagePort>,
687 hvsock_requests: usize,
688 hvsock_send: mesh::Sender<HvsockConnectRequest>,
689 saved_state_notify: Option<mesh::Sender<SavedStateRequest>>,
690 channels: HashMap<OfferId, Channel>,
691 channel_responses: FuturesUnordered<
692 Pin<Box<dyn Send + Future<Output = (OfferId, u64, Result<ChannelResponse, RpcError>)>>>,
693 >,
694 external_server_send: Option<mesh::Sender<InitiateContactRequest>>,
695 relay_send: mesh::Sender<ModifyRelayRequest>,
696 channel_bitmap: Option<Arc<ChannelBitmap>>,
697 shared_event_port: Option<Box<dyn Send>>,
698 reset_done: Vec<Rpc<(), ()>>,
699 enable_mnf: bool,
700}
701
702#[derive(Debug)]
703enum ChannelResponse {
704 Open(bool),
705 Close,
706 Gpadl(GpadlId, bool),
707 TeardownGpadl(GpadlId),
708 Modify(i32),
709}
710
711#[derive(Debug, Copy, Clone, PartialEq, Eq)]
712enum ChannelUnstickState {
713 None,
714 Queued,
715 NeedsRequeue,
716}
717
718struct Channel {
719 key: OfferKey,
720 send: mesh::Sender<ChannelRequest>,
721 seq: u64,
722 state: ChannelState,
723 gpadls: Arc<GpadlMap>,
724 guest_to_host_event: Arc<ChannelEvent>,
725 flags: protocol::OfferFlags,
726 reserved_state: ReservedState,
731 unstick_state: ChannelUnstickState,
732}
733
734struct ReservedState {
735 message_port: Option<Box<dyn GuestMessagePort>>,
736 target: ConnectionTarget,
737}
738
739struct ChannelOpenState {
740 open_params: OpenParams,
741 _event_port: Box<dyn Send>,
742 guest_event_port: Box<dyn GuestEventPort>,
743 host_to_guest_interrupt: Interrupt,
744}
745
746enum ChannelState {
747 Closed,
748 Open(Box<ChannelOpenState>),
749 Closing,
750}
751
752impl ServerTask {
753 fn handle_offer(&mut self, mut info: OfferInfo) -> anyhow::Result<()> {
754 let key = info.params.key();
755 let flags = info.params.flags;
756
757 if self.inner.enable_mnf && self.inner.synic.monitor_support().is_some() {
758 if info.params.use_mnf.is_relayed() {
763 info.params.use_mnf = MnfUsage::Enabled {
764 latency: Duration::ZERO,
765 }
766 }
767 } else if info.params.use_mnf.is_enabled() {
768 info.params.use_mnf = MnfUsage::Disabled;
771 }
772
773 let offer_id = self
774 .server
775 .with_notifier(&mut self.inner)
776 .offer_channel(info.params)
777 .context("channel offer failed")?;
778
779 tracing::debug!(?offer_id, %key, "offered channel");
780
781 let seq = self.next_seq;
782 self.next_seq += 1;
783 self.inner.channels.insert(
784 offer_id,
785 Channel {
786 key,
787 send: info.request_send,
788 state: ChannelState::Closed,
789 gpadls: GpadlMap::new(),
790 guest_to_host_event: Arc::new(ChannelEvent(info.event)),
791 seq,
792 flags,
793 reserved_state: ReservedState {
794 message_port: None,
795 target: ConnectionTarget { vp: 0, sint: 0 },
796 },
797 unstick_state: ChannelUnstickState::None,
798 },
799 );
800
801 self.server_request_recv.push(TaggedStream::new(
802 OfferInstanceId { offer_id, seq },
803 info.server_request_recv,
804 ));
805
806 Ok(())
807 }
808
809 fn handle_revoke(&mut self, id: OfferInstanceId) {
810 if let Some(channel) = self.inner.channels.get(&id.offer_id) {
813 if channel.seq == id.seq {
814 tracing::info!(?id.offer_id, "revoking channel");
815 self.inner.channels.remove(&id.offer_id);
816 self.server
817 .with_notifier(&mut self.inner)
818 .revoke_channel(id.offer_id);
819 }
820 }
821 }
822
823 fn handle_response(
824 &mut self,
825 offer_id: OfferId,
826 seq: u64,
827 response: Result<ChannelResponse, RpcError>,
828 ) {
829 let channel = self
831 .inner
832 .channels
833 .get(&offer_id)
834 .filter(|channel| channel.seq == seq);
835
836 if let Some(channel) = channel {
837 match response {
838 Ok(response) => match response {
839 ChannelResponse::Open(result) => self.handle_open(offer_id, result),
840 ChannelResponse::Close => self.handle_close(offer_id),
841 ChannelResponse::Gpadl(gpadl_id, ok) => {
842 self.handle_gpadl_create(offer_id, gpadl_id, ok)
843 }
844 ChannelResponse::TeardownGpadl(gpadl_id) => {
845 self.handle_gpadl_teardown(offer_id, gpadl_id)
846 }
847 ChannelResponse::Modify(status) => self.handle_modify_channel(offer_id, status),
848 },
849 Err(err) => {
850 tracing::error!(
851 key = %channel.key,
852 error = &err as &dyn std::error::Error,
853 "channel response failure, channel is in inconsistent state until revoked"
854 );
855 }
856 }
857 } else {
858 tracing::debug!(offer_id = ?offer_id, seq, ?response, "received response after revoke");
859 }
860 }
861
862 fn handle_open(&mut self, offer_id: OfferId, success: bool) {
863 let status = if success {
864 let channel = self
865 .inner
866 .channels
867 .get_mut(&offer_id)
868 .expect("channel exists");
869
870 if let Some(delay) = self.channel_unstick_delay {
873 if channel.unstick_state == ChannelUnstickState::None {
874 channel.unstick_state = ChannelUnstickState::Queued;
875 let seq = channel.seq;
876 let mut timer = PolledTimer::new(&self.driver);
877 self.channel_unstickers.push(Box::pin(async move {
878 timer.sleep(delay).await;
879 OfferInstanceId { offer_id, seq }
880 }));
881 } else {
882 channel.unstick_state = ChannelUnstickState::NeedsRequeue;
883 }
884 }
885
886 0
887 } else {
888 protocol::STATUS_UNSUCCESSFUL
889 };
890
891 self.server
892 .with_notifier(&mut self.inner)
893 .open_complete(offer_id, status);
894 }
895
896 fn handle_close(&mut self, offer_id: OfferId) {
897 let channel = self
898 .inner
899 .channels
900 .get_mut(&offer_id)
901 .expect("channel still exists");
902
903 match &mut channel.state {
904 ChannelState::Closing => {
905 channel.state = ChannelState::Closed;
906 self.server
907 .with_notifier(&mut self.inner)
908 .close_complete(offer_id);
909 }
910 _ => {
911 tracing::error!(?offer_id, "invalid close channel response");
912 }
913 };
914 }
915
916 fn handle_gpadl_create(&mut self, offer_id: OfferId, gpadl_id: GpadlId, ok: bool) {
917 let status = if ok { 0 } else { protocol::STATUS_UNSUCCESSFUL };
918 self.server
919 .with_notifier(&mut self.inner)
920 .gpadl_create_complete(offer_id, gpadl_id, status);
921 }
922
923 fn handle_gpadl_teardown(&mut self, offer_id: OfferId, gpadl_id: GpadlId) {
924 self.server
925 .with_notifier(&mut self.inner)
926 .gpadl_teardown_complete(offer_id, gpadl_id);
927 }
928
929 fn handle_modify_channel(&mut self, offer_id: OfferId, status: i32) {
930 self.server
931 .with_notifier(&mut self.inner)
932 .modify_channel_complete(offer_id, status);
933 }
934
935 fn handle_restore_channel(
936 &mut self,
937 offer_id: OfferId,
938 open: bool,
939 ) -> anyhow::Result<RestoreResult> {
940 let gpadls = self.server.channel_gpadls(offer_id);
941
942 let open_request = open
945 .then(|| -> anyhow::Result<_> {
946 let params = self.server.get_restore_open_params(offer_id)?;
947 let (channel, interrupt) = self.inner.open_channel(offer_id, ¶ms)?;
948 Ok(OpenRequest::new(
949 params.open_data,
950 interrupt,
951 self.server
952 .get_version()
953 .expect("must be connected")
954 .feature_flags,
955 channel.flags,
956 ))
957 })
958 .transpose()?;
959
960 self.server
961 .with_notifier(&mut self.inner)
962 .restore_channel(offer_id, open_request.is_some())?;
963
964 let channel = self.inner.channels.get_mut(&offer_id).unwrap();
965 for gpadl in &gpadls {
966 if let Ok(buf) =
967 MultiPagedRangeBuf::new(gpadl.request.count.into(), gpadl.request.buf.clone())
968 {
969 channel.gpadls.add(gpadl.request.id, buf);
970 }
971 }
972
973 let result = RestoreResult {
974 open_request,
975 gpadls,
976 };
977 Ok(result)
978 }
979
980 async fn handle_request(&mut self, request: VmbusRequest) {
981 tracing::debug!(?request, "handle_request");
982 match request {
983 VmbusRequest::Reset(rpc) => self.handle_reset(rpc),
984 VmbusRequest::Inspect(deferred) => {
985 deferred.respond(|resp| {
986 resp.field("message_port", &self.inner.message_port)
987 .field("running", self.inner.running)
988 .field("hvsock_requests", self.inner.hvsock_requests)
989 .field("channel_unstick_delay", self.channel_unstick_delay)
990 .field_mut_with("unstick_channels", |v| {
991 let v: inspect::ValueKind = if let Some(v) = v {
992 if v == "force" {
993 self.unstick_channels(true);
994 v.into()
995 } else {
996 let v =
997 v.parse().ok().context("expected false, true, or force")?;
998 if v {
999 self.unstick_channels(false);
1000 }
1001 v.into()
1002 }
1003 } else {
1004 false.into()
1005 };
1006 anyhow::Ok(v)
1007 })
1008 .merge(&self.server.with_notifier(&mut self.inner));
1009 });
1010 }
1011 VmbusRequest::Save(rpc) => rpc.handle_sync(|()| SavedState {
1012 server: self.server.save(),
1013 lost_synic_bug_fixed: true,
1014 }),
1015 VmbusRequest::Restore(rpc) => {
1016 rpc.handle(async |state| {
1017 self.unstick_on_start = !state.lost_synic_bug_fixed;
1018 if let Some(sender) = &self.inner.saved_state_notify {
1019 tracing::trace!("sending saved state to proxy");
1020 if let Err(err) = sender
1021 .call_failable(SavedStateRequest::Set, Box::new(state.server.clone()))
1022 .await
1023 {
1024 tracing::error!(
1025 err = &err as &dyn std::error::Error,
1026 "failed to restore proxy saved state"
1027 );
1028 return Err(RestoreError::ServerError(err.into()));
1029 }
1030 }
1031
1032 self.server
1033 .with_notifier(&mut self.inner)
1034 .restore(state.server)
1035 })
1036 .await
1037 }
1038 VmbusRequest::Stop(rpc) => rpc.handle_sync(|()| {
1039 if self.inner.running {
1040 self.inner.running = false;
1041 }
1042 }),
1043 VmbusRequest::Start => {
1044 if !self.inner.running {
1045 self.inner.running = true;
1046 if let Some(sender) = self.inner.saved_state_notify.as_ref() {
1047 tracing::trace!("sending clear saved state message to proxy");
1050 sender
1051 .call(SavedStateRequest::Clear, ())
1052 .await
1053 .expect("failed to clear proxy saved state");
1054 }
1055
1056 self.server
1057 .with_notifier(&mut self.inner)
1058 .revoke_unclaimed_channels();
1059 if self.unstick_on_start {
1060 tracing::info!(
1061 "lost synic bug fix is not in yet, call unstick_channels to mitigate the issue."
1062 );
1063 self.unstick_channels(false);
1064 self.unstick_on_start = false;
1065 }
1066 }
1067 }
1068 }
1069 }
1070
1071 fn handle_reset(&mut self, rpc: Rpc<(), ()>) {
1072 let needs_reset = self.inner.reset_done.is_empty();
1073 self.inner.reset_done.push(rpc);
1074 if needs_reset {
1075 self.server.with_notifier(&mut self.inner).reset();
1076 }
1077 }
1078
1079 fn handle_relay_response(&mut self, response: ModifyConnectionResponse) {
1080 self.server
1081 .with_notifier(&mut self.inner)
1082 .complete_modify_connection(response);
1083 }
1084
1085 fn handle_tl_connect_result(&mut self, result: HvsockConnectResult) {
1086 assert_ne!(self.inner.hvsock_requests, 0);
1087 self.inner.hvsock_requests -= 1;
1088
1089 self.server
1090 .with_notifier(&mut self.inner)
1091 .send_tl_connect_result(result);
1092 }
1093
1094 fn handle_synic_message(&mut self, message: SynicMessage) {
1095 match self
1096 .server
1097 .with_notifier(&mut self.inner)
1098 .handle_synic_message(message)
1099 {
1100 Ok(()) => {}
1101 Err(err) => {
1102 tracing::warn!(
1103 error = &err as &dyn std::error::Error,
1104 "synic message error"
1105 );
1106 }
1107 }
1108 }
1109
1110 fn handle_external_request(&mut self, request: InitiateContactRequest) {
1117 self.server
1118 .with_notifier(&mut self.inner)
1119 .initiate_contact(request);
1120 }
1121
1122 async fn run(
1123 &mut self,
1124 mut relay_response_recv: impl futures::stream::FusedStream<Item = ModifyConnectionResponse>
1125 + Unpin,
1126 mut hvsock_recv: impl futures::stream::FusedStream<Item = HvsockConnectResult> + Unpin,
1127 ) {
1128 loop {
1129 let running_not_resetting = self.inner.running && self.inner.reset_done.is_empty();
1134 let mut external_requests = OptionFuture::from(
1135 running_not_resetting
1136 .then(|| {
1137 self.external_requests
1138 .as_mut()
1139 .map(|r| r.select_next_some())
1140 })
1141 .flatten(),
1142 );
1143
1144 let has_pending_messages = self.server.has_pending_messages();
1146 let message_port = self.inner.message_port.as_mut();
1147 let mut flush_pending_messages =
1148 OptionFuture::from((running_not_resetting && has_pending_messages).then(|| {
1149 poll_fn(|cx| {
1150 self.server.poll_flush_pending_messages(|msg| {
1151 message_port.poll_post_message(cx, VMBUS_MESSAGE_TYPE, msg.data())
1152 })
1153 })
1154 .fuse()
1155 }));
1156
1157 let mut message_recv = OptionFuture::from(
1161 (running_not_resetting
1162 && !has_pending_messages
1163 && self.inner.hvsock_requests < MAX_CONCURRENT_HVSOCK_REQUESTS)
1164 .then(|| self.message_recv.select_next_some()),
1165 );
1166
1167 let mut channel_response = OptionFuture::from(
1169 (self.inner.running || !self.inner.reset_done.is_empty())
1170 .then(|| self.inner.channel_responses.select_next_some()),
1171 );
1172
1173 let mut hvsock_response =
1175 OptionFuture::from(running_not_resetting.then(|| hvsock_recv.select_next_some()));
1176
1177 let mut channel_unstickers = OptionFuture::from(
1178 running_not_resetting.then(|| self.channel_unstickers.select_next_some()),
1179 );
1180
1181 futures::select! { r = self.task_recv.recv().fuse() => {
1183 if let Ok(request) = r {
1184 self.handle_request(request).await;
1185 } else {
1186 break;
1187 }
1188 }
1189 r = self.offer_recv.select_next_some() => {
1190 match r {
1191 OfferRequest::Offer(rpc) => {
1192 rpc.handle_failable_sync(|request| { self.handle_offer(request) })
1193 },
1194 OfferRequest::ForceReset(rpc) => {
1195 self.handle_reset(rpc);
1196 }
1197 }
1198 }
1199 r = self.server_request_recv.select_next_some() => {
1200 match r {
1201 (id, Some(request)) => match request {
1202 ChannelServerRequest::Restore(rpc) => rpc.handle_failable_sync(|open| {
1203 self.handle_restore_channel(id.offer_id, open)
1204 }),
1205 ChannelServerRequest::Revoke(rpc) => rpc.handle_sync(|_| {
1206 self.handle_revoke(id);
1207 })
1208 },
1209 (id, None) => self.handle_revoke(id),
1210 }
1211 }
1212 r = channel_response => {
1213 let (id, seq, response) = r.unwrap();
1214 self.handle_response(id, seq, response);
1215 }
1216 r = relay_response_recv.select_next_some() => {
1217 self.handle_relay_response(r);
1218 },
1219 r = hvsock_response => {
1220 self.handle_tl_connect_result(r.unwrap());
1221 }
1222 data = message_recv => {
1223 let data = data.unwrap();
1224 self.handle_synic_message(data);
1225 }
1226 r = external_requests => {
1227 let r = r.unwrap();
1228 self.handle_external_request(r);
1229 }
1230 r = channel_unstickers => {
1231 self.unstick_channel_by_id(r.unwrap());
1232 }
1233 _r = flush_pending_messages => {}
1234 complete => break,
1235 }
1236 }
1237 }
1238
1239 fn unstick_channels(&self, force: bool) {
1243 let Some(version) = self.server.get_version() else {
1244 tracing::warn!("cannot unstick when not connected");
1245 return;
1246 };
1247
1248 for channel in self.inner.channels.values() {
1249 let gm = self.inner.get_gm_for_channel(version, channel);
1250 if let Err(err) = Self::unstick_channel(gm, channel, force, true) {
1251 tracing::warn!(
1252 channel = %channel.key,
1253 error = err.as_ref() as &dyn std::error::Error,
1254 "could not unstick channel"
1255 );
1256 }
1257 }
1258 }
1259
1260 fn unstick_channel_by_id(&mut self, id: OfferInstanceId) {
1263 let Some(version) = self.server.get_version() else {
1264 tracelimit::warn_ratelimited!("cannot unstick when not connected");
1265 return;
1266 };
1267
1268 if let Some(channel) = self.inner.channels.get_mut(&id.offer_id) {
1269 if channel.seq != id.seq {
1270 return;
1272 }
1273
1274 if channel.unstick_state == ChannelUnstickState::NeedsRequeue {
1277 channel.unstick_state = ChannelUnstickState::Queued;
1278 let mut timer = PolledTimer::new(&self.driver);
1279 let delay = self.channel_unstick_delay.unwrap();
1280 self.channel_unstickers.push(Box::pin(async move {
1281 timer.sleep(delay).await;
1282 id
1283 }));
1284
1285 return;
1286 }
1287
1288 channel.unstick_state = ChannelUnstickState::None;
1289 let gm = select_gm_for_channel(
1290 &self.inner.gm,
1291 self.inner.private_gm.as_ref(),
1292 version,
1293 channel,
1294 );
1295 if let Err(err) = Self::unstick_channel(gm, channel, false, false) {
1296 tracelimit::warn_ratelimited!(
1297 channel = %channel.key,
1298 error = err.as_ref() as &dyn std::error::Error,
1299 "could not unstick channel"
1300 );
1301 }
1302 }
1303 }
1304
1305 fn unstick_channel(
1306 gm: &GuestMemory,
1307 channel: &Channel,
1308 force: bool,
1309 unstick_host: bool,
1310 ) -> anyhow::Result<()> {
1311 if let ChannelState::Open(state) = &channel.state {
1312 if force {
1313 tracing::info!(channel = %channel.key, "waking host and guest");
1314 if unstick_host {
1315 channel.guest_to_host_event.0.deliver();
1316 }
1317 state.host_to_guest_interrupt.deliver();
1318 return Ok(());
1319 }
1320
1321 let gpadl = channel
1322 .gpadls
1323 .clone()
1324 .view()
1325 .map(state.open_params.open_data.ring_gpadl_id)
1326 .context("couldn't find ring gpadl")?;
1327
1328 let aligned = AlignedGpadlView::new(gpadl)
1329 .ok()
1330 .context("ring not aligned")?;
1331 let (in_gpadl, out_gpadl) = aligned
1332 .split(state.open_params.open_data.ring_offset)
1333 .ok()
1334 .context("couldn't split ring")?;
1335
1336 if let Err(err) = Self::unstick_incoming_ring(
1337 gm,
1338 channel,
1339 in_gpadl,
1340 unstick_host.then_some(channel.guest_to_host_event.as_ref()),
1341 &state.host_to_guest_interrupt,
1342 ) {
1343 tracelimit::warn_ratelimited!(
1344 channel = %channel.key,
1345 error = err.as_ref() as &dyn std::error::Error,
1346 "could not unstick incoming ring"
1347 );
1348 }
1349 if let Err(err) = Self::unstick_outgoing_ring(
1350 gm,
1351 channel,
1352 out_gpadl,
1353 unstick_host.then_some(channel.guest_to_host_event.as_ref()),
1354 &state.host_to_guest_interrupt,
1355 ) {
1356 tracelimit::warn_ratelimited!(
1357 channel = %channel.key,
1358 error = err.as_ref() as &dyn std::error::Error,
1359 "could not unstick outgoing ring"
1360 );
1361 }
1362 }
1363 Ok(())
1364 }
1365
1366 fn unstick_incoming_ring(
1367 gm: &GuestMemory,
1368 channel: &Channel,
1369 in_gpadl: AlignedGpadlView,
1370 guest_to_host_event: Option<&ChannelEvent>,
1371 host_to_guest_interrupt: &Interrupt,
1372 ) -> anyhow::Result<()> {
1373 let control_page = lock_gpn_with_subrange(gm, in_gpadl.gpns()[0])?;
1374 if let Some(guest_to_host_event) = guest_to_host_event {
1375 if ring::reader_needs_signal(control_page.pages()[0]) {
1376 tracelimit::info_ratelimited!(channel = %channel.key, "waking host for incoming ring");
1377 guest_to_host_event.0.deliver();
1378 }
1379 }
1380
1381 let ring_size = gpadl_ring_size(&in_gpadl).try_into()?;
1382 if ring::writer_needs_signal(control_page.pages()[0], ring_size) {
1383 tracelimit::info_ratelimited!(channel = %channel.key, "waking guest for incoming ring");
1384 host_to_guest_interrupt.deliver();
1385 }
1386 Ok(())
1387 }
1388
1389 fn unstick_outgoing_ring(
1390 gm: &GuestMemory,
1391 channel: &Channel,
1392 out_gpadl: AlignedGpadlView,
1393 guest_to_host_event: Option<&ChannelEvent>,
1394 host_to_guest_interrupt: &Interrupt,
1395 ) -> anyhow::Result<()> {
1396 let control_page = lock_gpn_with_subrange(gm, out_gpadl.gpns()[0])?;
1397 if ring::reader_needs_signal(control_page.pages()[0]) {
1398 tracelimit::info_ratelimited!(channel = %channel.key, "waking guest for outgoing ring");
1399 host_to_guest_interrupt.deliver();
1400 }
1401
1402 if let Some(guest_to_host_event) = guest_to_host_event {
1403 let ring_size = gpadl_ring_size(&out_gpadl).try_into()?;
1404 if ring::writer_needs_signal(control_page.pages()[0], ring_size) {
1405 tracelimit::info_ratelimited!(channel = %channel.key, "waking host for outgoing ring");
1406 guest_to_host_event.0.deliver();
1407 }
1408 }
1409 Ok(())
1410 }
1411}
1412
1413impl Notifier for ServerTaskInner {
1414 fn notify(&mut self, offer_id: OfferId, action: channels::Action) {
1415 let channel = self
1416 .channels
1417 .get_mut(&offer_id)
1418 .expect("channel does not exist");
1419
1420 fn handle<I: 'static + Send, R: 'static + Send>(
1421 offer_id: OfferId,
1422 channel: &Channel,
1423 req: impl FnOnce(Rpc<I, R>) -> ChannelRequest,
1424 input: I,
1425 f: impl 'static + Send + FnOnce(R) -> ChannelResponse,
1426 ) -> Pin<Box<dyn Send + Future<Output = (OfferId, u64, Result<ChannelResponse, RpcError>)>>>
1427 {
1428 let recv = channel.send.call(req, input);
1429 let seq = channel.seq;
1430 Box::pin(async move {
1431 let r = recv.await.map(f);
1432 (offer_id, seq, r)
1433 })
1434 }
1435
1436 let response = match action {
1437 channels::Action::Open(open_params, version) => {
1438 let seq = channel.seq;
1439 match self.open_channel(offer_id, &open_params) {
1440 Ok((channel, interrupt)) => handle(
1441 offer_id,
1442 channel,
1443 ChannelRequest::Open,
1444 OpenRequest::new(
1445 open_params.open_data,
1446 interrupt,
1447 version.feature_flags,
1448 channel.flags,
1449 ),
1450 ChannelResponse::Open,
1451 ),
1452 Err(err) => {
1453 tracelimit::error_ratelimited!(
1454 err = err.as_ref() as &dyn std::error::Error,
1455 ?offer_id,
1456 "could not open channel",
1457 );
1458
1459 Box::pin(future::ready((
1462 offer_id,
1463 seq,
1464 Ok(ChannelResponse::Open(false)),
1465 )))
1466 }
1467 }
1468 }
1469 channels::Action::Close => {
1470 if let Some(channel_bitmap) = self.channel_bitmap.as_ref() {
1471 if let ChannelState::Open(ref state) = channel.state {
1472 channel_bitmap.unregister_channel(state.open_params.event_flag);
1473 }
1474 }
1475
1476 channel.state = ChannelState::Closing;
1477 handle(offer_id, channel, ChannelRequest::Close, (), |()| {
1478 ChannelResponse::Close
1479 })
1480 }
1481 channels::Action::Gpadl(gpadl_id, count, buf) => {
1482 channel.gpadls.add(
1483 gpadl_id,
1484 MultiPagedRangeBuf::new(count.into(), buf.clone()).unwrap(),
1485 );
1486 handle(
1487 offer_id,
1488 channel,
1489 ChannelRequest::Gpadl,
1490 GpadlRequest {
1491 id: gpadl_id,
1492 count,
1493 buf,
1494 },
1495 move |r| ChannelResponse::Gpadl(gpadl_id, r),
1496 )
1497 }
1498 channels::Action::TeardownGpadl {
1499 gpadl_id,
1500 post_restore,
1501 } => {
1502 if !post_restore {
1503 channel.gpadls.remove(gpadl_id, Box::new(|| ()));
1504 }
1505
1506 handle(
1507 offer_id,
1508 channel,
1509 ChannelRequest::TeardownGpadl,
1510 gpadl_id,
1511 move |()| ChannelResponse::TeardownGpadl(gpadl_id),
1512 )
1513 }
1514 channels::Action::Modify { target_vp } => {
1515 if let ChannelState::Open(state) = &mut channel.state {
1516 if let Err(err) = state.guest_event_port.set_target_vp(target_vp) {
1517 tracelimit::error_ratelimited!(
1518 error = &err as &dyn std::error::Error,
1519 channel = %channel.key,
1520 "could not modify channel",
1521 );
1522 let seq = channel.seq;
1523 Box::pin(async move {
1524 (
1525 offer_id,
1526 seq,
1527 Ok(ChannelResponse::Modify(protocol::STATUS_UNSUCCESSFUL)),
1528 )
1529 })
1530 } else {
1531 handle(
1532 offer_id,
1533 channel,
1534 ChannelRequest::Modify,
1535 ModifyRequest::TargetVp { target_vp },
1536 ChannelResponse::Modify,
1537 )
1538 }
1539 } else {
1540 unreachable!();
1541 }
1542 }
1543 };
1544 self.channel_responses.push(response);
1545 }
1546
1547 fn modify_connection(&mut self, mut request: ModifyConnectionRequest) -> anyhow::Result<()> {
1548 self.map_interrupt_page(request.interrupt_page)
1549 .context("Failed to map interrupt page.")?;
1550
1551 self.set_monitor_page(request.monitor_page)
1552 .context("Failed to map monitor page.")?;
1553
1554 if let Some(vp) = request.target_message_vp {
1555 self.message_port.set_target_vp(vp)?;
1556 }
1557
1558 if request.notify_relay {
1559 if self.enable_mnf {
1564 request.monitor_page = Update::Unchanged;
1565 }
1566
1567 self.relay_send.send(request.into());
1568 }
1569
1570 Ok(())
1571 }
1572
1573 fn forward_unhandled(&mut self, request: InitiateContactRequest) {
1574 if let Some(external_server) = &self.external_server_send {
1575 external_server.send(request);
1576 } else {
1577 tracing::warn!(?request, "nowhere to forward unhandled request")
1578 }
1579 }
1580
1581 fn inspect(&self, version: Option<VersionInfo>, offer_id: OfferId, req: inspect::Request<'_>) {
1582 let channel = self.channels.get(&offer_id).expect("should exist");
1583 let mut resp = req.respond();
1584 if let ChannelState::Open(state) = &channel.state {
1585 let mem = self.get_gm_for_channel(version.expect("must be connected"), channel);
1586 inspect_rings(
1587 &mut resp,
1588 mem,
1589 channel.gpadls.clone(),
1590 &state.open_params.open_data,
1591 );
1592 }
1593 }
1594
1595 fn send_message(&mut self, message: &OutgoingMessage, target: MessageTarget) -> bool {
1596 if !self.running && !self.send_messages_while_stopped {
1609 if !matches!(target, MessageTarget::Default) {
1610 tracelimit::error_ratelimited!(?target, "dropping message while paused");
1611 }
1612 return false;
1613 }
1614
1615 let mut port_storage;
1616 let port = match target {
1617 MessageTarget::Default => self.message_port.as_mut(),
1618 MessageTarget::ReservedChannel(offer_id, target) => {
1619 if let Some(port) = self.get_reserved_channel_message_port(offer_id, target) {
1620 port.as_mut()
1621 } else {
1622 return true;
1624 }
1625 }
1626 MessageTarget::Custom(target) => {
1627 port_storage = match self.synic.new_guest_message_port(
1628 self.redirect_vtl,
1629 target.vp,
1630 target.sint,
1631 ) {
1632 Ok(port) => port,
1633 Err(err) => {
1634 tracing::error!(
1635 ?err,
1636 ?self.redirect_vtl,
1637 ?target,
1638 "could not create message port"
1639 );
1640
1641 return true;
1643 }
1644 };
1645 port_storage.as_mut()
1646 }
1647 };
1648
1649 matches!(
1652 port.poll_post_message(
1653 &mut std::task::Context::from_waker(std::task::Waker::noop()),
1654 VMBUS_MESSAGE_TYPE,
1655 message.data()
1656 ),
1657 Poll::Ready(())
1658 )
1659 }
1660
1661 fn notify_hvsock(&mut self, request: &HvsockConnectRequest) {
1662 self.hvsock_requests += 1;
1663 self.hvsock_send.send(*request);
1664 }
1665
1666 fn reset_complete(&mut self) {
1667 if let Some(monitor) = self.synic.monitor_support() {
1668 if let Err(err) = monitor.set_monitor_page(self.vtl, None) {
1669 tracing::warn!(?err, "resetting monitor page failed")
1670 }
1671 }
1672
1673 self.unreserve_channels();
1674 for done in self.reset_done.drain(..) {
1675 done.complete(());
1676 }
1677 }
1678
1679 fn unload_complete(&mut self) {
1680 self.unreserve_channels();
1681 }
1682}
1683
1684impl ServerTaskInner {
1685 fn open_channel(
1686 &mut self,
1687 offer_id: OfferId,
1688 open_params: &OpenParams,
1689 ) -> anyhow::Result<(&mut Channel, Interrupt)> {
1690 let channel = self
1691 .channels
1692 .get_mut(&offer_id)
1693 .expect("channel does not exist");
1694
1695 if let Some(channel_bitmap) = self.channel_bitmap.as_ref() {
1697 channel_bitmap.register_channel(
1698 open_params.event_flag,
1699 channel.guest_to_host_event.0.clone(),
1700 );
1701 }
1702 let event_port = self
1707 .synic
1708 .add_event_port(
1709 open_params.connection_id,
1710 self.vtl,
1711 channel.guest_to_host_event.clone(),
1712 open_params.monitor_info,
1713 )
1714 .context("failed to create guest-to-host event port")?;
1715
1716 let (target_vp, event_flag) = if self.channel_bitmap.is_some() {
1719 (0, 0)
1720 } else {
1721 (open_params.open_data.target_vp, open_params.event_flag)
1722 };
1723 let (target_vtl, target_sint) = if open_params.flags.redirect_interrupt() {
1724 (self.redirect_vtl, self.redirect_sint)
1725 } else {
1726 (self.vtl, SINT)
1727 };
1728
1729 let guest_event_port = self.synic.new_guest_event_port(
1730 VmbusServer::get_child_event_port_id(open_params.channel_id, SINT, self.vtl),
1731 target_vtl,
1732 target_vp,
1733 target_sint,
1734 event_flag,
1735 open_params.monitor_info,
1736 )?;
1737
1738 let interrupt = ChannelBitmap::create_interrupt(
1739 &self.channel_bitmap,
1740 guest_event_port.interrupt(),
1741 open_params.event_flag,
1742 );
1743
1744 channel.reserved_state.message_port = None;
1746
1747 if let Some(target) = open_params.reserved_target {
1749 channel.reserved_state.message_port = Some(self.synic.new_guest_message_port(
1750 self.redirect_vtl,
1751 target.vp,
1752 target.sint,
1753 )?);
1754
1755 channel.reserved_state.target = target;
1756 }
1757
1758 channel.state = ChannelState::Open(Box::new(ChannelOpenState {
1759 open_params: *open_params,
1760 _event_port: event_port,
1761 guest_event_port,
1762 host_to_guest_interrupt: interrupt.clone(),
1763 }));
1764 Ok((channel, interrupt))
1765 }
1766
1767 fn map_interrupt_page(&mut self, interrupt_page: Update<u64>) -> anyhow::Result<()> {
1770 let interrupt_page = match interrupt_page {
1771 Update::Unchanged => return Ok(()),
1772 Update::Reset => {
1773 self.channel_bitmap = None;
1774 self.shared_event_port = None;
1775 return Ok(());
1776 }
1777 Update::Set(interrupt_page) => interrupt_page,
1778 };
1779
1780 assert_ne!(interrupt_page, 0);
1781
1782 if interrupt_page % PAGE_SIZE as u64 != 0 {
1783 anyhow::bail!("interrupt page {:#x} is not page aligned", interrupt_page);
1784 }
1785
1786 let interrupt_page = lock_page_with_subrange(&self.gm, interrupt_page)?;
1789 let channel_bitmap = Arc::new(ChannelBitmap::new(interrupt_page));
1790 self.channel_bitmap = Some(channel_bitmap.clone());
1791
1792 let interrupt = Interrupt::from_fn(move || {
1794 channel_bitmap.handle_shared_interrupt();
1795 });
1796
1797 self.shared_event_port = Some(self.synic.add_event_port(
1798 SHARED_EVENT_CONNECTION_ID,
1799 self.vtl,
1800 Arc::new(ChannelEvent(interrupt)),
1801 None,
1802 )?);
1803
1804 Ok(())
1805 }
1806
1807 fn set_monitor_page(&mut self, monitor_page: Update<MonitorPageGpas>) -> anyhow::Result<()> {
1808 let monitor_page = match monitor_page {
1809 Update::Unchanged => return Ok(()),
1810 Update::Reset => None,
1811 Update::Set(value) => Some(value),
1812 };
1813
1814 if self.channels.iter().any(|(_, c)| {
1816 matches!(
1817 &c.state,
1818 ChannelState::Open(state) if state.open_params.monitor_info.is_some()
1819 )
1820 }) {
1821 anyhow::bail!("attempt to change monitor page while open channels using mnf");
1822 }
1823
1824 if self.enable_mnf {
1825 if let Some(monitor) = self.synic.monitor_support() {
1826 if let Err(err) = monitor.set_monitor_page(self.vtl, monitor_page) {
1827 anyhow::bail!(
1828 "setting monitor page failed, err = {err:?}, monitor_page = {monitor_page:?}"
1829 );
1830 }
1831 }
1832 }
1833
1834 Ok(())
1835 }
1836
1837 fn get_reserved_channel_message_port(
1838 &mut self,
1839 offer_id: OfferId,
1840 new_target: ConnectionTarget,
1841 ) -> Option<&mut Box<dyn GuestMessagePort>> {
1842 let channel = self
1843 .channels
1844 .get_mut(&offer_id)
1845 .expect("channel does not exist");
1846
1847 assert!(
1848 channel.reserved_state.message_port.is_some(),
1849 "channel is not reserved"
1850 );
1851
1852 if channel.reserved_state.target.sint != new_target.sint {
1855 channel.reserved_state.message_port = None;
1857 let message_port = self
1858 .synic
1859 .new_guest_message_port(self.redirect_vtl, new_target.vp, new_target.sint)
1860 .inspect_err(|err| {
1861 tracing::error!(
1862 ?err,
1863 ?self.redirect_vtl,
1864 ?new_target,
1865 "could not create reserved channel message port"
1866 )
1867 })
1868 .ok()?;
1869
1870 channel.reserved_state.message_port = Some(message_port);
1871 channel.reserved_state.target = new_target;
1872 } else if channel.reserved_state.target.vp != new_target.vp {
1873 let message_port = channel.reserved_state.message_port.as_mut().unwrap();
1874
1875 if let Err(err) = message_port.set_target_vp(new_target.vp) {
1878 tracing::error!(
1879 ?err,
1880 ?self.redirect_vtl,
1881 ?new_target,
1882 "could not update reserved channel message port"
1883 );
1884 }
1885
1886 channel.reserved_state.target = new_target;
1887 return Some(message_port);
1888 }
1889
1890 Some(channel.reserved_state.message_port.as_mut().unwrap())
1891 }
1892
1893 fn unreserve_channels(&mut self) {
1894 for channel in self.channels.values_mut() {
1896 if let ChannelState::Closed = channel.state {
1897 channel.reserved_state.message_port = None;
1898 }
1899 }
1900 }
1901
1902 fn get_gm_for_channel(&self, version: VersionInfo, channel: &Channel) -> &GuestMemory {
1903 select_gm_for_channel(&self.gm, self.private_gm.as_ref(), version, channel)
1904 }
1905}
1906
1907fn select_gm_for_channel<'a>(
1908 gm: &'a GuestMemory,
1909 private_gm: Option<&'a GuestMemory>,
1910 version: VersionInfo,
1911 channel: &Channel,
1912) -> &'a GuestMemory {
1913 if channel.flags.confidential_ring_buffer() && version.feature_flags.confidential_channels() {
1914 if let Some(private_gm) = private_gm {
1915 return private_gm;
1916 }
1917 }
1918
1919 gm
1920}
1921
1922#[derive(Clone)]
1924pub struct VmbusServerControl {
1925 mem: GuestMemory,
1926 private_mem: Option<GuestMemory>,
1927 send: mesh::Sender<OfferRequest>,
1928 use_event: bool,
1929 force_confidential_external_memory: bool,
1930}
1931
1932impl VmbusServerControl {
1933 pub async fn offer_core(&self, offer_info: OfferInfo) -> anyhow::Result<OfferResources> {
1936 let flags = offer_info.params.flags;
1937 self.send
1938 .call_failable(OfferRequest::Offer, offer_info)
1939 .await?;
1940 Ok(OfferResources::new(
1941 self.mem.clone(),
1942 if flags.confidential_ring_buffer() || flags.confidential_external_memory() {
1943 self.private_mem.clone()
1944 } else {
1945 None
1946 },
1947 ))
1948 }
1949
1950 pub async fn force_reset(&self) -> anyhow::Result<()> {
1953 self.send
1954 .call(OfferRequest::ForceReset, ())
1955 .await
1956 .context("vmbus server is gone")
1957 }
1958
1959 async fn offer(&self, request: OfferInput) -> anyhow::Result<OfferResources> {
1960 let mut offer_info = OfferInfo {
1961 params: request.params.into(),
1962 event: request.event,
1963 request_send: request.request_send,
1964 server_request_recv: request.server_request_recv,
1965 };
1966
1967 if self.force_confidential_external_memory {
1968 tracing::warn!(
1969 key = %offer_info.params.key(),
1970 "forcing confidential external memory for channel"
1971 );
1972
1973 offer_info
1974 .params
1975 .flags
1976 .set_confidential_external_memory(true);
1977 }
1978
1979 self.offer_core(offer_info).await
1980 }
1981}
1982
1983fn inspect_rings(
1985 resp: &mut inspect::Response<'_>,
1986 gm: &GuestMemory,
1987 gpadl_map: Arc<GpadlMap>,
1988 open_data: &OpenData,
1989) -> Option<()> {
1990 let gpadl = gpadl_map
1991 .view()
1992 .map(GpadlId(open_data.ring_gpadl_id.0))
1993 .ok()?;
1994
1995 let aligned = AlignedGpadlView::new(gpadl).ok()?;
1996 let (in_gpadl, out_gpadl) = aligned.split(open_data.ring_offset).ok()?;
1997 resp.child("incoming_ring", |req| inspect_ring(req, &in_gpadl, gm));
1998 resp.child("outgoing_ring", |req| inspect_ring(req, &out_gpadl, gm));
1999 Some(())
2000}
2001
2002fn inspect_ring(req: inspect::Request<'_>, gpadl: &AlignedGpadlView, gm: &GuestMemory) {
2004 let mut resp = req.respond();
2005
2006 resp.hex("ring_size", gpadl_ring_size(gpadl));
2007
2008 if let Ok(pages) = lock_gpn_with_subrange(gm, gpadl.gpns()[0]) {
2011 ring::inspect_ring(pages.pages()[0], &mut resp);
2012 }
2013}
2014
2015fn gpadl_ring_size(gpadl: &AlignedGpadlView) -> usize {
2016 (gpadl.gpns().len() - 1) * PAGE_SIZE
2018}
2019
2020fn lock_page_with_subrange(gm: &GuestMemory, offset: u64) -> anyhow::Result<guestmem::LockedPages> {
2025 Ok(gm
2026 .lockable_subrange(offset, PAGE_SIZE as u64)?
2027 .lock_gpns(false, &[0])?)
2028}
2029
2030fn lock_gpn_with_subrange(gm: &GuestMemory, gpn: u64) -> anyhow::Result<guestmem::LockedPages> {
2035 lock_page_with_subrange(gm, gpn * PAGE_SIZE as u64)
2036}
2037
2038pub(crate) struct MessageSender {
2039 send: mpsc::Sender<SynicMessage>,
2040 multiclient: bool,
2041}
2042
2043impl MessageSender {
2044 fn poll_handle_message(
2045 &self,
2046 cx: &mut std::task::Context<'_>,
2047 msg: &[u8],
2048 trusted: bool,
2049 ) -> Poll<Result<(), SendError>> {
2050 let mut send = self.send.clone();
2051 ready!(send.poll_ready(cx))?;
2052 send.start_send(SynicMessage {
2053 data: msg.to_vec(),
2054 multiclient: self.multiclient,
2055 trusted,
2056 })?;
2057
2058 Poll::Ready(Ok(()))
2059 }
2060}
2061
2062impl MessagePort for MessageSender {
2063 fn poll_handle_message(
2064 &self,
2065 cx: &mut std::task::Context<'_>,
2066 msg: &[u8],
2067 trusted: bool,
2068 ) -> Poll<()> {
2069 if let Err(err) = ready!(self.poll_handle_message(cx, msg, trusted)) {
2070 tracelimit::error_ratelimited!(
2071 error = &err as &dyn std::error::Error,
2072 "failed to send message"
2073 );
2074 }
2075
2076 Poll::Ready(())
2077 }
2078}
2079
2080#[async_trait]
2081impl ParentBus for VmbusServerControl {
2082 async fn add_child(&self, request: OfferInput) -> anyhow::Result<OfferResources> {
2083 self.offer(request).await
2084 }
2085
2086 fn clone_bus(&self) -> Box<dyn ParentBus> {
2087 Box::new(self.clone())
2088 }
2089
2090 fn use_event(&self) -> bool {
2091 self.use_event
2092 }
2093}
2094
2095#[cfg(test)]
2096mod tests {
2097 use super::*;
2098 use inspect::InspectMut;
2099 use mesh::CancelReason;
2100 use pal_async::DefaultDriver;
2101 use pal_async::async_test;
2102 use pal_async::timer::Instant;
2103 use pal_async::timer::PolledTimer;
2104 use parking_lot::Mutex;
2105 use protocol::UserDefinedData;
2106 use std::time::Duration;
2107 use test_with_tracing::test;
2108 use vmbus_channel::bus::OfferParams;
2109 use vmbus_channel::channel::ChannelOpenError;
2110 use vmbus_channel::channel::DeviceResources;
2111 use vmbus_channel::channel::SaveRestoreVmbusDevice;
2112 use vmbus_channel::channel::VmbusDevice;
2113 use vmbus_channel::channel::offer_channel;
2114 use vmbus_core::protocol::ChannelId;
2115 use vmbus_core::protocol::VmbusMessage;
2116 use vmcore::synic::MonitorInfo;
2117 use vmcore::synic::SynicPortAccess;
2118 use zerocopy::FromBytes;
2119 use zerocopy::Immutable;
2120 use zerocopy::IntoBytes;
2121 use zerocopy::KnownLayout;
2122
2123 struct MockSynicInner {
2124 message_port: Option<Arc<dyn MessagePort>>,
2125 }
2126
2127 struct MockSynic {
2128 inner: Mutex<MockSynicInner>,
2129 message_send: mesh::Sender<Vec<u8>>,
2130 spawner: DefaultDriver,
2131 }
2132
2133 impl MockSynic {
2134 fn new(message_send: mesh::Sender<Vec<u8>>, spawner: DefaultDriver) -> Self {
2135 Self {
2136 inner: Mutex::new(MockSynicInner { message_port: None }),
2137 message_send,
2138 spawner,
2139 }
2140 }
2141
2142 fn send_message(&self, msg: impl VmbusMessage + IntoBytes + Immutable + KnownLayout) {
2143 self.send_message_core(OutgoingMessage::new(&msg), false);
2144 }
2145
2146 fn send_message_trusted(
2147 &self,
2148 msg: impl VmbusMessage + IntoBytes + Immutable + KnownLayout,
2149 ) {
2150 self.send_message_core(OutgoingMessage::new(&msg), true);
2151 }
2152
2153 fn send_message_core(&self, msg: OutgoingMessage, trusted: bool) {
2154 assert_eq!(
2155 self.inner
2156 .lock()
2157 .message_port
2158 .as_ref()
2159 .unwrap()
2160 .poll_handle_message(
2161 &mut std::task::Context::from_waker(std::task::Waker::noop()),
2162 msg.data(),
2163 trusted,
2164 ),
2165 Poll::Ready(())
2166 );
2167 }
2168 }
2169
2170 #[derive(Debug)]
2171 struct MockGuestPort {}
2172
2173 impl GuestEventPort for MockGuestPort {
2174 fn interrupt(&self) -> Interrupt {
2175 Interrupt::null()
2176 }
2177
2178 fn set_target_vp(&mut self, _vp: u32) -> Result<(), vmcore::synic::HypervisorError> {
2179 Ok(())
2180 }
2181 }
2182
2183 struct MockGuestMessagePort {
2184 send: mesh::Sender<Vec<u8>>,
2185 spawner: DefaultDriver,
2186 timer: Option<(PolledTimer, Instant)>,
2187 }
2188
2189 impl GuestMessagePort for MockGuestMessagePort {
2190 fn poll_post_message(
2191 &mut self,
2192 cx: &mut std::task::Context<'_>,
2193 _typ: u32,
2194 payload: &[u8],
2195 ) -> Poll<()> {
2196 if let Some((timer, deadline)) = self.timer.as_mut() {
2197 ready!(timer.sleep_until(*deadline).poll_unpin(cx));
2198 self.timer = None;
2199 }
2200
2201 let mut pending_chance = [0; 1];
2203 getrandom::fill(&mut pending_chance).unwrap();
2204 if pending_chance[0] % 4 == 0 {
2205 let mut timer = PolledTimer::new(&self.spawner);
2206 let deadline = Instant::now() + Duration::from_millis(10);
2207 match timer.sleep_until(deadline).poll_unpin(cx) {
2208 Poll::Ready(_) => {}
2209 Poll::Pending => {
2210 self.timer = Some((timer, deadline));
2211 return Poll::Pending;
2212 }
2213 }
2214 }
2215
2216 self.send.send(payload.into());
2217 Poll::Ready(())
2218 }
2219
2220 fn set_target_vp(&mut self, _vp: u32) -> Result<(), vmcore::synic::HypervisorError> {
2221 Ok(())
2222 }
2223 }
2224
2225 impl Inspect for MockGuestMessagePort {
2226 fn inspect(&self, _req: inspect::Request<'_>) {}
2227 }
2228
2229 impl SynicPortAccess for MockSynic {
2230 fn add_message_port(
2231 &self,
2232 connection_id: u32,
2233 _minimum_vtl: Vtl,
2234 port: Arc<dyn MessagePort>,
2235 ) -> Result<Box<dyn Sync + Send>, vmcore::synic::Error> {
2236 self.inner.lock().message_port = Some(port);
2237 Ok(Box::new(connection_id))
2238 }
2239
2240 fn add_event_port(
2241 &self,
2242 connection_id: u32,
2243 _minimum_vtl: Vtl,
2244 _port: Arc<dyn EventPort>,
2245 _monitor_info: Option<MonitorInfo>,
2246 ) -> Result<Box<dyn Sync + Send>, vmcore::synic::Error> {
2247 Ok(Box::new(connection_id))
2248 }
2249
2250 fn new_guest_message_port(
2251 &self,
2252 _vtl: Vtl,
2253 _vp: u32,
2254 _sint: u8,
2255 ) -> Result<Box<(dyn GuestMessagePort)>, vmcore::synic::HypervisorError> {
2256 Ok(Box::new(MockGuestMessagePort {
2257 send: self.message_send.clone(),
2258 spawner: self.spawner.clone(),
2259 timer: None,
2260 }))
2261 }
2262
2263 fn new_guest_event_port(
2264 &self,
2265 _port_id: u32,
2266 _vtl: Vtl,
2267 _vp: u32,
2268 _sint: u8,
2269 _flag: u16,
2270 _monitor_info: Option<MonitorInfo>,
2271 ) -> Result<Box<(dyn GuestEventPort)>, vmcore::synic::HypervisorError> {
2272 Ok(Box::new(MockGuestPort {}))
2273 }
2274
2275 fn prefer_os_events(&self) -> bool {
2276 false
2277 }
2278 }
2279
2280 struct TestChannel {
2281 request_recv: mesh::Receiver<ChannelRequest>,
2282 server_request_send: mesh::Sender<ChannelServerRequest>,
2283 _resources: OfferResources,
2284 }
2285
2286 impl TestChannel {
2287 async fn next_request(&mut self) -> ChannelRequest {
2288 self.request_recv.next().await.unwrap()
2289 }
2290
2291 async fn handle_gpadl(&mut self) {
2292 let ChannelRequest::Gpadl(rpc) = self.next_request().await else {
2293 panic!("Wrong request");
2294 };
2295
2296 rpc.complete(true);
2297 }
2298
2299 async fn handle_open(&mut self, f: fn(&OpenRequest)) {
2300 let ChannelRequest::Open(rpc) = self.next_request().await else {
2301 panic!("Wrong request");
2302 };
2303
2304 f(rpc.input());
2305 rpc.complete(true);
2306 }
2307
2308 async fn handle_gpadl_teardown(&mut self) {
2309 let rpc = self.get_gpadl_teardown().await;
2310 rpc.complete(());
2311 }
2312
2313 async fn get_gpadl_teardown(&mut self) -> Rpc<GpadlId, ()> {
2314 let ChannelRequest::TeardownGpadl(rpc) = self.next_request().await else {
2315 panic!("Wrong request");
2316 };
2317
2318 rpc
2319 }
2320
2321 async fn restore(&self) {
2322 self.server_request_send
2323 .call(ChannelServerRequest::Restore, false)
2324 .await
2325 .unwrap()
2326 .unwrap();
2327 }
2328 }
2329
2330 struct TestEnv {
2331 vmbus: VmbusServer,
2332 synic: Arc<MockSynic>,
2333 message_recv: mesh::Receiver<Vec<u8>>,
2334 trusted: bool,
2335 }
2336
2337 impl TestEnv {
2338 fn new(spawner: DefaultDriver) -> Self {
2339 let (message_send, message_recv) = mesh::channel();
2340 let synic = Arc::new(MockSynic::new(message_send, spawner.clone()));
2341 let gm = GuestMemory::empty();
2342 let vmbus = VmbusServerBuilder::new(spawner, synic.clone(), gm)
2343 .build()
2344 .unwrap();
2345
2346 Self {
2347 vmbus,
2348 synic,
2349 message_recv,
2350 trusted: false,
2351 }
2352 }
2353
2354 async fn offer(&self, id: u32, allow_confidential_external_memory: bool) -> TestChannel {
2355 let guid = Guid {
2356 data1: id,
2357 ..Guid::ZERO
2358 };
2359 let (request_send, request_recv) = mesh::channel();
2360 let (server_request_send, server_request_recv) = mesh::channel();
2361 let offer = OfferInput {
2362 event: Interrupt::from_fn(|| {}),
2363 request_send,
2364 server_request_recv,
2365 params: OfferParams {
2366 interface_name: "test".into(),
2367 instance_id: guid,
2368 interface_id: guid,
2369 mmio_megabytes: 0,
2370 mmio_megabytes_optional: 0,
2371 channel_type: vmbus_channel::bus::ChannelType::Device {
2372 pipe_packets: false,
2373 },
2374 subchannel_index: 0,
2375 mnf_interrupt_latency: None,
2376 offer_order: None,
2377 allow_confidential_external_memory,
2378 },
2379 };
2380
2381 let control = self.vmbus.control();
2382 let _resources = control.add_child(offer).await.unwrap();
2383
2384 TestChannel {
2385 request_recv,
2386 server_request_send,
2387 _resources,
2388 }
2389 }
2390
2391 async fn gpadl(&mut self, channel_id: u32, gpadl_id: u32, channel: &mut TestChannel) {
2392 self.synic.send_message_core(
2393 OutgoingMessage::with_data(
2394 &protocol::GpadlHeader {
2395 channel_id: ChannelId(channel_id),
2396 gpadl_id: GpadlId(gpadl_id),
2397 count: 1,
2398 len: 16,
2399 },
2400 [1u64, 0u64].as_bytes(),
2401 ),
2402 self.trusted,
2403 );
2404
2405 channel.handle_gpadl().await;
2406 self.expect_response(protocol::MessageType::GPADL_CREATED)
2407 .await;
2408 }
2409
2410 async fn open_channel(
2411 &mut self,
2412 channel_id: u32,
2413 ring_gpadl_id: u32,
2414 channel: &mut TestChannel,
2415 f: fn(&OpenRequest),
2416 ) {
2417 self.gpadl(channel_id, ring_gpadl_id, channel).await;
2418 self.synic.send_message_core(
2419 OutgoingMessage::new(&protocol::OpenChannel {
2420 channel_id: ChannelId(channel_id),
2421 open_id: 0,
2422 ring_buffer_gpadl_id: GpadlId(ring_gpadl_id),
2423 target_vp: 0,
2424 downstream_ring_buffer_page_offset: 0,
2425 user_data: UserDefinedData::default(),
2426 }),
2427 self.trusted,
2428 );
2429
2430 channel.handle_open(f).await;
2431 self.expect_response(protocol::MessageType::OPEN_CHANNEL_RESULT)
2432 .await;
2433 }
2434
2435 async fn expect_response(&mut self, expected: protocol::MessageType) {
2436 let data = self.message_recv.next().await.unwrap();
2437 let header = protocol::MessageHeader::read_from_prefix(&data).unwrap().0; assert_eq!(expected, header.message_type())
2439 }
2440
2441 async fn get_response<T: VmbusMessage + FromBytes + Immutable + KnownLayout>(
2442 &mut self,
2443 ) -> T {
2444 let data = self.message_recv.next().await.unwrap();
2445 let (header, message) = protocol::MessageHeader::read_from_prefix(&data).unwrap(); assert_eq!(T::MESSAGE_TYPE, header.message_type());
2447 T::read_from_prefix(message).unwrap().0 }
2449
2450 fn initiate_contact(
2451 &mut self,
2452 version: protocol::Version,
2453 feature_flags: protocol::FeatureFlags,
2454 trusted: bool,
2455 ) {
2456 self.synic.send_message_core(
2457 OutgoingMessage::new(&protocol::InitiateContact {
2458 version_requested: version as u32,
2459 target_message_vp: 0,
2460 child_to_parent_monitor_page_gpa: 0,
2461 parent_to_child_monitor_page_gpa: 0,
2462 interrupt_page_or_target_info: protocol::TargetInfo::new()
2463 .with_sint(2)
2464 .with_vtl(0)
2465 .with_feature_flags(feature_flags.into())
2466 .into(),
2467 }),
2468 trusted,
2469 );
2470
2471 self.trusted = trusted;
2472 }
2473
2474 async fn connect(
2475 &mut self,
2476 offer_count: u32,
2477 feature_flags: protocol::FeatureFlags,
2478 trusted: bool,
2479 ) {
2480 self.initiate_contact(protocol::Version::Copper, feature_flags, trusted);
2481
2482 self.expect_response(protocol::MessageType::VERSION_RESPONSE)
2483 .await;
2484
2485 self.synic
2486 .send_message_core(OutgoingMessage::new(&protocol::RequestOffers {}), trusted);
2487
2488 for _ in 0..offer_count {
2489 self.expect_response(protocol::MessageType::OFFER_CHANNEL)
2490 .await;
2491 }
2492
2493 self.expect_response(protocol::MessageType::ALL_OFFERS_DELIVERED)
2494 .await;
2495 }
2496 }
2497
2498 #[async_test]
2499 async fn test_save_restore(spawner: DefaultDriver) {
2500 let mut env = TestEnv::new(spawner);
2505 let mut channel = env.offer(1, false).await;
2506 env.vmbus.start();
2507 env.connect(1, protocol::FeatureFlags::new(), false).await;
2508
2509 env.gpadl(1, 10, &mut channel).await;
2511
2512 env.synic.send_message(protocol::GpadlTeardown {
2514 channel_id: ChannelId(1),
2515 gpadl_id: GpadlId(10),
2516 });
2517
2518 let rpc = channel.get_gpadl_teardown().await;
2521 env.vmbus.stop().await;
2522 let saved_state = env.vmbus.save().await;
2523 env.vmbus.start();
2524
2525 rpc.complete(());
2527 env.expect_response(protocol::MessageType::GPADL_TORNDOWN)
2528 .await;
2529
2530 env.synic.send_message(protocol::RelIdReleased {
2531 channel_id: ChannelId(1),
2532 });
2533
2534 env.vmbus.reset().await;
2535 env.vmbus.stop().await;
2536
2537 env.vmbus.restore(saved_state).await.unwrap();
2540 channel.restore().await;
2541 env.vmbus.start();
2542
2543 channel.handle_gpadl_teardown().await;
2545 env.expect_response(protocol::MessageType::GPADL_TORNDOWN)
2546 .await;
2547
2548 env.synic.send_message(protocol::RelIdReleased {
2549 channel_id: ChannelId(1),
2550 });
2551 }
2552
2553 struct TestDeviceState {
2554 id: u32,
2555 started: bool,
2556 resources: Option<DeviceResources>,
2557 open_requests: HashMap<u16, OpenRequest>,
2558 target_vps: HashMap<u16, u32>,
2559 }
2560
2561 impl TestDeviceState {
2562 pub fn id(this: &Arc<Mutex<Self>>) -> u32 {
2563 this.lock().id
2564 }
2565
2566 pub fn started(this: &Arc<Mutex<Self>>) -> bool {
2567 this.lock().started
2568 }
2569 pub fn set_started(this: &Arc<Mutex<Self>>, started: bool) {
2570 this.lock().started = started;
2571 }
2572
2573 pub fn open_request(this: &Arc<Mutex<Self>>, channel_idx: u16) -> Option<OpenRequest> {
2574 this.lock().open_requests.get(&channel_idx).cloned()
2575 }
2576 pub fn set_open_request(
2577 this: &Arc<Mutex<Self>>,
2578 channel_idx: u16,
2579 open_request: OpenRequest,
2580 ) {
2581 assert!(
2582 this.lock()
2583 .open_requests
2584 .insert(channel_idx, open_request)
2585 .is_none()
2586 );
2587 }
2588 pub fn remove_open_request(
2589 this: &Arc<Mutex<Self>>,
2590 channel_idx: u16,
2591 ) -> Option<OpenRequest> {
2592 this.lock().open_requests.remove(&channel_idx)
2593 }
2594
2595 pub fn target_vp(this: &Arc<Mutex<Self>>, channel_idx: u16) -> Option<u32> {
2596 this.lock().target_vps.get(&channel_idx).copied()
2597 }
2598 pub fn set_target_vp(this: &Arc<Mutex<Self>>, channel_idx: u16, target_vp: u32) {
2599 let _ = this.lock().target_vps.insert(channel_idx, target_vp);
2600 }
2601 }
2602
2603 #[derive(InspectMut)]
2604 struct TestDevice {
2605 #[inspect(skip)]
2606 pub state: Arc<Mutex<TestDeviceState>>,
2607 }
2608
2609 impl TestDevice {
2610 pub fn new_and_state(id: u32) -> (Self, Arc<Mutex<TestDeviceState>>) {
2611 let state = TestDeviceState {
2612 id,
2613 resources: None,
2614 open_requests: HashMap::new(),
2615 target_vps: HashMap::new(),
2616 started: false,
2617 };
2618 let state = Arc::new(Mutex::new(state));
2619 let this = Self {
2620 state: state.clone(),
2621 };
2622 (this, state)
2623 }
2624 }
2625
2626 #[async_trait]
2627 impl VmbusDevice for TestDevice {
2628 fn offer(&self) -> OfferParams {
2629 let guid = Guid {
2630 data1: TestDeviceState::id(&self.state),
2631 ..Guid::ZERO
2632 };
2633
2634 OfferParams {
2635 interface_name: "test".into(),
2636 instance_id: guid,
2637 interface_id: guid,
2638 channel_type: vmbus_channel::bus::ChannelType::Device {
2639 pipe_packets: false,
2640 },
2641 ..Default::default()
2642 }
2643 }
2644
2645 fn max_subchannels(&self) -> u16 {
2646 0
2647 }
2648
2649 fn install(&mut self, resources: DeviceResources) {
2650 self.state.lock().resources = Some(resources);
2651 }
2652
2653 async fn open(
2654 &mut self,
2655 channel_idx: u16,
2656 open_request: &OpenRequest,
2657 ) -> Result<(), ChannelOpenError> {
2658 tracing::info!("OPEN");
2659 TestDeviceState::set_open_request(&self.state, channel_idx, open_request.clone());
2660 Ok(())
2661 }
2662
2663 async fn close(&mut self, channel_idx: u16) {
2664 tracing::info!("CLOSE");
2665 assert!(TestDeviceState::remove_open_request(&self.state, channel_idx).is_some());
2666 }
2667
2668 async fn retarget_vp(&mut self, channel_idx: u16, target_vp: u32) {
2669 TestDeviceState::set_target_vp(&self.state, channel_idx, target_vp);
2670 }
2671
2672 fn start(&mut self) {
2673 tracing::info!("START");
2674 TestDeviceState::set_started(&self.state, true);
2675 }
2676
2677 async fn stop(&mut self) {
2678 tracing::info!("STOP");
2679 TestDeviceState::set_started(&self.state, false);
2680 }
2681
2682 fn supports_save_restore(&mut self) -> Option<&mut dyn SaveRestoreVmbusDevice> {
2683 None
2684 }
2685 }
2686
2687 #[async_test]
2688 async fn test_stopped_child(spawner: DefaultDriver) {
2689 let mut env = TestEnv::new(spawner.clone());
2693 let (test_device, test_device_state) = TestDevice::new_and_state(1);
2694 let control = env.vmbus.control();
2695 let channel = offer_channel(&spawner, control.as_ref(), test_device)
2696 .await
2697 .expect("test device failed to offer");
2698
2699 env.vmbus.start();
2700 env.connect(1, protocol::FeatureFlags::new(), false).await;
2701
2702 channel.stop().await;
2704
2705 assert_eq!(TestDeviceState::started(&test_device_state), false);
2706
2707 env.synic.send_message_core(
2710 OutgoingMessage::with_data(
2711 &protocol::GpadlHeader {
2712 channel_id: ChannelId(1),
2713 gpadl_id: GpadlId(1),
2714 count: 1,
2715 len: 16,
2716 },
2717 [1u64, 0u64].as_bytes(),
2718 ),
2719 false,
2720 );
2721 env.expect_response(protocol::MessageType::GPADL_CREATED)
2722 .await;
2723
2724 env.synic.send_message_core(
2726 OutgoingMessage::new(&protocol::OpenChannel {
2727 channel_id: ChannelId(1),
2728 open_id: 0,
2729 ring_buffer_gpadl_id: GpadlId(1),
2730 target_vp: 0,
2731 downstream_ring_buffer_page_offset: 0,
2732 user_data: UserDefinedData::default(),
2733 }),
2734 false,
2735 );
2736 let wait_for_response = mesh::CancelContext::new()
2737 .with_timeout(Duration::from_millis(150))
2738 .until_cancelled(env.expect_response(protocol::MessageType::OPEN_CHANNEL_RESULT))
2739 .await;
2740 assert!(matches!(
2741 wait_for_response,
2742 Err(CancelReason::DeadlineExceeded)
2743 ));
2744 assert!(TestDeviceState::open_request(&test_device_state, 0).is_none());
2745
2746 channel.start();
2748 env.expect_response(protocol::MessageType::OPEN_CHANNEL_RESULT)
2749 .await;
2750 assert!(TestDeviceState::open_request(&test_device_state, 0).is_some());
2751
2752 assert!(TestDeviceState::target_vp(&test_device_state, 0).is_none());
2754 channel.stop().await;
2755 env.synic.send_message_core(
2756 OutgoingMessage::new(&protocol::ModifyChannel {
2757 channel_id: ChannelId(1),
2758 target_vp: 2,
2759 }),
2760 false,
2761 );
2762 let wait_for_response = mesh::CancelContext::new()
2763 .with_timeout(Duration::from_millis(150))
2764 .until_cancelled(env.expect_response(protocol::MessageType::MODIFY_CHANNEL_RESPONSE))
2765 .await;
2766 assert!(matches!(
2767 wait_for_response,
2768 Err(CancelReason::DeadlineExceeded)
2769 ));
2770
2771 channel.start();
2773 env.expect_response(protocol::MessageType::MODIFY_CHANNEL_RESPONSE)
2774 .await;
2775 assert_eq!(
2776 TestDeviceState::target_vp(&test_device_state, 0)
2777 .expect("Modify channel request received"),
2778 2
2779 );
2780
2781 channel.stop().await;
2785 env.vmbus.reset().await;
2786 assert!(TestDeviceState::open_request(&test_device_state, 0).is_none());
2787
2788 env.vmbus.stop().await;
2789 }
2790
2791 #[async_test]
2792 async fn test_confidential_connection(spawner: DefaultDriver) {
2793 let mut env = TestEnv::new(spawner);
2794 let mut channel = env.offer(1, false).await;
2796 let mut channel2 = env.offer(2, true).await;
2797
2798 let (request_send, request_recv) = mesh::channel();
2800 let (server_request_send, server_request_recv) = mesh::channel();
2801 let id = Guid {
2802 data1: 3,
2803 ..Guid::ZERO
2804 };
2805 let control = env.vmbus.control();
2806 let relay_resources = control
2807 .offer_core(OfferInfo {
2808 params: OfferParamsInternal {
2809 interface_name: "test".into(),
2810 instance_id: id,
2811 interface_id: id,
2812 mmio_megabytes: 0,
2813 mmio_megabytes_optional: 0,
2814 subchannel_index: 0,
2815 use_mnf: MnfUsage::Disabled,
2816 offer_order: None,
2817 flags: protocol::OfferFlags::new().with_enumerate_device_interface(true),
2818 ..Default::default()
2819 },
2820 event: Interrupt::from_fn(|| {}),
2821 request_send,
2822 server_request_recv,
2823 })
2824 .await
2825 .unwrap();
2826
2827 let mut relay_channel = TestChannel {
2828 request_recv,
2829 server_request_send,
2830 _resources: relay_resources,
2831 };
2832
2833 env.vmbus.start();
2834 env.initiate_contact(
2835 protocol::Version::Copper,
2836 protocol::FeatureFlags::new().with_confidential_channels(true),
2837 true,
2838 );
2839
2840 env.expect_response(protocol::MessageType::VERSION_RESPONSE)
2841 .await;
2842
2843 env.synic.send_message_trusted(protocol::RequestOffers {});
2844
2845 let offer = env.get_response::<protocol::OfferChannel>().await;
2847 assert!(offer.flags.confidential_ring_buffer());
2848 assert!(!offer.flags.confidential_external_memory());
2849 let offer = env.get_response::<protocol::OfferChannel>().await;
2850 assert!(offer.flags.confidential_ring_buffer());
2851 assert!(offer.flags.confidential_external_memory());
2852
2853 let offer = env.get_response::<protocol::OfferChannel>().await;
2855 assert!(!offer.flags.confidential_ring_buffer());
2856 assert!(!offer.flags.confidential_external_memory());
2857
2858 env.expect_response(protocol::MessageType::ALL_OFFERS_DELIVERED)
2859 .await;
2860
2861 env.open_channel(1, 1, &mut channel, |request| {
2864 assert!(request.use_confidential_ring);
2865 assert!(!request.use_confidential_external_memory);
2866 })
2867 .await;
2868
2869 env.open_channel(2, 2, &mut channel2, |request| {
2870 assert!(request.use_confidential_ring);
2871 assert!(request.use_confidential_external_memory);
2872 })
2873 .await;
2874
2875 env.open_channel(3, 3, &mut relay_channel, |request| {
2876 assert!(!request.use_confidential_ring);
2877 assert!(!request.use_confidential_external_memory);
2878 })
2879 .await;
2880 }
2881
2882 #[async_test]
2883 async fn test_confidential_channels_unsupported(spawner: DefaultDriver) {
2884 let mut env = TestEnv::new(spawner);
2885 let mut channel = env.offer(1, false).await;
2886 let mut channel2 = env.offer(2, true).await;
2887
2888 env.vmbus.start();
2889 env.connect(2, protocol::FeatureFlags::new(), true).await;
2890
2891 env.open_channel(1, 1, &mut channel, |request| {
2894 assert!(!request.use_confidential_ring);
2895 assert!(!request.use_confidential_external_memory);
2896 })
2897 .await;
2898
2899 env.open_channel(2, 2, &mut channel2, |request| {
2900 assert!(!request.use_confidential_ring);
2901 assert!(!request.use_confidential_external_memory);
2902 })
2903 .await;
2904 }
2905
2906 #[async_test]
2907 async fn test_confidential_channels_untrusted(spawner: DefaultDriver) {
2908 let mut env = TestEnv::new(spawner);
2909 let mut channel = env.offer(1, false).await;
2910 let mut channel2 = env.offer(2, true).await;
2911
2912 env.vmbus.start();
2913 env.connect(
2916 2,
2917 protocol::FeatureFlags::new().with_confidential_channels(true),
2918 false,
2919 )
2920 .await;
2921
2922 env.open_channel(1, 1, &mut channel, |request| {
2925 assert!(!request.use_confidential_ring);
2926 assert!(!request.use_confidential_external_memory);
2927 })
2928 .await;
2929
2930 env.open_channel(2, 2, &mut channel2, |request| {
2931 assert!(!request.use_confidential_ring);
2932 assert!(!request.use_confidential_external_memory);
2933 })
2934 .await;
2935 }
2936}