1#![expect(missing_docs)]
5#![forbid(unsafe_code)]
6
7mod channel_bitmap;
8pub mod channels;
9pub mod event;
10pub mod hvsock;
11mod monitor;
12mod proxyintegration;
13
14pub type Guid = guid::Guid;
16
17use anyhow::Context;
18use async_trait::async_trait;
19use channel_bitmap::ChannelBitmap;
20use channels::ConnectionTarget;
21pub use channels::InitiateContactRequest;
22use channels::MessageTarget;
23pub use channels::MnfUsage;
24use channels::ModifyConnectionRequest;
25pub use channels::ModifyConnectionResponse;
26use channels::Notifier;
27use channels::OfferId;
28pub use channels::OfferParamsInternal;
29use channels::OpenParams;
30use channels::RestoreError;
31pub use channels::Update;
32use futures::FutureExt;
33use futures::StreamExt;
34use futures::channel::mpsc;
35use futures::channel::mpsc::SendError;
36use futures::future::OptionFuture;
37use futures::future::poll_fn;
38use futures::stream::SelectAll;
39use guestmem::GuestMemory;
40use hvdef::Vtl;
41use inspect::Inspect;
42use mesh::payload::Protobuf;
43use mesh::rpc::FailableRpc;
44use mesh::rpc::Rpc;
45use mesh::rpc::RpcError;
46use mesh::rpc::RpcSend;
47use pal_async::task::Spawn;
48use pal_async::task::Task;
49use pal_event::Event;
50#[cfg(windows)]
51pub use proxyintegration::ProxyIntegration;
52#[cfg(windows)]
53pub use proxyintegration::ProxyServerInfo;
54use ring::PAGE_SIZE;
55use std::collections::HashMap;
56use std::future;
57use std::future::Future;
58use std::pin::Pin;
59use std::sync::Arc;
60use std::task::Poll;
61use std::task::ready;
62use std::time::Duration;
63use unicycle::FuturesUnordered;
64use vmbus_channel::bus::ChannelRequest;
65use vmbus_channel::bus::ChannelServerRequest;
66use vmbus_channel::bus::GpadlRequest;
67use vmbus_channel::bus::ModifyRequest;
68use vmbus_channel::bus::OfferInput;
69use vmbus_channel::bus::OfferKey;
70use vmbus_channel::bus::OfferResources;
71use vmbus_channel::bus::OpenData;
72use vmbus_channel::bus::OpenRequest;
73use vmbus_channel::bus::OpenResult;
74use vmbus_channel::bus::ParentBus;
75use vmbus_channel::bus::RestoreResult;
76use vmbus_channel::gpadl::GpadlMap;
77use vmbus_channel::gpadl_ring::AlignedGpadlView;
78use vmbus_channel::gpadl_ring::GpadlRingMem;
79use vmbus_core::HvsockConnectRequest;
80use vmbus_core::HvsockConnectResult;
81use vmbus_core::MaxVersionInfo;
82use vmbus_core::OutgoingMessage;
83use vmbus_core::TaggedStream;
84use vmbus_core::VersionInfo;
85use vmbus_core::protocol;
86pub use vmbus_core::protocol::GpadlId;
87#[cfg(windows)]
88use vmbus_proxy::ProxyHandle;
89use vmbus_ring as ring;
90use vmbus_ring::gparange::MultiPagedRangeBuf;
91use vmcore::interrupt::Interrupt;
92use vmcore::save_restore::SavedStateRoot;
93use vmcore::synic::EventPort;
94use vmcore::synic::GuestEventPort;
95use vmcore::synic::GuestMessagePort;
96use vmcore::synic::MessagePort;
97use vmcore::synic::MonitorPageGpas;
98use vmcore::synic::SynicPortAccess;
99
100const SINT: u8 = 2;
101pub const REDIRECT_SINT: u8 = 7;
102pub const REDIRECT_VTL: Vtl = Vtl::Vtl2;
103const SHARED_EVENT_CONNECTION_ID: u32 = 2;
104const EVENT_PORT_ID: u32 = 2;
105const VMBUS_MESSAGE_TYPE: u32 = 1;
106
107const MAX_CONCURRENT_HVSOCK_REQUESTS: usize = 16;
108
109pub struct VmbusServer {
110 task_send: mesh::Sender<VmbusRequest>,
111 control: Arc<VmbusServerControl>,
112 _message_port: Box<dyn Sync + Send>,
113 _multiclient_message_port: Option<Box<dyn Sync + Send>>,
114 task: Task<ServerTask>,
115}
116
117pub struct VmbusServerBuilder<'a, T: Spawn> {
118 spawner: &'a T,
119 synic: Arc<dyn SynicPortAccess>,
120 gm: GuestMemory,
121 private_gm: Option<GuestMemory>,
122 vtl: Vtl,
123 hvsock_notify: Option<HvsockServerChannelHalf>,
124 server_relay: Option<VmbusServerChannelHalf>,
125 saved_state_notify: Option<mesh::Sender<SavedStateRequest>>,
126 external_server: Option<mesh::Sender<InitiateContactRequest>>,
127 external_requests: Option<mesh::Receiver<InitiateContactRequest>>,
128 use_message_redirect: bool,
129 channel_id_offset: u16,
130 max_version: Option<MaxVersionInfo>,
131 delay_max_version: bool,
132 enable_mnf: bool,
133 force_confidential_external_memory: bool,
134 send_messages_while_stopped: bool,
135}
136
137#[derive(mesh::MeshPayload)]
138pub enum SavedStateRequest {
140 Set(FailableRpc<Box<channels::SavedState>, ()>),
141 Clear(Rpc<(), ()>),
142}
143
144pub struct ServerChannelHalf<Request, Response> {
146 request_send: mesh::Sender<Request>,
147 response_receive: mesh::Receiver<Response>,
148}
149
150pub struct RelayChannelHalf<Request, Response> {
152 pub request_receive: mesh::Receiver<Request>,
153 pub response_send: mesh::Sender<Response>,
154}
155
156pub struct RelayChannel<Request, Response> {
158 pub relay_half: RelayChannelHalf<Request, Response>,
159 pub server_half: ServerChannelHalf<Request, Response>,
160}
161
162impl<Request: 'static + Send, Response: 'static + Send> RelayChannel<Request, Response> {
163 pub fn new() -> Self {
165 let (request_send, request_receive) = mesh::channel();
166 let (response_send, response_receive) = mesh::channel();
167 Self {
168 relay_half: RelayChannelHalf {
169 request_receive,
170 response_send,
171 },
172 server_half: ServerChannelHalf {
173 request_send,
174 response_receive,
175 },
176 }
177 }
178}
179
180pub type VmbusServerChannelHalf = ServerChannelHalf<ModifyRelayRequest, ModifyConnectionResponse>;
181pub type VmbusRelayChannelHalf = RelayChannelHalf<ModifyRelayRequest, ModifyConnectionResponse>;
182pub type VmbusRelayChannel = RelayChannel<ModifyRelayRequest, ModifyConnectionResponse>;
183pub type HvsockServerChannelHalf = ServerChannelHalf<HvsockConnectRequest, HvsockConnectResult>;
184pub type HvsockRelayChannelHalf = RelayChannelHalf<HvsockConnectRequest, HvsockConnectResult>;
185pub type HvsockRelayChannel = RelayChannel<HvsockConnectRequest, HvsockConnectResult>;
186
187#[derive(Debug, Copy, Clone)]
192pub struct ModifyRelayRequest {
193 pub version: Option<u32>,
194 pub monitor_page: Update<MonitorPageGpas>,
195 pub use_interrupt_page: Option<bool>,
196}
197
198impl From<ModifyConnectionRequest> for ModifyRelayRequest {
199 fn from(value: ModifyConnectionRequest) -> Self {
200 Self {
201 version: value.version,
202 monitor_page: value.monitor_page,
203 use_interrupt_page: match value.interrupt_page {
204 Update::Unchanged => None,
205 Update::Reset => Some(false),
206 Update::Set(_) => Some(true),
207 },
208 }
209 }
210}
211
212#[derive(Debug)]
213enum VmbusRequest {
214 Reset(Rpc<(), ()>),
215 Inspect(inspect::Deferred),
216 Save(Rpc<(), SavedState>),
217 Restore(Rpc<Box<SavedState>, Result<(), RestoreError>>),
218 Start,
219 Stop(Rpc<(), ()>),
220}
221
222#[derive(mesh::MeshPayload, Debug)]
223pub struct OfferInfo {
224 pub params: OfferParamsInternal,
225 pub request_send: mesh::Sender<ChannelRequest>,
226 pub server_request_recv: mesh::Receiver<ChannelServerRequest>,
227}
228
229#[expect(clippy::large_enum_variant)]
230#[derive(mesh::MeshPayload)]
231pub(crate) enum OfferRequest {
232 Offer(FailableRpc<OfferInfo, ()>),
233 ForceReset(Rpc<(), ()>),
234}
235
236impl Inspect for VmbusServer {
237 fn inspect(&self, req: inspect::Request<'_>) {
238 self.task_send.send(VmbusRequest::Inspect(req.defer()));
239 }
240}
241
242struct ChannelEvent(Interrupt);
243
244impl EventPort for ChannelEvent {
245 fn handle_event(&self, _flag: u16) {
246 self.0.deliver();
247 }
248
249 fn os_event(&self) -> Option<&Event> {
250 self.0.event()
251 }
252}
253
254#[derive(Debug, Protobuf, SavedStateRoot)]
255#[mesh(package = "vmbus.server")]
256pub struct SavedState {
257 #[mesh(1)]
258 server: channels::SavedState,
259 #[mesh(2)]
263 lost_synic_bug_fixed: bool,
264}
265
266const MESSAGE_CONNECTION_ID: u32 = 1;
267const MULTICLIENT_MESSAGE_CONNECTION_ID: u32 = 4;
268
269impl<'a, T: Spawn> VmbusServerBuilder<'a, T> {
270 pub fn new(spawner: &'a T, synic: Arc<dyn SynicPortAccess>, gm: GuestMemory) -> Self {
272 Self {
273 spawner,
274 synic,
275 gm,
276 private_gm: None,
277 vtl: Vtl::Vtl0,
278 hvsock_notify: None,
279 server_relay: None,
280 saved_state_notify: None,
281 external_server: None,
282 external_requests: None,
283 use_message_redirect: false,
284 channel_id_offset: 0,
285 max_version: None,
286 delay_max_version: false,
287 enable_mnf: false,
288 force_confidential_external_memory: false,
289 send_messages_while_stopped: false,
290 }
291 }
292
293 pub fn private_gm(mut self, private_gm: Option<GuestMemory>) -> Self {
297 self.private_gm = private_gm;
298 self
299 }
300
301 pub fn vtl(mut self, vtl: Vtl) -> Self {
303 self.vtl = vtl;
304 self
305 }
306
307 pub fn hvsock_notify(mut self, hvsock_notify: Option<HvsockServerChannelHalf>) -> Self {
309 self.hvsock_notify = hvsock_notify;
310 self
311 }
312
313 pub fn saved_state_notify(
315 mut self,
316 saved_state_notify: Option<mesh::Sender<SavedStateRequest>>,
317 ) -> Self {
318 self.saved_state_notify = saved_state_notify;
319 self
320 }
321
322 pub fn server_relay(mut self, server_relay: Option<VmbusServerChannelHalf>) -> Self {
325 self.server_relay = server_relay;
326 self
327 }
328
329 pub fn external_requests(
331 mut self,
332 external_requests: Option<mesh::Receiver<InitiateContactRequest>>,
333 ) -> Self {
334 self.external_requests = external_requests;
335 self
336 }
337
338 pub fn external_server(
341 mut self,
342 external_server: Option<mesh::Sender<InitiateContactRequest>>,
343 ) -> Self {
344 self.external_server = external_server;
345 self
346 }
347
348 pub fn use_message_redirect(mut self, use_message_redirect: bool) -> Self {
350 self.use_message_redirect = use_message_redirect;
351 self
352 }
353
354 pub fn enable_channel_id_offset(mut self, enable: bool) -> Self {
359 self.channel_id_offset = if enable { 1024 } else { 0 };
360 self
361 }
362
363 pub fn max_version(mut self, max_version: Option<MaxVersionInfo>) -> Self {
367 self.max_version = max_version;
368 self
369 }
370
371 pub fn delay_max_version(mut self, delay: bool) -> Self {
376 self.delay_max_version = delay;
377 self
378 }
379
380 pub fn enable_mnf(mut self, enable: bool) -> Self {
384 self.enable_mnf = enable;
385 self
386 }
387
388 pub fn force_confidential_external_memory(mut self, force: bool) -> Self {
391 self.force_confidential_external_memory = force;
392 self
393 }
394
395 pub fn send_messages_while_stopped(mut self, send: bool) -> Self {
402 self.send_messages_while_stopped = send;
403 self
404 }
405
406 pub fn build(self) -> anyhow::Result<VmbusServer> {
411 #[expect(clippy::disallowed_methods)] let (message_send, message_recv) = mpsc::channel(64);
413 let message_sender = Arc::new(MessageSender {
414 send: message_send.clone(),
415 multiclient: self.use_message_redirect,
416 });
417
418 let (redirect_vtl, redirect_sint) = if self.use_message_redirect {
419 (REDIRECT_VTL, REDIRECT_SINT)
420 } else {
421 (self.vtl, SINT)
422 };
423
424 let connection_id = if self.vtl == Vtl::Vtl0 && !self.use_message_redirect {
427 MESSAGE_CONNECTION_ID
428 } else {
429 VmbusServer::get_child_message_connection_id(0, redirect_sint, redirect_vtl)
432 };
433
434 let _message_port = self
435 .synic
436 .add_message_port(connection_id, redirect_vtl, message_sender)
437 .context("failed to create vmbus synic ports")?;
438
439 let _multiclient_message_port = if self.vtl == Vtl::Vtl0 && !self.use_message_redirect {
443 let multiclient_message_sender = Arc::new(MessageSender {
444 send: message_send,
445 multiclient: true,
446 });
447
448 Some(
449 self.synic
450 .add_message_port(
451 MULTICLIENT_MESSAGE_CONNECTION_ID,
452 self.vtl,
453 multiclient_message_sender,
454 )
455 .context("failed to create vmbus synic ports")?,
456 )
457 } else {
458 None
459 };
460
461 let (offer_send, offer_recv) = mesh::mpsc_channel();
462 let control = Arc::new(VmbusServerControl {
463 mem: self.gm.clone(),
464 private_mem: self.private_gm.clone(),
465 send: offer_send,
466 use_event: self.synic.prefer_os_events(),
467 force_confidential_external_memory: self.force_confidential_external_memory,
468 });
469
470 let mut server = channels::Server::new(self.vtl, connection_id, self.channel_id_offset);
471
472 if let Some(version) = self.max_version {
474 server.set_compatibility_version(version, self.delay_max_version);
475 }
476 let (relay_request_send, relay_response_recv) =
477 if let Some(server_relay) = self.server_relay {
478 let r = server_relay.response_receive.boxed().fuse();
479 (server_relay.request_send, r)
480 } else {
481 let (req_send, req_recv) = mesh::channel();
482 let resp_recv = req_recv
483 .map(|_| {
484 ModifyConnectionResponse::Supported(
485 protocol::ConnectionState::SUCCESSFUL,
486 protocol::FeatureFlags::from_bits(u32::MAX),
487 )
488 })
489 .boxed()
490 .fuse();
491 (req_send, resp_recv)
492 };
493
494 let (hvsock_send, hvsock_recv) = if let Some(hvsock_notify) = self.hvsock_notify {
496 let r = hvsock_notify.response_receive.boxed().fuse();
497 (hvsock_notify.request_send, r)
498 } else {
499 let (req_send, req_recv) = mesh::channel();
500 let resp_recv = req_recv
501 .map(|r: HvsockConnectRequest| HvsockConnectResult::from_request(&r, false))
502 .boxed()
503 .fuse();
504 (req_send, resp_recv)
505 };
506
507 let inner = ServerTaskInner {
508 running: false,
509 send_messages_while_stopped: self.send_messages_while_stopped,
510 gm: self.gm,
511 private_gm: self.private_gm,
512 vtl: self.vtl,
513 redirect_vtl,
514 redirect_sint,
515 message_port: self
516 .synic
517 .new_guest_message_port(redirect_vtl, 0, redirect_sint)?,
518 synic: self.synic,
519 hvsock_requests: 0,
520 hvsock_send,
521 saved_state_notify: self.saved_state_notify,
522 channels: HashMap::new(),
523 channel_responses: FuturesUnordered::new(),
524 relay_send: relay_request_send,
525 external_server_send: self.external_server,
526 channel_bitmap: None,
527 shared_event_port: None,
528 reset_done: Vec::new(),
529 enable_mnf: self.enable_mnf,
530 };
531
532 let (task_send, task_recv) = mesh::channel();
533 let mut server_task = ServerTask {
534 server,
535 task_recv,
536 offer_recv,
537 message_recv,
538 server_request_recv: SelectAll::new(),
539 inner,
540 external_requests: self.external_requests,
541 next_seq: 0,
542 unstick_on_start: false,
543 };
544
545 let task = self.spawner.spawn("vmbus server", async move {
546 server_task.run(relay_response_recv, hvsock_recv).await;
547 server_task
548 });
549
550 Ok(VmbusServer {
551 task_send,
552 control,
553 _message_port,
554 _multiclient_message_port,
555 task,
556 })
557 }
558}
559
560impl VmbusServer {
561 pub fn builder<T: Spawn>(
563 spawner: &T,
564 synic: Arc<dyn SynicPortAccess>,
565 gm: GuestMemory,
566 ) -> VmbusServerBuilder<'_, T> {
567 VmbusServerBuilder::new(spawner, synic, gm)
568 }
569
570 pub async fn save(&self) -> SavedState {
571 self.task_send.call(VmbusRequest::Save, ()).await.unwrap()
572 }
573
574 pub async fn restore(&self, state: SavedState) -> Result<(), RestoreError> {
575 self.task_send
576 .call(VmbusRequest::Restore, Box::new(state))
577 .await
578 .unwrap()
579 }
580
581 pub async fn stop(&self) {
583 self.task_send.call(VmbusRequest::Stop, ()).await.unwrap()
584 }
585
586 pub fn start(&self) {
588 self.task_send.send(VmbusRequest::Start);
589 }
590
591 pub async fn reset(&self) {
593 tracing::debug!("resetting channel state");
594 self.task_send.call(VmbusRequest::Reset, ()).await.unwrap()
595 }
596
597 pub async fn shutdown(self) {
599 drop(self.task_send);
600 let _ = self.task.await;
601 }
602
603 pub fn control(&self) -> Arc<VmbusServerControl> {
605 self.control.clone()
606 }
607
608 fn get_child_message_connection_id(vp_index: u32, sint_index: u8, vtl: Vtl) -> u32 {
611 MULTICLIENT_MESSAGE_CONNECTION_ID
612 | (vtl as u32) << 22
613 | vp_index << 8
614 | (sint_index as u32) << 4
615 }
616
617 fn get_child_event_port_id(channel_id: protocol::ChannelId, sint_index: u8, vtl: Vtl) -> u32 {
618 EVENT_PORT_ID | (vtl as u32) << 22 | channel_id.0 << 8 | (sint_index as u32) << 4
619 }
620}
621
622#[derive(mesh::MeshPayload)]
623pub struct RestoreInfo {
624 open_data: Option<OpenData>,
625 gpadls: Vec<(GpadlId, u16, Vec<u64>)>,
626 interrupt: Option<Interrupt>,
627}
628
629#[derive(Default)]
630pub struct SynicMessage {
631 data: Vec<u8>,
632 multiclient: bool,
633 trusted: bool,
634}
635
636struct ServerTask {
637 server: channels::Server,
638 task_recv: mesh::Receiver<VmbusRequest>,
639 offer_recv: mesh::Receiver<OfferRequest>,
640 message_recv: mpsc::Receiver<SynicMessage>,
641 server_request_recv: SelectAll<TaggedStream<OfferId, mesh::Receiver<ChannelServerRequest>>>,
642 inner: ServerTaskInner,
643 external_requests: Option<mesh::Receiver<InitiateContactRequest>>,
644 next_seq: u64,
646 unstick_on_start: bool,
647}
648
649struct ServerTaskInner {
650 running: bool,
651 send_messages_while_stopped: bool,
652 gm: GuestMemory,
653 private_gm: Option<GuestMemory>,
654 synic: Arc<dyn SynicPortAccess>,
655 vtl: Vtl,
656 redirect_vtl: Vtl,
657 redirect_sint: u8,
658 message_port: Box<dyn GuestMessagePort>,
659 hvsock_requests: usize,
660 hvsock_send: mesh::Sender<HvsockConnectRequest>,
661 saved_state_notify: Option<mesh::Sender<SavedStateRequest>>,
662 channels: HashMap<OfferId, Channel>,
663 channel_responses: FuturesUnordered<
664 Pin<Box<dyn Send + Future<Output = (OfferId, u64, Result<ChannelResponse, RpcError>)>>>,
665 >,
666 external_server_send: Option<mesh::Sender<InitiateContactRequest>>,
667 relay_send: mesh::Sender<ModifyRelayRequest>,
668 channel_bitmap: Option<Arc<ChannelBitmap>>,
669 shared_event_port: Option<Box<dyn Send>>,
670 reset_done: Vec<Rpc<(), ()>>,
671 enable_mnf: bool,
672}
673
674#[derive(Debug)]
675enum ChannelResponse {
676 Open(Option<OpenResult>),
677 Close,
678 Gpadl(GpadlId, bool),
679 TeardownGpadl(GpadlId),
680 Modify(i32),
681}
682
683struct Channel {
684 key: OfferKey,
685 send: mesh::Sender<ChannelRequest>,
686 seq: u64,
687 state: ChannelState,
688 gpadls: Arc<GpadlMap>,
689 flags: protocol::OfferFlags,
690 reserved_state: ReservedState,
695}
696
697struct ReservedState {
698 message_port: Option<Box<dyn GuestMessagePort>>,
699 target: ConnectionTarget,
700}
701
702enum ChannelState {
703 Closed,
704 Opening {
705 open_params: OpenParams,
706 guest_event_port: Box<dyn GuestEventPort>,
707 host_to_guest_interrupt: Interrupt,
708 },
709 Open {
710 open_params: OpenParams,
711 _event_port: Box<dyn Send>,
712 guest_event_port: Box<dyn GuestEventPort>,
713 host_to_guest_interrupt: Interrupt,
714 guest_to_host_event: Arc<ChannelEvent>,
715 },
716 Closing,
717 FailedOpen,
718}
719
720impl ServerTask {
721 fn handle_offer(&mut self, mut info: OfferInfo) -> anyhow::Result<()> {
722 let key = info.params.key();
723 let flags = info.params.flags;
724
725 if self.inner.enable_mnf && self.inner.synic.monitor_support().is_some() {
726 if info.params.use_mnf.is_relayed() {
731 info.params.use_mnf = MnfUsage::Enabled {
732 latency: Duration::ZERO,
733 }
734 }
735 } else if info.params.use_mnf.is_enabled() {
736 info.params.use_mnf = MnfUsage::Disabled;
739 }
740
741 let offer_id = self
742 .server
743 .with_notifier(&mut self.inner)
744 .offer_channel(info.params)
745 .context("channel offer failed")?;
746
747 tracing::debug!(?offer_id, %key, "offered channel");
748
749 let id = self.next_seq;
750 self.next_seq += 1;
751 self.inner.channels.insert(
752 offer_id,
753 Channel {
754 key,
755 send: info.request_send,
756 state: ChannelState::Closed,
757 gpadls: GpadlMap::new(),
758 seq: id,
759 flags,
760 reserved_state: ReservedState {
761 message_port: None,
762 target: ConnectionTarget { vp: 0, sint: 0 },
763 },
764 },
765 );
766
767 self.server_request_recv
768 .push(TaggedStream::new(offer_id, info.server_request_recv));
769
770 Ok(())
771 }
772
773 fn handle_revoke(&mut self, offer_id: OfferId) {
774 if self.inner.channels.remove(&offer_id).is_some() {
777 tracing::info!(?offer_id, "revoking channel");
778 self.server
779 .with_notifier(&mut self.inner)
780 .revoke_channel(offer_id);
781 }
782 }
783
784 fn handle_response(
785 &mut self,
786 offer_id: OfferId,
787 seq: u64,
788 response: Result<ChannelResponse, RpcError>,
789 ) {
790 let channel = self
792 .inner
793 .channels
794 .get(&offer_id)
795 .filter(|channel| channel.seq == seq);
796
797 if let Some(channel) = channel {
798 match response {
799 Ok(response) => match response {
800 ChannelResponse::Open(result) => self.handle_open(offer_id, result),
801 ChannelResponse::Close => self.handle_close(offer_id),
802 ChannelResponse::Gpadl(gpadl_id, ok) => {
803 self.handle_gpadl_create(offer_id, gpadl_id, ok)
804 }
805 ChannelResponse::TeardownGpadl(gpadl_id) => {
806 self.handle_gpadl_teardown(offer_id, gpadl_id)
807 }
808 ChannelResponse::Modify(status) => self.handle_modify_channel(offer_id, status),
809 },
810 Err(err) => {
811 tracing::error!(
812 key = %channel.key,
813 error = &err as &dyn std::error::Error,
814 "channel response failure, channel is in inconsistent state until revoked"
815 );
816 }
817 }
818 } else {
819 tracing::debug!(offer_id = ?offer_id, seq, ?response, "received response after revoke");
820 }
821 }
822
823 fn handle_open(&mut self, offer_id: OfferId, result: Option<OpenResult>) {
824 let status = if result.is_some() {
825 0
826 } else {
827 protocol::STATUS_UNSUCCESSFUL
828 };
829 if let Err(err) = self.inner.complete_open(offer_id, result) {
830 tracelimit::error_ratelimited!(
831 error = err.as_ref() as &dyn std::error::Error,
832 "failed to complete open"
833 );
834 self.inner.notify(offer_id, channels::Action::Close);
838 } else {
839 self.server
840 .with_notifier(&mut self.inner)
841 .open_complete(offer_id, status);
842 }
843 }
844
845 fn handle_close(&mut self, offer_id: OfferId) {
846 let channel = self
847 .inner
848 .channels
849 .get_mut(&offer_id)
850 .expect("channel still exists");
851
852 match &mut channel.state {
853 ChannelState::Closing => {
854 channel.state = ChannelState::Closed;
855 self.server
856 .with_notifier(&mut self.inner)
857 .close_complete(offer_id);
858 }
859 ChannelState::FailedOpen => {
860 channel.state = ChannelState::Closed;
863 self.server
864 .with_notifier(&mut self.inner)
865 .open_complete(offer_id, protocol::STATUS_UNSUCCESSFUL);
866 }
867 _ => {
868 tracing::error!(?offer_id, "invalid close channel response");
869 }
870 };
871 }
872
873 fn handle_gpadl_create(&mut self, offer_id: OfferId, gpadl_id: GpadlId, ok: bool) {
874 let status = if ok { 0 } else { protocol::STATUS_UNSUCCESSFUL };
875 self.server
876 .with_notifier(&mut self.inner)
877 .gpadl_create_complete(offer_id, gpadl_id, status);
878 }
879
880 fn handle_gpadl_teardown(&mut self, offer_id: OfferId, gpadl_id: GpadlId) {
881 self.server
882 .with_notifier(&mut self.inner)
883 .gpadl_teardown_complete(offer_id, gpadl_id);
884 }
885
886 fn handle_modify_channel(&mut self, offer_id: OfferId, status: i32) {
887 self.server
888 .with_notifier(&mut self.inner)
889 .modify_channel_complete(offer_id, status);
890 }
891
892 fn handle_restore_channel(
893 &mut self,
894 offer_id: OfferId,
895 open: Option<OpenResult>,
896 ) -> anyhow::Result<RestoreResult> {
897 let gpadls = self.server.channel_gpadls(offer_id);
898
899 let open_request = open
902 .map(|result| -> anyhow::Result<_> {
903 let params = self.server.get_restore_open_params(offer_id)?;
904 let (_, interrupt) = self.inner.open_channel(offer_id, ¶ms)?;
905 let channel = self.inner.complete_open(offer_id, Some(result))?;
906 Ok(OpenRequest::new(
907 params.open_data,
908 interrupt,
909 self.server
910 .get_version()
911 .expect("must be connected")
912 .feature_flags,
913 channel.flags,
914 ))
915 })
916 .transpose()?;
917
918 self.server
919 .with_notifier(&mut self.inner)
920 .restore_channel(offer_id, open_request.is_some())?;
921
922 let channel = self.inner.channels.get_mut(&offer_id).unwrap();
923 for gpadl in &gpadls {
924 if let Ok(buf) =
925 MultiPagedRangeBuf::new(gpadl.request.count.into(), gpadl.request.buf.clone())
926 {
927 channel.gpadls.add(gpadl.request.id, buf);
928 }
929 }
930
931 let result = RestoreResult {
932 open_request,
933 gpadls,
934 };
935 Ok(result)
936 }
937
938 async fn handle_request(&mut self, request: VmbusRequest) {
939 tracing::debug!(?request, "handle_request");
940 match request {
941 VmbusRequest::Reset(rpc) => self.handle_reset(rpc),
942 VmbusRequest::Inspect(deferred) => {
943 deferred.respond(|resp| {
944 resp.field("message_port", &self.inner.message_port)
945 .field("running", self.inner.running)
946 .field("hvsock_requests", self.inner.hvsock_requests)
947 .field_mut_with("unstick_channels", |v| {
948 let v: inspect::ValueKind = if let Some(v) = v {
949 if v == "force" {
950 self.unstick_channels(true);
951 v.into()
952 } else {
953 let v =
954 v.parse().ok().context("expected false, true, or force")?;
955 if v {
956 self.unstick_channels(false);
957 }
958 v.into()
959 }
960 } else {
961 false.into()
962 };
963 anyhow::Ok(v)
964 })
965 .merge(&self.server.with_notifier(&mut self.inner));
966 });
967 }
968 VmbusRequest::Save(rpc) => rpc.handle_sync(|()| SavedState {
969 server: self.server.save(),
970 lost_synic_bug_fixed: true,
971 }),
972 VmbusRequest::Restore(rpc) => {
973 rpc.handle(async |state| {
974 self.unstick_on_start = !state.lost_synic_bug_fixed;
975 if let Some(sender) = &self.inner.saved_state_notify {
976 tracing::trace!("sending saved state to proxy");
977 if let Err(err) = sender
978 .call_failable(SavedStateRequest::Set, Box::new(state.server.clone()))
979 .await
980 {
981 tracing::error!(
982 err = &err as &dyn std::error::Error,
983 "failed to restore proxy saved state"
984 );
985 return Err(RestoreError::ServerError(err.into()));
986 }
987 }
988
989 self.server
990 .with_notifier(&mut self.inner)
991 .restore(state.server)
992 })
993 .await
994 }
995 VmbusRequest::Stop(rpc) => rpc.handle_sync(|()| {
996 if self.inner.running {
997 self.inner.running = false;
998 }
999 }),
1000 VmbusRequest::Start => {
1001 if !self.inner.running {
1002 self.inner.running = true;
1003 if let Some(sender) = self.inner.saved_state_notify.as_ref() {
1004 tracing::trace!("sending clear saved state message to proxy");
1007 sender
1008 .call(SavedStateRequest::Clear, ())
1009 .await
1010 .expect("failed to clear proxy saved state");
1011 }
1012
1013 self.server
1014 .with_notifier(&mut self.inner)
1015 .revoke_unclaimed_channels();
1016 if self.unstick_on_start {
1017 tracing::info!(
1018 "lost synic bug fix is not in yet, call unstick_channels to mitigate the issue."
1019 );
1020 self.unstick_channels(false);
1021 self.unstick_on_start = false;
1022 }
1023 }
1024 }
1025 }
1026 }
1027
1028 fn handle_reset(&mut self, rpc: Rpc<(), ()>) {
1029 let needs_reset = self.inner.reset_done.is_empty();
1030 self.inner.reset_done.push(rpc);
1031 if needs_reset {
1032 self.server.with_notifier(&mut self.inner).reset();
1033 }
1034 }
1035
1036 fn handle_relay_response(&mut self, response: ModifyConnectionResponse) {
1037 self.server
1038 .with_notifier(&mut self.inner)
1039 .complete_modify_connection(response);
1040 }
1041
1042 fn handle_tl_connect_result(&mut self, result: HvsockConnectResult) {
1043 assert_ne!(self.inner.hvsock_requests, 0);
1044 self.inner.hvsock_requests -= 1;
1045
1046 self.server
1047 .with_notifier(&mut self.inner)
1048 .send_tl_connect_result(result);
1049 }
1050
1051 fn handle_synic_message(&mut self, message: SynicMessage) {
1052 match self
1053 .server
1054 .with_notifier(&mut self.inner)
1055 .handle_synic_message(message)
1056 {
1057 Ok(()) => {}
1058 Err(err) => {
1059 tracing::warn!(
1060 error = &err as &dyn std::error::Error,
1061 "synic message error"
1062 );
1063 }
1064 }
1065 }
1066
1067 fn handle_external_request(&mut self, request: InitiateContactRequest) {
1074 self.server
1075 .with_notifier(&mut self.inner)
1076 .initiate_contact(request);
1077 }
1078
1079 async fn run(
1080 &mut self,
1081 mut relay_response_recv: impl futures::stream::FusedStream<Item = ModifyConnectionResponse>
1082 + Unpin,
1083 mut hvsock_recv: impl futures::stream::FusedStream<Item = HvsockConnectResult> + Unpin,
1084 ) {
1085 loop {
1086 let running_not_resetting = self.inner.running && self.inner.reset_done.is_empty();
1091 let mut external_requests = OptionFuture::from(
1092 running_not_resetting
1093 .then(|| {
1094 self.external_requests
1095 .as_mut()
1096 .map(|r| r.select_next_some())
1097 })
1098 .flatten(),
1099 );
1100
1101 let has_pending_messages = self.server.has_pending_messages();
1103 let message_port = self.inner.message_port.as_mut();
1104 let mut flush_pending_messages =
1105 OptionFuture::from((running_not_resetting && has_pending_messages).then(|| {
1106 poll_fn(|cx| {
1107 self.server.poll_flush_pending_messages(|msg| {
1108 message_port.poll_post_message(cx, VMBUS_MESSAGE_TYPE, msg.data())
1109 })
1110 })
1111 .fuse()
1112 }));
1113
1114 let mut message_recv = OptionFuture::from(
1118 (running_not_resetting
1119 && !has_pending_messages
1120 && self.inner.hvsock_requests < MAX_CONCURRENT_HVSOCK_REQUESTS)
1121 .then(|| self.message_recv.select_next_some()),
1122 );
1123
1124 let mut channel_response = OptionFuture::from(
1126 (self.inner.running || !self.inner.reset_done.is_empty())
1127 .then(|| self.inner.channel_responses.select_next_some()),
1128 );
1129
1130 let mut hvsock_response =
1132 OptionFuture::from(running_not_resetting.then(|| hvsock_recv.select_next_some()));
1133
1134 futures::select! { r = self.task_recv.recv().fuse() => {
1136 if let Ok(request) = r {
1137 self.handle_request(request).await;
1138 } else {
1139 break;
1140 }
1141 }
1142 r = self.offer_recv.select_next_some() => {
1143 match r {
1144 OfferRequest::Offer(rpc) => {
1145 rpc.handle_failable_sync(|request| { self.handle_offer(request) })
1146 },
1147 OfferRequest::ForceReset(rpc) => {
1148 self.handle_reset(rpc);
1149 }
1150 }
1151 }
1152 r = self.server_request_recv.select_next_some() => {
1153 match r {
1154 (id, Some(request)) => match request {
1155 ChannelServerRequest::Restore(rpc) => rpc.handle_failable_sync(|open| {
1156 self.handle_restore_channel(id, open)
1157 }),
1158 ChannelServerRequest::Revoke(rpc) => rpc.handle_sync(|_| {
1159 self.handle_revoke(id);
1160 })
1161 },
1162 (id, None) => self.handle_revoke(id),
1163 }
1164 }
1165 r = channel_response => {
1166 let (id, seq, response) = r.unwrap();
1167 self.handle_response(id, seq, response);
1168 }
1169 r = relay_response_recv.select_next_some() => {
1170 self.handle_relay_response(r);
1171 },
1172 r = hvsock_response => {
1173 self.handle_tl_connect_result(r.unwrap());
1174 }
1175 data = message_recv => {
1176 let data = data.unwrap();
1177 self.handle_synic_message(data);
1178 }
1179 r = external_requests => {
1180 let r = r.unwrap();
1181 self.handle_external_request(r);
1182 }
1183 _r = flush_pending_messages => {}
1184 complete => break,
1185 }
1186 }
1187 }
1188
1189 fn unstick_channels(&self, force: bool) {
1193 for channel in self.inner.channels.values() {
1194 if let Err(err) = self.unstick_channel(channel, force) {
1195 tracing::warn!(
1196 channel = %channel.key,
1197 error = err.as_ref() as &dyn std::error::Error,
1198 "could not unstick channel"
1199 );
1200 }
1201 }
1202 }
1203
1204 fn unstick_channel(&self, channel: &Channel, force: bool) -> anyhow::Result<()> {
1205 if let ChannelState::Open {
1206 open_params,
1207 host_to_guest_interrupt,
1208 guest_to_host_event,
1209 ..
1210 } = &channel.state
1211 {
1212 if force {
1213 tracing::info!(channel = %channel.key, "waking host and guest");
1214 guest_to_host_event.0.deliver();
1215 host_to_guest_interrupt.deliver();
1216 return Ok(());
1217 }
1218
1219 let gpadl = channel
1220 .gpadls
1221 .clone()
1222 .view()
1223 .map(open_params.open_data.ring_gpadl_id)
1224 .context("couldn't find ring gpadl")?;
1225
1226 let aligned = AlignedGpadlView::new(gpadl)
1227 .ok()
1228 .context("ring not aligned")?;
1229 let (in_gpadl, out_gpadl) = aligned
1230 .split(open_params.open_data.ring_offset)
1231 .ok()
1232 .context("couldn't split ring")?;
1233
1234 if let Err(err) = self.unstick_incoming_ring(
1235 channel,
1236 in_gpadl,
1237 guest_to_host_event,
1238 host_to_guest_interrupt,
1239 ) {
1240 tracing::warn!(
1241 channel = %channel.key,
1242 error = err.as_ref() as &dyn std::error::Error,
1243 "could not unstick incoming ring"
1244 );
1245 }
1246 if let Err(err) = self.unstick_outgoing_ring(
1247 channel,
1248 out_gpadl,
1249 guest_to_host_event,
1250 host_to_guest_interrupt,
1251 ) {
1252 tracing::warn!(
1253 channel = %channel.key,
1254 error = err.as_ref() as &dyn std::error::Error,
1255 "could not unstick outgoing ring"
1256 );
1257 }
1258 }
1259 Ok(())
1260 }
1261
1262 fn unstick_incoming_ring(
1263 &self,
1264 channel: &Channel,
1265 in_gpadl: AlignedGpadlView,
1266 guest_to_host_event: &ChannelEvent,
1267 host_to_guest_interrupt: &Interrupt,
1268 ) -> Result<(), anyhow::Error> {
1269 let incoming_mem = GpadlRingMem::new(in_gpadl, &self.inner.gm)?;
1270 if ring::reader_needs_signal(&incoming_mem) {
1271 tracing::info!(channel = %channel.key, "waking host for incoming ring");
1272 guest_to_host_event.0.deliver();
1273 }
1274 if ring::writer_needs_signal(&incoming_mem) {
1275 tracing::info!(channel = %channel.key, "waking guest for incoming ring");
1276 host_to_guest_interrupt.deliver();
1277 }
1278 Ok(())
1279 }
1280
1281 fn unstick_outgoing_ring(
1282 &self,
1283 channel: &Channel,
1284 out_gpadl: AlignedGpadlView,
1285 guest_to_host_event: &ChannelEvent,
1286 host_to_guest_interrupt: &Interrupt,
1287 ) -> Result<(), anyhow::Error> {
1288 let outgoing_mem = GpadlRingMem::new(out_gpadl, &self.inner.gm)?;
1289 if ring::reader_needs_signal(&outgoing_mem) {
1290 tracing::info!(channel = %channel.key, "waking guest for outgoing ring");
1291 host_to_guest_interrupt.deliver();
1292 }
1293 if ring::writer_needs_signal(&outgoing_mem) {
1294 tracing::info!(channel = %channel.key, "waking host for outgoing ring");
1295 guest_to_host_event.0.deliver();
1296 }
1297 Ok(())
1298 }
1299}
1300
1301impl Notifier for ServerTaskInner {
1302 fn notify(&mut self, offer_id: OfferId, action: channels::Action) {
1303 let channel = self
1304 .channels
1305 .get_mut(&offer_id)
1306 .expect("channel does not exist");
1307
1308 fn handle<I: 'static + Send, R: 'static + Send>(
1309 offer_id: OfferId,
1310 channel: &Channel,
1311 req: impl FnOnce(Rpc<I, R>) -> ChannelRequest,
1312 input: I,
1313 f: impl 'static + Send + FnOnce(R) -> ChannelResponse,
1314 ) -> Pin<Box<dyn Send + Future<Output = (OfferId, u64, Result<ChannelResponse, RpcError>)>>>
1315 {
1316 let recv = channel.send.call(req, input);
1317 let seq = channel.seq;
1318 Box::pin(async move {
1319 let r = recv.await.map(f);
1320 (offer_id, seq, r)
1321 })
1322 }
1323
1324 let response = match action {
1325 channels::Action::Open(open_params, version) => {
1326 let seq = channel.seq;
1327 match self.open_channel(offer_id, &open_params) {
1328 Ok((channel, interrupt)) => handle(
1329 offer_id,
1330 channel,
1331 ChannelRequest::Open,
1332 OpenRequest::new(
1333 open_params.open_data,
1334 interrupt,
1335 version.feature_flags,
1336 channel.flags,
1337 ),
1338 ChannelResponse::Open,
1339 ),
1340 Err(err) => {
1341 tracelimit::error_ratelimited!(
1342 err = err.as_ref() as &dyn std::error::Error,
1343 ?offer_id,
1344 "could not open channel",
1345 );
1346
1347 Box::pin(future::ready((
1350 offer_id,
1351 seq,
1352 Ok(ChannelResponse::Open(None)),
1353 )))
1354 }
1355 }
1356 }
1357 channels::Action::Close => {
1358 if let Some(channel_bitmap) = self.channel_bitmap.as_ref() {
1359 if let ChannelState::Open { open_params, .. } = channel.state {
1360 channel_bitmap.unregister_channel(open_params.event_flag);
1361 }
1362 }
1363
1364 channel.state = ChannelState::Closing;
1365 handle(offer_id, channel, ChannelRequest::Close, (), |()| {
1366 ChannelResponse::Close
1367 })
1368 }
1369 channels::Action::Gpadl(gpadl_id, count, buf) => {
1370 channel.gpadls.add(
1371 gpadl_id,
1372 MultiPagedRangeBuf::new(count.into(), buf.clone()).unwrap(),
1373 );
1374 handle(
1375 offer_id,
1376 channel,
1377 ChannelRequest::Gpadl,
1378 GpadlRequest {
1379 id: gpadl_id,
1380 count,
1381 buf,
1382 },
1383 move |r| ChannelResponse::Gpadl(gpadl_id, r),
1384 )
1385 }
1386 channels::Action::TeardownGpadl {
1387 gpadl_id,
1388 post_restore,
1389 } => {
1390 if !post_restore {
1391 channel.gpadls.remove(gpadl_id, Box::new(|| ()));
1392 }
1393
1394 handle(
1395 offer_id,
1396 channel,
1397 ChannelRequest::TeardownGpadl,
1398 gpadl_id,
1399 move |()| ChannelResponse::TeardownGpadl(gpadl_id),
1400 )
1401 }
1402 channels::Action::Modify { target_vp } => {
1403 if let ChannelState::Open {
1404 guest_event_port, ..
1405 } = &mut channel.state
1406 {
1407 if let Err(err) = guest_event_port.set_target_vp(target_vp) {
1408 tracelimit::error_ratelimited!(
1409 error = &err as &dyn std::error::Error,
1410 channel = %channel.key,
1411 "could not modify channel",
1412 );
1413 let seq = channel.seq;
1414 Box::pin(async move {
1415 (
1416 offer_id,
1417 seq,
1418 Ok(ChannelResponse::Modify(protocol::STATUS_UNSUCCESSFUL)),
1419 )
1420 })
1421 } else {
1422 handle(
1423 offer_id,
1424 channel,
1425 ChannelRequest::Modify,
1426 ModifyRequest::TargetVp { target_vp },
1427 ChannelResponse::Modify,
1428 )
1429 }
1430 } else {
1431 unreachable!();
1432 }
1433 }
1434 };
1435 self.channel_responses.push(response);
1436 }
1437
1438 fn modify_connection(&mut self, mut request: ModifyConnectionRequest) -> anyhow::Result<()> {
1439 self.map_interrupt_page(request.interrupt_page)
1440 .context("Failed to map interrupt page.")?;
1441
1442 self.set_monitor_page(request.monitor_page)
1443 .context("Failed to map monitor page.")?;
1444
1445 if let Some(vp) = request.target_message_vp {
1446 self.message_port.set_target_vp(vp)?;
1447 }
1448
1449 if request.notify_relay {
1450 if self.enable_mnf {
1455 request.monitor_page = Update::Unchanged;
1456 }
1457
1458 self.relay_send.send(request.into());
1459 }
1460
1461 Ok(())
1462 }
1463
1464 fn forward_unhandled(&mut self, request: InitiateContactRequest) {
1465 if let Some(external_server) = &self.external_server_send {
1466 external_server.send(request);
1467 } else {
1468 tracing::warn!(?request, "nowhere to forward unhandled request")
1469 }
1470 }
1471
1472 fn inspect(&self, version: Option<VersionInfo>, offer_id: OfferId, req: inspect::Request<'_>) {
1473 let channel = self.channels.get(&offer_id).expect("should exist");
1474 let mut resp = req.respond();
1475 if let ChannelState::Open { open_params, .. } = &channel.state {
1476 let mem = if self.private_gm.is_some()
1477 && channel.flags.confidential_ring_buffer()
1478 && version
1479 .expect("must be connected")
1480 .feature_flags
1481 .confidential_channels()
1482 {
1483 self.private_gm.as_ref().unwrap()
1484 } else {
1485 &self.gm
1486 };
1487
1488 inspect_rings(
1489 &mut resp,
1490 mem,
1491 channel.gpadls.clone(),
1492 &open_params.open_data,
1493 );
1494 }
1495 }
1496
1497 fn send_message(&mut self, message: &OutgoingMessage, target: MessageTarget) -> bool {
1498 if !self.running && !self.send_messages_while_stopped {
1511 if !matches!(target, MessageTarget::Default) {
1512 tracelimit::error_ratelimited!(?target, "dropping message while paused");
1513 }
1514 return false;
1515 }
1516
1517 let mut port_storage;
1518 let port = match target {
1519 MessageTarget::Default => self.message_port.as_mut(),
1520 MessageTarget::ReservedChannel(offer_id, target) => {
1521 if let Some(port) = self.get_reserved_channel_message_port(offer_id, target) {
1522 port.as_mut()
1523 } else {
1524 return true;
1526 }
1527 }
1528 MessageTarget::Custom(target) => {
1529 port_storage = match self.synic.new_guest_message_port(
1530 self.redirect_vtl,
1531 target.vp,
1532 target.sint,
1533 ) {
1534 Ok(port) => port,
1535 Err(err) => {
1536 tracing::error!(
1537 ?err,
1538 ?self.redirect_vtl,
1539 ?target,
1540 "could not create message port"
1541 );
1542
1543 return true;
1545 }
1546 };
1547 port_storage.as_mut()
1548 }
1549 };
1550
1551 matches!(
1554 port.poll_post_message(
1555 &mut std::task::Context::from_waker(std::task::Waker::noop()),
1556 VMBUS_MESSAGE_TYPE,
1557 message.data()
1558 ),
1559 Poll::Ready(())
1560 )
1561 }
1562
1563 fn notify_hvsock(&mut self, request: &HvsockConnectRequest) {
1564 self.hvsock_requests += 1;
1565 self.hvsock_send.send(*request);
1566 }
1567
1568 fn reset_complete(&mut self) {
1569 if let Some(monitor) = self.synic.monitor_support() {
1570 if let Err(err) = monitor.set_monitor_page(self.vtl, None) {
1571 tracing::warn!(?err, "resetting monitor page failed")
1572 }
1573 }
1574
1575 self.unreserve_channels();
1576 for done in self.reset_done.drain(..) {
1577 done.complete(());
1578 }
1579 }
1580
1581 fn unload_complete(&mut self) {
1582 self.unreserve_channels();
1583 }
1584}
1585
1586impl ServerTaskInner {
1587 fn open_channel(
1588 &mut self,
1589 offer_id: OfferId,
1590 open_params: &OpenParams,
1591 ) -> anyhow::Result<(&mut Channel, Interrupt)> {
1592 let channel = self
1593 .channels
1594 .get_mut(&offer_id)
1595 .expect("channel does not exist");
1596
1597 let (target_vp, event_flag) = if self.channel_bitmap.is_some() {
1600 (0, 0)
1601 } else {
1602 (open_params.open_data.target_vp, open_params.event_flag)
1603 };
1604 let (target_vtl, target_sint) = if open_params.flags.redirect_interrupt() {
1605 (self.redirect_vtl, self.redirect_sint)
1606 } else {
1607 (self.vtl, SINT)
1608 };
1609
1610 let guest_event_port = self.synic.new_guest_event_port(
1611 VmbusServer::get_child_event_port_id(open_params.channel_id, SINT, self.vtl),
1612 target_vtl,
1613 target_vp,
1614 target_sint,
1615 event_flag,
1616 open_params.monitor_info,
1617 )?;
1618
1619 let interrupt = ChannelBitmap::create_interrupt(
1620 &self.channel_bitmap,
1621 guest_event_port.interrupt(),
1622 open_params.event_flag,
1623 );
1624
1625 channel.reserved_state.message_port = None;
1627
1628 if let Some(target) = open_params.reserved_target {
1630 channel.reserved_state.message_port = Some(self.synic.new_guest_message_port(
1631 self.redirect_vtl,
1632 target.vp,
1633 target.sint,
1634 )?);
1635
1636 channel.reserved_state.target = target;
1637 }
1638
1639 channel.state = ChannelState::Opening {
1640 open_params: *open_params,
1641 guest_event_port,
1642 host_to_guest_interrupt: interrupt.clone(),
1643 };
1644 Ok((channel, interrupt))
1645 }
1646
1647 fn complete_open(
1648 &mut self,
1649 offer_id: OfferId,
1650 result: Option<OpenResult>,
1651 ) -> anyhow::Result<&mut Channel> {
1652 let channel = self
1653 .channels
1654 .get_mut(&offer_id)
1655 .expect("channel does not exist");
1656
1657 channel.state = if let Some(result) = result {
1658 match std::mem::replace(&mut channel.state, ChannelState::FailedOpen) {
1661 ChannelState::Opening {
1662 open_params,
1663 guest_event_port,
1664 host_to_guest_interrupt,
1665 } => {
1666 let guest_to_host_event =
1667 Arc::new(ChannelEvent(result.guest_to_host_interrupt));
1668 if let Some(channel_bitmap) = self.channel_bitmap.as_ref() {
1670 channel_bitmap.register_channel(
1671 open_params.event_flag,
1672 guest_to_host_event.0.clone(),
1673 );
1674 }
1675 let event_port = self
1677 .synic
1678 .add_event_port(
1679 open_params.connection_id,
1680 self.vtl,
1681 guest_to_host_event.clone(),
1682 open_params.monitor_info,
1683 )
1684 .with_context(|| {
1685 format!(
1686 "failed to create event port for VTL {:?}, connection ID {:#x}",
1687 self.vtl, open_params.connection_id
1688 )
1689 })?;
1690
1691 ChannelState::Open {
1692 open_params,
1693 _event_port: event_port,
1694 guest_event_port,
1695 host_to_guest_interrupt,
1696 guest_to_host_event,
1697 }
1698 }
1699 s => {
1700 tracing::error!("attempting to complete open of open or closed channel");
1701 s
1703 }
1704 }
1705 } else {
1706 ChannelState::Closed
1707 };
1708 Ok(channel)
1709 }
1710
1711 fn map_interrupt_page(&mut self, interrupt_page: Update<u64>) -> anyhow::Result<()> {
1714 let interrupt_page = match interrupt_page {
1715 Update::Unchanged => return Ok(()),
1716 Update::Reset => {
1717 self.channel_bitmap = None;
1718 self.shared_event_port = None;
1719 return Ok(());
1720 }
1721 Update::Set(interrupt_page) => interrupt_page,
1722 };
1723
1724 assert_ne!(interrupt_page, 0);
1725
1726 if interrupt_page % PAGE_SIZE as u64 != 0 {
1727 anyhow::bail!("interrupt page {:#x} is not page aligned", interrupt_page);
1728 }
1729
1730 let interrupt_page = self
1733 .gm
1734 .lockable_subrange(interrupt_page, PAGE_SIZE as u64)?
1735 .lock_gpns(false, &[0])?;
1736
1737 let channel_bitmap = Arc::new(ChannelBitmap::new(interrupt_page));
1738 self.channel_bitmap = Some(channel_bitmap.clone());
1739
1740 let interrupt = Interrupt::from_fn(move || {
1742 channel_bitmap.handle_shared_interrupt();
1743 });
1744
1745 self.shared_event_port = Some(self.synic.add_event_port(
1746 SHARED_EVENT_CONNECTION_ID,
1747 self.vtl,
1748 Arc::new(ChannelEvent(interrupt)),
1749 None,
1750 )?);
1751
1752 Ok(())
1753 }
1754
1755 fn set_monitor_page(&mut self, monitor_page: Update<MonitorPageGpas>) -> anyhow::Result<()> {
1756 let monitor_page = match monitor_page {
1757 Update::Unchanged => return Ok(()),
1758 Update::Reset => None,
1759 Update::Set(value) => Some(value),
1760 };
1761
1762 if self.channels.iter().any(|(_, c)| {
1764 matches!(
1765 &c.state,
1766 ChannelState::Open {
1767 open_params,
1768 ..
1769 } | ChannelState::Opening {
1770 open_params,
1771 ..
1772 } if open_params.monitor_info.is_some()
1773 )
1774 }) {
1775 anyhow::bail!("attempt to change monitor page while open channels using mnf");
1776 }
1777
1778 if self.enable_mnf {
1779 if let Some(monitor) = self.synic.monitor_support() {
1780 if let Err(err) = monitor.set_monitor_page(self.vtl, monitor_page) {
1781 anyhow::bail!(
1782 "setting monitor page failed, err = {err:?}, monitor_page = {monitor_page:?}"
1783 );
1784 }
1785 }
1786 }
1787
1788 Ok(())
1789 }
1790
1791 fn get_reserved_channel_message_port(
1792 &mut self,
1793 offer_id: OfferId,
1794 new_target: ConnectionTarget,
1795 ) -> Option<&mut Box<dyn GuestMessagePort>> {
1796 let channel = self
1797 .channels
1798 .get_mut(&offer_id)
1799 .expect("channel does not exist");
1800
1801 assert!(
1802 channel.reserved_state.message_port.is_some(),
1803 "channel is not reserved"
1804 );
1805
1806 if channel.reserved_state.target.sint != new_target.sint {
1809 channel.reserved_state.message_port = None;
1811 let message_port = self
1812 .synic
1813 .new_guest_message_port(self.redirect_vtl, new_target.vp, new_target.sint)
1814 .inspect_err(|err| {
1815 tracing::error!(
1816 ?err,
1817 ?self.redirect_vtl,
1818 ?new_target,
1819 "could not create reserved channel message port"
1820 )
1821 })
1822 .ok()?;
1823
1824 channel.reserved_state.message_port = Some(message_port);
1825 channel.reserved_state.target = new_target;
1826 } else if channel.reserved_state.target.vp != new_target.vp {
1827 let message_port = channel.reserved_state.message_port.as_mut().unwrap();
1828
1829 if let Err(err) = message_port.set_target_vp(new_target.vp) {
1832 tracing::error!(
1833 ?err,
1834 ?self.redirect_vtl,
1835 ?new_target,
1836 "could not update reserved channel message port"
1837 );
1838 }
1839
1840 channel.reserved_state.target = new_target;
1841 return Some(message_port);
1842 }
1843
1844 Some(channel.reserved_state.message_port.as_mut().unwrap())
1845 }
1846
1847 fn unreserve_channels(&mut self) {
1848 for channel in self.channels.values_mut() {
1850 if let ChannelState::Closed = channel.state {
1851 channel.reserved_state.message_port = None;
1852 }
1853 }
1854 }
1855}
1856
1857#[derive(Clone)]
1859pub struct VmbusServerControl {
1860 mem: GuestMemory,
1861 private_mem: Option<GuestMemory>,
1862 send: mesh::Sender<OfferRequest>,
1863 use_event: bool,
1864 force_confidential_external_memory: bool,
1865}
1866
1867impl VmbusServerControl {
1868 pub async fn offer_core(&self, offer_info: OfferInfo) -> anyhow::Result<OfferResources> {
1871 let flags = offer_info.params.flags;
1872 self.send
1873 .call_failable(OfferRequest::Offer, offer_info)
1874 .await?;
1875 Ok(OfferResources::new(
1876 self.mem.clone(),
1877 if flags.confidential_ring_buffer() || flags.confidential_external_memory() {
1878 self.private_mem.clone()
1879 } else {
1880 None
1881 },
1882 ))
1883 }
1884
1885 pub async fn force_reset(&self) -> anyhow::Result<()> {
1888 self.send
1889 .call(OfferRequest::ForceReset, ())
1890 .await
1891 .context("vmbus server is gone")
1892 }
1893
1894 async fn offer(&self, request: OfferInput) -> anyhow::Result<OfferResources> {
1895 let mut offer_info = OfferInfo {
1896 params: request.params.into(),
1897 request_send: request.request_send,
1898 server_request_recv: request.server_request_recv,
1899 };
1900
1901 if self.force_confidential_external_memory {
1902 tracing::warn!(
1903 key = %offer_info.params.key(),
1904 "forcing confidential external memory for channel"
1905 );
1906
1907 offer_info
1908 .params
1909 .flags
1910 .set_confidential_external_memory(true);
1911 }
1912
1913 self.offer_core(offer_info).await
1914 }
1915}
1916
1917fn inspect_rings(
1919 resp: &mut inspect::Response<'_>,
1920 gm: &GuestMemory,
1921 gpadl_map: Arc<GpadlMap>,
1922 open_data: &OpenData,
1923) -> Option<()> {
1924 let gpadl = gpadl_map
1925 .view()
1926 .map(GpadlId(open_data.ring_gpadl_id.0))
1927 .ok()?;
1928 let aligned = AlignedGpadlView::new(gpadl).ok()?;
1929 let (in_gpadl, out_gpadl) = aligned.split(open_data.ring_offset).ok()?;
1930 if let Ok(incoming_mem) = GpadlRingMem::new(in_gpadl, gm) {
1931 resp.child("incoming_ring", |req| ring::inspect_ring(incoming_mem, req));
1932 }
1933 if let Ok(outgoing_mem) = GpadlRingMem::new(out_gpadl, gm) {
1934 resp.child("outgoing_ring", |req| ring::inspect_ring(outgoing_mem, req));
1935 }
1936 Some(())
1937}
1938
1939pub(crate) struct MessageSender {
1940 send: mpsc::Sender<SynicMessage>,
1941 multiclient: bool,
1942}
1943
1944impl MessageSender {
1945 fn poll_handle_message(
1946 &self,
1947 cx: &mut std::task::Context<'_>,
1948 msg: &[u8],
1949 trusted: bool,
1950 ) -> Poll<Result<(), SendError>> {
1951 let mut send = self.send.clone();
1952 ready!(send.poll_ready(cx))?;
1953 send.start_send(SynicMessage {
1954 data: msg.to_vec(),
1955 multiclient: self.multiclient,
1956 trusted,
1957 })?;
1958
1959 Poll::Ready(Ok(()))
1960 }
1961}
1962
1963impl MessagePort for MessageSender {
1964 fn poll_handle_message(
1965 &self,
1966 cx: &mut std::task::Context<'_>,
1967 msg: &[u8],
1968 trusted: bool,
1969 ) -> Poll<()> {
1970 if let Err(err) = ready!(self.poll_handle_message(cx, msg, trusted)) {
1971 tracelimit::error_ratelimited!(
1972 error = &err as &dyn std::error::Error,
1973 "failed to send message"
1974 );
1975 }
1976
1977 Poll::Ready(())
1978 }
1979}
1980
1981#[async_trait]
1982impl ParentBus for VmbusServerControl {
1983 async fn add_child(&self, request: OfferInput) -> anyhow::Result<OfferResources> {
1984 self.offer(request).await
1985 }
1986
1987 fn clone_bus(&self) -> Box<dyn ParentBus> {
1988 Box::new(self.clone())
1989 }
1990
1991 fn use_event(&self) -> bool {
1992 self.use_event
1993 }
1994}
1995
1996#[cfg(test)]
1997mod tests {
1998 use super::*;
1999 use inspect::InspectMut;
2000 use mesh::CancelReason;
2001 use pal_async::DefaultDriver;
2002 use pal_async::async_test;
2003 use pal_async::driver::SpawnDriver;
2004 use pal_async::timer::Instant;
2005 use pal_async::timer::PolledTimer;
2006 use parking_lot::Mutex;
2007 use protocol::UserDefinedData;
2008 use std::time::Duration;
2009 use test_with_tracing::test;
2010 use vmbus_channel::bus::OfferParams;
2011 use vmbus_channel::channel::ChannelOpenError;
2012 use vmbus_channel::channel::DeviceResources;
2013 use vmbus_channel::channel::SaveRestoreVmbusDevice;
2014 use vmbus_channel::channel::VmbusDevice;
2015 use vmbus_channel::channel::offer_channel;
2016 use vmbus_core::protocol::ChannelId;
2017 use vmbus_core::protocol::VmbusMessage;
2018 use vmcore::synic::MonitorInfo;
2019 use vmcore::synic::SynicPortAccess;
2020 use zerocopy::FromBytes;
2021 use zerocopy::Immutable;
2022 use zerocopy::IntoBytes;
2023 use zerocopy::KnownLayout;
2024
2025 struct MockSynicInner {
2026 message_port: Option<Arc<dyn MessagePort>>,
2027 }
2028
2029 struct MockSynic {
2030 inner: Mutex<MockSynicInner>,
2031 message_send: mesh::Sender<Vec<u8>>,
2032 spawner: Arc<dyn SpawnDriver>,
2033 }
2034
2035 impl MockSynic {
2036 fn new(message_send: mesh::Sender<Vec<u8>>, spawner: Arc<dyn SpawnDriver>) -> Self {
2037 Self {
2038 inner: Mutex::new(MockSynicInner { message_port: None }),
2039 message_send,
2040 spawner,
2041 }
2042 }
2043
2044 fn send_message(&self, msg: impl VmbusMessage + IntoBytes + Immutable + KnownLayout) {
2045 self.send_message_core(OutgoingMessage::new(&msg), false);
2046 }
2047
2048 fn send_message_trusted(
2049 &self,
2050 msg: impl VmbusMessage + IntoBytes + Immutable + KnownLayout,
2051 ) {
2052 self.send_message_core(OutgoingMessage::new(&msg), true);
2053 }
2054
2055 fn send_message_core(&self, msg: OutgoingMessage, trusted: bool) {
2056 assert_eq!(
2057 self.inner
2058 .lock()
2059 .message_port
2060 .as_ref()
2061 .unwrap()
2062 .poll_handle_message(
2063 &mut std::task::Context::from_waker(std::task::Waker::noop()),
2064 msg.data(),
2065 trusted,
2066 ),
2067 Poll::Ready(())
2068 );
2069 }
2070 }
2071
2072 #[derive(Debug)]
2073 struct MockGuestPort {}
2074
2075 impl GuestEventPort for MockGuestPort {
2076 fn interrupt(&self) -> Interrupt {
2077 Interrupt::null()
2078 }
2079
2080 fn set_target_vp(&mut self, _vp: u32) -> Result<(), vmcore::synic::HypervisorError> {
2081 Ok(())
2082 }
2083 }
2084
2085 struct MockGuestMessagePort {
2086 send: mesh::Sender<Vec<u8>>,
2087 spawner: Arc<dyn SpawnDriver>,
2088 timer: Option<(PolledTimer, Instant)>,
2089 }
2090
2091 impl GuestMessagePort for MockGuestMessagePort {
2092 fn poll_post_message(
2093 &mut self,
2094 cx: &mut std::task::Context<'_>,
2095 _typ: u32,
2096 payload: &[u8],
2097 ) -> Poll<()> {
2098 if let Some((timer, deadline)) = self.timer.as_mut() {
2099 ready!(timer.sleep_until(*deadline).poll_unpin(cx));
2100 self.timer = None;
2101 }
2102
2103 let mut pending_chance = [0; 1];
2105 getrandom::fill(&mut pending_chance).unwrap();
2106 if pending_chance[0] % 4 == 0 {
2107 let mut timer = PolledTimer::new(self.spawner.as_ref());
2108 let deadline = Instant::now() + Duration::from_millis(10);
2109 match timer.sleep_until(deadline).poll_unpin(cx) {
2110 Poll::Ready(_) => {}
2111 Poll::Pending => {
2112 self.timer = Some((timer, deadline));
2113 return Poll::Pending;
2114 }
2115 }
2116 }
2117
2118 self.send.send(payload.into());
2119 Poll::Ready(())
2120 }
2121
2122 fn set_target_vp(&mut self, _vp: u32) -> Result<(), vmcore::synic::HypervisorError> {
2123 Ok(())
2124 }
2125 }
2126
2127 impl Inspect for MockGuestMessagePort {
2128 fn inspect(&self, _req: inspect::Request<'_>) {}
2129 }
2130
2131 impl SynicPortAccess for MockSynic {
2132 fn add_message_port(
2133 &self,
2134 connection_id: u32,
2135 _minimum_vtl: Vtl,
2136 port: Arc<dyn MessagePort>,
2137 ) -> Result<Box<dyn Sync + Send>, vmcore::synic::Error> {
2138 self.inner.lock().message_port = Some(port);
2139 Ok(Box::new(connection_id))
2140 }
2141
2142 fn add_event_port(
2143 &self,
2144 connection_id: u32,
2145 _minimum_vtl: Vtl,
2146 _port: Arc<dyn EventPort>,
2147 _monitor_info: Option<MonitorInfo>,
2148 ) -> Result<Box<dyn Sync + Send>, vmcore::synic::Error> {
2149 Ok(Box::new(connection_id))
2150 }
2151
2152 fn new_guest_message_port(
2153 &self,
2154 _vtl: Vtl,
2155 _vp: u32,
2156 _sint: u8,
2157 ) -> Result<Box<(dyn GuestMessagePort)>, vmcore::synic::HypervisorError> {
2158 Ok(Box::new(MockGuestMessagePort {
2159 send: self.message_send.clone(),
2160 spawner: Arc::clone(&self.spawner),
2161 timer: None,
2162 }))
2163 }
2164
2165 fn new_guest_event_port(
2166 &self,
2167 _port_id: u32,
2168 _vtl: Vtl,
2169 _vp: u32,
2170 _sint: u8,
2171 _flag: u16,
2172 _monitor_info: Option<MonitorInfo>,
2173 ) -> Result<Box<(dyn GuestEventPort)>, vmcore::synic::HypervisorError> {
2174 Ok(Box::new(MockGuestPort {}))
2175 }
2176
2177 fn prefer_os_events(&self) -> bool {
2178 false
2179 }
2180 }
2181
2182 struct TestChannel {
2183 request_recv: mesh::Receiver<ChannelRequest>,
2184 server_request_send: mesh::Sender<ChannelServerRequest>,
2185 _resources: OfferResources,
2186 }
2187
2188 impl TestChannel {
2189 async fn next_request(&mut self) -> ChannelRequest {
2190 self.request_recv.next().await.unwrap()
2191 }
2192
2193 async fn handle_gpadl(&mut self) {
2194 let ChannelRequest::Gpadl(rpc) = self.next_request().await else {
2195 panic!("Wrong request");
2196 };
2197
2198 rpc.complete(true);
2199 }
2200
2201 async fn handle_open(&mut self, f: fn(&OpenRequest)) {
2202 let ChannelRequest::Open(rpc) = self.next_request().await else {
2203 panic!("Wrong request");
2204 };
2205
2206 f(rpc.input());
2207 rpc.complete(Some(OpenResult {
2208 guest_to_host_interrupt: Interrupt::null(),
2209 }));
2210 }
2211
2212 async fn handle_gpadl_teardown(&mut self) {
2213 let rpc = self.get_gpadl_teardown().await;
2214 rpc.complete(());
2215 }
2216
2217 async fn get_gpadl_teardown(&mut self) -> Rpc<GpadlId, ()> {
2218 let ChannelRequest::TeardownGpadl(rpc) = self.next_request().await else {
2219 panic!("Wrong request");
2220 };
2221
2222 rpc
2223 }
2224
2225 async fn restore(&self) {
2226 self.server_request_send
2227 .call(ChannelServerRequest::Restore, None)
2228 .await
2229 .unwrap()
2230 .unwrap();
2231 }
2232 }
2233
2234 struct TestEnv {
2235 vmbus: VmbusServer,
2236 synic: Arc<MockSynic>,
2237 message_recv: mesh::Receiver<Vec<u8>>,
2238 trusted: bool,
2239 }
2240
2241 impl TestEnv {
2242 fn new(spawner: DefaultDriver) -> Self {
2243 let spawner: Arc<dyn SpawnDriver> = Arc::new(spawner);
2244 let (message_send, message_recv) = mesh::channel();
2245 let synic = Arc::new(MockSynic::new(message_send, Arc::clone(&spawner)));
2246 let gm = GuestMemory::empty();
2247 let vmbus = VmbusServerBuilder::new(&spawner, synic.clone(), gm)
2248 .build()
2249 .unwrap();
2250
2251 Self {
2252 vmbus,
2253 synic,
2254 message_recv,
2255 trusted: false,
2256 }
2257 }
2258
2259 async fn offer(&self, id: u32, allow_confidential_external_memory: bool) -> TestChannel {
2260 let guid = Guid {
2261 data1: id,
2262 ..Guid::ZERO
2263 };
2264 let (request_send, request_recv) = mesh::channel();
2265 let (server_request_send, server_request_recv) = mesh::channel();
2266 let offer = OfferInput {
2267 request_send,
2268 server_request_recv,
2269 params: OfferParams {
2270 interface_name: "test".into(),
2271 instance_id: guid,
2272 interface_id: guid,
2273 mmio_megabytes: 0,
2274 mmio_megabytes_optional: 0,
2275 channel_type: vmbus_channel::bus::ChannelType::Device {
2276 pipe_packets: false,
2277 },
2278 subchannel_index: 0,
2279 mnf_interrupt_latency: None,
2280 offer_order: None,
2281 allow_confidential_external_memory,
2282 },
2283 };
2284
2285 let control = self.vmbus.control();
2286 let _resources = control.add_child(offer).await.unwrap();
2287
2288 TestChannel {
2289 request_recv,
2290 server_request_send,
2291 _resources,
2292 }
2293 }
2294
2295 async fn gpadl(&mut self, channel_id: u32, gpadl_id: u32, channel: &mut TestChannel) {
2296 self.synic.send_message_core(
2297 OutgoingMessage::with_data(
2298 &protocol::GpadlHeader {
2299 channel_id: ChannelId(channel_id),
2300 gpadl_id: GpadlId(gpadl_id),
2301 count: 1,
2302 len: 16,
2303 },
2304 [1u64, 0u64].as_bytes(),
2305 ),
2306 self.trusted,
2307 );
2308
2309 channel.handle_gpadl().await;
2310 self.expect_response(protocol::MessageType::GPADL_CREATED)
2311 .await;
2312 }
2313
2314 async fn open_channel(
2315 &mut self,
2316 channel_id: u32,
2317 ring_gpadl_id: u32,
2318 channel: &mut TestChannel,
2319 f: fn(&OpenRequest),
2320 ) {
2321 self.gpadl(channel_id, ring_gpadl_id, channel).await;
2322 self.synic.send_message_core(
2323 OutgoingMessage::new(&protocol::OpenChannel {
2324 channel_id: ChannelId(channel_id),
2325 open_id: 0,
2326 ring_buffer_gpadl_id: GpadlId(ring_gpadl_id),
2327 target_vp: 0,
2328 downstream_ring_buffer_page_offset: 0,
2329 user_data: UserDefinedData::default(),
2330 }),
2331 self.trusted,
2332 );
2333
2334 channel.handle_open(f).await;
2335 self.expect_response(protocol::MessageType::OPEN_CHANNEL_RESULT)
2336 .await;
2337 }
2338
2339 async fn expect_response(&mut self, expected: protocol::MessageType) {
2340 let data = self.message_recv.next().await.unwrap();
2341 let header = protocol::MessageHeader::read_from_prefix(&data).unwrap().0; assert_eq!(expected, header.message_type())
2343 }
2344
2345 async fn get_response<T: VmbusMessage + FromBytes + Immutable + KnownLayout>(
2346 &mut self,
2347 ) -> T {
2348 let data = self.message_recv.next().await.unwrap();
2349 let (header, message) = protocol::MessageHeader::read_from_prefix(&data).unwrap(); assert_eq!(T::MESSAGE_TYPE, header.message_type());
2351 T::read_from_prefix(message).unwrap().0 }
2353
2354 fn initiate_contact(
2355 &mut self,
2356 version: protocol::Version,
2357 feature_flags: protocol::FeatureFlags,
2358 trusted: bool,
2359 ) {
2360 self.synic.send_message_core(
2361 OutgoingMessage::new(&protocol::InitiateContact {
2362 version_requested: version as u32,
2363 target_message_vp: 0,
2364 child_to_parent_monitor_page_gpa: 0,
2365 parent_to_child_monitor_page_gpa: 0,
2366 interrupt_page_or_target_info: protocol::TargetInfo::new()
2367 .with_sint(2)
2368 .with_vtl(0)
2369 .with_feature_flags(feature_flags.into())
2370 .into(),
2371 }),
2372 trusted,
2373 );
2374
2375 self.trusted = trusted;
2376 }
2377
2378 async fn connect(
2379 &mut self,
2380 offer_count: u32,
2381 feature_flags: protocol::FeatureFlags,
2382 trusted: bool,
2383 ) {
2384 self.initiate_contact(protocol::Version::Copper, feature_flags, trusted);
2385
2386 self.expect_response(protocol::MessageType::VERSION_RESPONSE)
2387 .await;
2388
2389 self.synic
2390 .send_message_core(OutgoingMessage::new(&protocol::RequestOffers {}), trusted);
2391
2392 for _ in 0..offer_count {
2393 self.expect_response(protocol::MessageType::OFFER_CHANNEL)
2394 .await;
2395 }
2396
2397 self.expect_response(protocol::MessageType::ALL_OFFERS_DELIVERED)
2398 .await;
2399 }
2400 }
2401
2402 #[async_test]
2403 async fn test_save_restore(spawner: DefaultDriver) {
2404 let mut env = TestEnv::new(spawner);
2409 let mut channel = env.offer(1, false).await;
2410 env.vmbus.start();
2411 env.connect(1, protocol::FeatureFlags::new(), false).await;
2412
2413 env.gpadl(1, 10, &mut channel).await;
2415
2416 env.synic.send_message(protocol::GpadlTeardown {
2418 channel_id: ChannelId(1),
2419 gpadl_id: GpadlId(10),
2420 });
2421
2422 let rpc = channel.get_gpadl_teardown().await;
2425 env.vmbus.stop().await;
2426 let saved_state = env.vmbus.save().await;
2427 env.vmbus.start();
2428
2429 rpc.complete(());
2431 env.expect_response(protocol::MessageType::GPADL_TORNDOWN)
2432 .await;
2433
2434 env.synic.send_message(protocol::RelIdReleased {
2435 channel_id: ChannelId(1),
2436 });
2437
2438 env.vmbus.reset().await;
2439 env.vmbus.stop().await;
2440
2441 env.vmbus.restore(saved_state).await.unwrap();
2444 channel.restore().await;
2445 env.vmbus.start();
2446
2447 channel.handle_gpadl_teardown().await;
2449 env.expect_response(protocol::MessageType::GPADL_TORNDOWN)
2450 .await;
2451
2452 env.synic.send_message(protocol::RelIdReleased {
2453 channel_id: ChannelId(1),
2454 });
2455 }
2456
2457 struct TestDeviceState {
2458 id: u32,
2459 started: bool,
2460 resources: Option<DeviceResources>,
2461 open_requests: HashMap<u16, OpenRequest>,
2462 target_vps: HashMap<u16, u32>,
2463 }
2464
2465 impl TestDeviceState {
2466 pub fn id(this: &Arc<Mutex<Self>>) -> u32 {
2467 this.lock().id
2468 }
2469
2470 pub fn started(this: &Arc<Mutex<Self>>) -> bool {
2471 this.lock().started
2472 }
2473 pub fn set_started(this: &Arc<Mutex<Self>>, started: bool) {
2474 this.lock().started = started;
2475 }
2476
2477 pub fn open_request(this: &Arc<Mutex<Self>>, channel_idx: u16) -> Option<OpenRequest> {
2478 this.lock().open_requests.get(&channel_idx).cloned()
2479 }
2480 pub fn set_open_request(
2481 this: &Arc<Mutex<Self>>,
2482 channel_idx: u16,
2483 open_request: OpenRequest,
2484 ) {
2485 assert!(
2486 this.lock()
2487 .open_requests
2488 .insert(channel_idx, open_request)
2489 .is_none()
2490 );
2491 }
2492 pub fn remove_open_request(
2493 this: &Arc<Mutex<Self>>,
2494 channel_idx: u16,
2495 ) -> Option<OpenRequest> {
2496 this.lock().open_requests.remove(&channel_idx)
2497 }
2498
2499 pub fn target_vp(this: &Arc<Mutex<Self>>, channel_idx: u16) -> Option<u32> {
2500 this.lock().target_vps.get(&channel_idx).copied()
2501 }
2502 pub fn set_target_vp(this: &Arc<Mutex<Self>>, channel_idx: u16, target_vp: u32) {
2503 let _ = this.lock().target_vps.insert(channel_idx, target_vp);
2504 }
2505 }
2506
2507 #[derive(InspectMut)]
2508 struct TestDevice {
2509 #[inspect(skip)]
2510 pub state: Arc<Mutex<TestDeviceState>>,
2511 }
2512
2513 impl TestDevice {
2514 pub fn new_and_state(id: u32) -> (Self, Arc<Mutex<TestDeviceState>>) {
2515 let state = TestDeviceState {
2516 id,
2517 resources: None,
2518 open_requests: HashMap::new(),
2519 target_vps: HashMap::new(),
2520 started: false,
2521 };
2522 let state = Arc::new(Mutex::new(state));
2523 let this = Self {
2524 state: state.clone(),
2525 };
2526 (this, state)
2527 }
2528 }
2529
2530 #[async_trait]
2531 impl VmbusDevice for TestDevice {
2532 fn offer(&self) -> OfferParams {
2533 let guid = Guid {
2534 data1: TestDeviceState::id(&self.state),
2535 ..Guid::ZERO
2536 };
2537
2538 OfferParams {
2539 interface_name: "test".into(),
2540 instance_id: guid,
2541 interface_id: guid,
2542 channel_type: vmbus_channel::bus::ChannelType::Device {
2543 pipe_packets: false,
2544 },
2545 ..Default::default()
2546 }
2547 }
2548
2549 fn max_subchannels(&self) -> u16 {
2550 0
2551 }
2552
2553 fn install(&mut self, resources: DeviceResources) {
2554 self.state.lock().resources = Some(resources);
2555 }
2556
2557 async fn open(
2558 &mut self,
2559 channel_idx: u16,
2560 open_request: &OpenRequest,
2561 ) -> Result<(), ChannelOpenError> {
2562 tracing::info!("OPEN");
2563 TestDeviceState::set_open_request(&self.state, channel_idx, open_request.clone());
2564 Ok(())
2565 }
2566
2567 async fn close(&mut self, channel_idx: u16) {
2568 tracing::info!("CLOSE");
2569 assert!(TestDeviceState::remove_open_request(&self.state, channel_idx).is_some());
2570 }
2571
2572 async fn retarget_vp(&mut self, channel_idx: u16, target_vp: u32) {
2573 TestDeviceState::set_target_vp(&self.state, channel_idx, target_vp);
2574 }
2575
2576 fn start(&mut self) {
2577 tracing::info!("START");
2578 TestDeviceState::set_started(&self.state, true);
2579 }
2580
2581 async fn stop(&mut self) {
2582 tracing::info!("STOP");
2583 TestDeviceState::set_started(&self.state, false);
2584 }
2585
2586 fn supports_save_restore(&mut self) -> Option<&mut dyn SaveRestoreVmbusDevice> {
2587 None
2588 }
2589 }
2590
2591 #[async_test]
2592 async fn test_stopped_child(spawner: DefaultDriver) {
2593 let mut env = TestEnv::new(spawner.clone());
2597 let (test_device, test_device_state) = TestDevice::new_and_state(1);
2598 let control = env.vmbus.control();
2599 let channel = offer_channel(&spawner, control.as_ref(), test_device)
2600 .await
2601 .expect("test device failed to offer");
2602
2603 env.vmbus.start();
2604 env.connect(1, protocol::FeatureFlags::new(), false).await;
2605
2606 channel.stop().await;
2608
2609 assert_eq!(TestDeviceState::started(&test_device_state), false);
2610
2611 env.synic.send_message_core(
2614 OutgoingMessage::with_data(
2615 &protocol::GpadlHeader {
2616 channel_id: ChannelId(1),
2617 gpadl_id: GpadlId(1),
2618 count: 1,
2619 len: 16,
2620 },
2621 [1u64, 0u64].as_bytes(),
2622 ),
2623 false,
2624 );
2625 env.expect_response(protocol::MessageType::GPADL_CREATED)
2626 .await;
2627
2628 env.synic.send_message_core(
2630 OutgoingMessage::new(&protocol::OpenChannel {
2631 channel_id: ChannelId(1),
2632 open_id: 0,
2633 ring_buffer_gpadl_id: GpadlId(1),
2634 target_vp: 0,
2635 downstream_ring_buffer_page_offset: 0,
2636 user_data: UserDefinedData::default(),
2637 }),
2638 false,
2639 );
2640 let wait_for_response = mesh::CancelContext::new()
2641 .with_timeout(Duration::from_millis(150))
2642 .until_cancelled(env.expect_response(protocol::MessageType::OPEN_CHANNEL_RESULT))
2643 .await;
2644 assert!(matches!(
2645 wait_for_response,
2646 Err(CancelReason::DeadlineExceeded)
2647 ));
2648 assert!(TestDeviceState::open_request(&test_device_state, 0).is_none());
2649
2650 channel.start();
2652 env.expect_response(protocol::MessageType::OPEN_CHANNEL_RESULT)
2653 .await;
2654 assert!(TestDeviceState::open_request(&test_device_state, 0).is_some());
2655
2656 assert!(TestDeviceState::target_vp(&test_device_state, 0).is_none());
2658 channel.stop().await;
2659 env.synic.send_message_core(
2660 OutgoingMessage::new(&protocol::ModifyChannel {
2661 channel_id: ChannelId(1),
2662 target_vp: 2,
2663 }),
2664 false,
2665 );
2666 let wait_for_response = mesh::CancelContext::new()
2667 .with_timeout(Duration::from_millis(150))
2668 .until_cancelled(env.expect_response(protocol::MessageType::MODIFY_CHANNEL_RESPONSE))
2669 .await;
2670 assert!(matches!(
2671 wait_for_response,
2672 Err(CancelReason::DeadlineExceeded)
2673 ));
2674
2675 channel.start();
2677 env.expect_response(protocol::MessageType::MODIFY_CHANNEL_RESPONSE)
2678 .await;
2679 assert_eq!(
2680 TestDeviceState::target_vp(&test_device_state, 0)
2681 .expect("Modify channel request received"),
2682 2
2683 );
2684
2685 channel.stop().await;
2689 env.vmbus.reset().await;
2690 assert!(TestDeviceState::open_request(&test_device_state, 0).is_none());
2691
2692 env.vmbus.stop().await;
2693 }
2694
2695 #[async_test]
2696 async fn test_confidential_connection(spawner: DefaultDriver) {
2697 let mut env = TestEnv::new(spawner);
2698 let mut channel = env.offer(1, false).await;
2700 let mut channel2 = env.offer(2, true).await;
2701
2702 let (request_send, request_recv) = mesh::channel();
2704 let (server_request_send, server_request_recv) = mesh::channel();
2705 let id = Guid {
2706 data1: 3,
2707 ..Guid::ZERO
2708 };
2709 let control = env.vmbus.control();
2710 let relay_resources = control
2711 .offer_core(OfferInfo {
2712 params: OfferParamsInternal {
2713 interface_name: "test".into(),
2714 instance_id: id,
2715 interface_id: id,
2716 mmio_megabytes: 0,
2717 mmio_megabytes_optional: 0,
2718 subchannel_index: 0,
2719 use_mnf: MnfUsage::Disabled,
2720 offer_order: None,
2721 flags: protocol::OfferFlags::new().with_enumerate_device_interface(true),
2722 ..Default::default()
2723 },
2724 request_send,
2725 server_request_recv,
2726 })
2727 .await
2728 .unwrap();
2729
2730 let mut relay_channel = TestChannel {
2731 request_recv,
2732 server_request_send,
2733 _resources: relay_resources,
2734 };
2735
2736 env.vmbus.start();
2737 env.initiate_contact(
2738 protocol::Version::Copper,
2739 protocol::FeatureFlags::new().with_confidential_channels(true),
2740 true,
2741 );
2742
2743 env.expect_response(protocol::MessageType::VERSION_RESPONSE)
2744 .await;
2745
2746 env.synic.send_message_trusted(protocol::RequestOffers {});
2747
2748 let offer = env.get_response::<protocol::OfferChannel>().await;
2750 assert!(offer.flags.confidential_ring_buffer());
2751 assert!(!offer.flags.confidential_external_memory());
2752 let offer = env.get_response::<protocol::OfferChannel>().await;
2753 assert!(offer.flags.confidential_ring_buffer());
2754 assert!(offer.flags.confidential_external_memory());
2755
2756 let offer = env.get_response::<protocol::OfferChannel>().await;
2758 assert!(!offer.flags.confidential_ring_buffer());
2759 assert!(!offer.flags.confidential_external_memory());
2760
2761 env.expect_response(protocol::MessageType::ALL_OFFERS_DELIVERED)
2762 .await;
2763
2764 env.open_channel(1, 1, &mut channel, |request| {
2767 assert!(request.use_confidential_ring);
2768 assert!(!request.use_confidential_external_memory);
2769 })
2770 .await;
2771
2772 env.open_channel(2, 2, &mut channel2, |request| {
2773 assert!(request.use_confidential_ring);
2774 assert!(request.use_confidential_external_memory);
2775 })
2776 .await;
2777
2778 env.open_channel(3, 3, &mut relay_channel, |request| {
2779 assert!(!request.use_confidential_ring);
2780 assert!(!request.use_confidential_external_memory);
2781 })
2782 .await;
2783 }
2784
2785 #[async_test]
2786 async fn test_confidential_channels_unsupported(spawner: DefaultDriver) {
2787 let mut env = TestEnv::new(spawner);
2788 let mut channel = env.offer(1, false).await;
2789 let mut channel2 = env.offer(2, true).await;
2790
2791 env.vmbus.start();
2792 env.connect(2, protocol::FeatureFlags::new(), true).await;
2793
2794 env.open_channel(1, 1, &mut channel, |request| {
2797 assert!(!request.use_confidential_ring);
2798 assert!(!request.use_confidential_external_memory);
2799 })
2800 .await;
2801
2802 env.open_channel(2, 2, &mut channel2, |request| {
2803 assert!(!request.use_confidential_ring);
2804 assert!(!request.use_confidential_external_memory);
2805 })
2806 .await;
2807 }
2808
2809 #[async_test]
2810 async fn test_confidential_channels_untrusted(spawner: DefaultDriver) {
2811 let mut env = TestEnv::new(spawner);
2812 let mut channel = env.offer(1, false).await;
2813 let mut channel2 = env.offer(2, true).await;
2814
2815 env.vmbus.start();
2816 env.connect(
2819 2,
2820 protocol::FeatureFlags::new().with_confidential_channels(true),
2821 false,
2822 )
2823 .await;
2824
2825 env.open_channel(1, 1, &mut channel, |request| {
2828 assert!(!request.use_confidential_ring);
2829 assert!(!request.use_confidential_external_memory);
2830 })
2831 .await;
2832
2833 env.open_channel(2, 2, &mut channel2, |request| {
2834 assert!(!request.use_confidential_ring);
2835 assert!(!request.use_confidential_external_memory);
2836 })
2837 .await;
2838 }
2839}