1#![expect(missing_docs)]
5#![forbid(unsafe_code)]
6
7mod channel_bitmap;
8mod channels;
9pub mod event;
10pub mod hvsock;
11mod monitor;
12mod proxyintegration;
13
14pub type Guid = guid::Guid;
16
17use anyhow::Context;
18use async_trait::async_trait;
19use channel_bitmap::ChannelBitmap;
20use channels::ConnectionTarget;
21pub use channels::InitiateContactRequest;
22use channels::MessageTarget;
23pub use channels::MnfUsage;
24use channels::ModifyConnectionRequest;
25pub use channels::ModifyConnectionResponse;
26use channels::Notifier;
27use channels::OfferId;
28pub use channels::OfferParamsInternal;
29use channels::OpenParams;
30use channels::RestoreError;
31pub use channels::Update;
32use futures::FutureExt;
33use futures::StreamExt;
34use futures::channel::mpsc;
35use futures::channel::mpsc::SendError;
36use futures::future::OptionFuture;
37use futures::future::poll_fn;
38use futures::stream::SelectAll;
39use guestmem::GuestMemory;
40use hvdef::Vtl;
41use inspect::Inspect;
42use mesh::payload::Protobuf;
43use mesh::rpc::FailableRpc;
44use mesh::rpc::Rpc;
45use mesh::rpc::RpcError;
46use mesh::rpc::RpcSend;
47use pal_async::task::Spawn;
48use pal_async::task::Task;
49use pal_event::Event;
50#[cfg(windows)]
51pub use proxyintegration::ProxyIntegration;
52#[cfg(windows)]
53pub use proxyintegration::ProxyServerInfo;
54use ring::PAGE_SIZE;
55use std::collections::HashMap;
56use std::future;
57use std::future::Future;
58use std::pin::Pin;
59use std::sync::Arc;
60use std::task::Poll;
61use std::task::ready;
62use std::time::Duration;
63use unicycle::FuturesUnordered;
64use vmbus_channel::bus::ChannelRequest;
65use vmbus_channel::bus::ChannelServerRequest;
66use vmbus_channel::bus::GpadlRequest;
67use vmbus_channel::bus::ModifyRequest;
68use vmbus_channel::bus::OfferInput;
69use vmbus_channel::bus::OfferKey;
70use vmbus_channel::bus::OfferResources;
71use vmbus_channel::bus::OpenData;
72use vmbus_channel::bus::OpenRequest;
73use vmbus_channel::bus::OpenResult;
74use vmbus_channel::bus::ParentBus;
75use vmbus_channel::bus::RestoreResult;
76use vmbus_channel::gpadl::GpadlMap;
77use vmbus_channel::gpadl_ring::AlignedGpadlView;
78use vmbus_channel::gpadl_ring::GpadlRingMem;
79use vmbus_core::HvsockConnectRequest;
80use vmbus_core::HvsockConnectResult;
81use vmbus_core::MaxVersionInfo;
82use vmbus_core::OutgoingMessage;
83use vmbus_core::TaggedStream;
84use vmbus_core::VersionInfo;
85use vmbus_core::protocol;
86pub use vmbus_core::protocol::GpadlId;
87#[cfg(windows)]
88use vmbus_proxy::ProxyHandle;
89use vmbus_ring as ring;
90use vmbus_ring::gparange::MultiPagedRangeBuf;
91use vmcore::interrupt::Interrupt;
92use vmcore::save_restore::SavedStateRoot;
93use vmcore::synic::EventPort;
94use vmcore::synic::GuestEventPort;
95use vmcore::synic::GuestMessagePort;
96use vmcore::synic::MessagePort;
97use vmcore::synic::MonitorPageGpas;
98use vmcore::synic::SynicPortAccess;
99
100const SINT: u8 = 2;
101pub const REDIRECT_SINT: u8 = 7;
102pub const REDIRECT_VTL: Vtl = Vtl::Vtl2;
103const SHARED_EVENT_CONNECTION_ID: u32 = 2;
104const EVENT_PORT_ID: u32 = 2;
105const VMBUS_MESSAGE_TYPE: u32 = 1;
106
107const MAX_CONCURRENT_HVSOCK_REQUESTS: usize = 16;
108
109pub struct VmbusServer {
110 task_send: mesh::Sender<VmbusRequest>,
111 control: Arc<VmbusServerControl>,
112 _message_port: Box<dyn Sync + Send>,
113 _multiclient_message_port: Option<Box<dyn Sync + Send>>,
114 task: Task<ServerTask>,
115}
116
117pub struct VmbusServerBuilder<'a, T: Spawn> {
118 spawner: &'a T,
119 synic: Arc<dyn SynicPortAccess>,
120 gm: GuestMemory,
121 private_gm: Option<GuestMemory>,
122 vtl: Vtl,
123 hvsock_notify: Option<HvsockServerChannelHalf>,
124 server_relay: Option<VmbusServerChannelHalf>,
125 external_server: Option<mesh::Sender<InitiateContactRequest>>,
126 external_requests: Option<mesh::Receiver<InitiateContactRequest>>,
127 use_message_redirect: bool,
128 channel_id_offset: u16,
129 max_version: Option<MaxVersionInfo>,
130 delay_max_version: bool,
131 enable_mnf: bool,
132 force_confidential_external_memory: bool,
133 send_messages_while_stopped: bool,
134}
135
136pub struct ServerChannelHalf<Request, Response> {
138 request_send: mesh::Sender<Request>,
139 response_receive: mesh::Receiver<Response>,
140}
141
142pub struct RelayChannelHalf<Request, Response> {
144 pub request_receive: mesh::Receiver<Request>,
145 pub response_send: mesh::Sender<Response>,
146}
147
148pub struct RelayChannel<Request, Response> {
150 pub relay_half: RelayChannelHalf<Request, Response>,
151 pub server_half: ServerChannelHalf<Request, Response>,
152}
153
154impl<Request: 'static + Send, Response: 'static + Send> RelayChannel<Request, Response> {
155 pub fn new() -> Self {
157 let (request_send, request_receive) = mesh::channel();
158 let (response_send, response_receive) = mesh::channel();
159 Self {
160 relay_half: RelayChannelHalf {
161 request_receive,
162 response_send,
163 },
164 server_half: ServerChannelHalf {
165 request_send,
166 response_receive,
167 },
168 }
169 }
170}
171
172pub type VmbusServerChannelHalf = ServerChannelHalf<ModifyRelayRequest, ModifyConnectionResponse>;
173pub type VmbusRelayChannelHalf = RelayChannelHalf<ModifyRelayRequest, ModifyConnectionResponse>;
174pub type VmbusRelayChannel = RelayChannel<ModifyRelayRequest, ModifyConnectionResponse>;
175pub type HvsockServerChannelHalf = ServerChannelHalf<HvsockConnectRequest, HvsockConnectResult>;
176pub type HvsockRelayChannelHalf = RelayChannelHalf<HvsockConnectRequest, HvsockConnectResult>;
177pub type HvsockRelayChannel = RelayChannel<HvsockConnectRequest, HvsockConnectResult>;
178
179#[derive(Debug, Copy, Clone)]
184pub struct ModifyRelayRequest {
185 pub version: Option<u32>,
186 pub monitor_page: Update<MonitorPageGpas>,
187 pub use_interrupt_page: Option<bool>,
188}
189
190impl From<ModifyConnectionRequest> for ModifyRelayRequest {
191 fn from(value: ModifyConnectionRequest) -> Self {
192 Self {
193 version: value.version,
194 monitor_page: value.monitor_page,
195 use_interrupt_page: match value.interrupt_page {
196 Update::Unchanged => None,
197 Update::Reset => Some(false),
198 Update::Set(_) => Some(true),
199 },
200 }
201 }
202}
203
204#[derive(Debug)]
205enum VmbusRequest {
206 Reset(Rpc<(), ()>),
207 Inspect(inspect::Deferred),
208 Save(Rpc<(), SavedState>),
209 Restore(Rpc<SavedState, Result<(), RestoreError>>),
210 PostRestore(Rpc<(), Result<(), RestoreError>>),
211 Start,
212 Stop(Rpc<(), ()>),
213}
214
215#[derive(mesh::MeshPayload, Debug)]
216pub struct OfferInfo {
217 pub params: OfferParamsInternal,
218 pub request_send: mesh::Sender<ChannelRequest>,
219 pub server_request_recv: mesh::Receiver<ChannelServerRequest>,
220}
221
222#[derive(mesh::MeshPayload)]
223pub(crate) enum OfferRequest {
224 Offer(FailableRpc<OfferInfo, ()>),
225 ForceReset(Rpc<(), ()>),
226}
227
228impl Inspect for VmbusServer {
229 fn inspect(&self, req: inspect::Request<'_>) {
230 self.task_send.send(VmbusRequest::Inspect(req.defer()));
231 }
232}
233
234struct ChannelEvent(Interrupt);
235
236impl EventPort for ChannelEvent {
237 fn handle_event(&self, _flag: u16) {
238 self.0.deliver();
239 }
240
241 fn os_event(&self) -> Option<&Event> {
242 self.0.event()
243 }
244}
245
246#[derive(Debug, Protobuf, SavedStateRoot)]
247#[mesh(package = "vmbus.server")]
248pub struct SavedState {
249 #[mesh(1)]
250 server: channels::SavedState,
251 #[mesh(2)]
255 lost_synic_bug_fixed: bool,
256}
257
258const MESSAGE_CONNECTION_ID: u32 = 1;
259const MULTICLIENT_MESSAGE_CONNECTION_ID: u32 = 4;
260
261impl<'a, T: Spawn> VmbusServerBuilder<'a, T> {
262 pub fn new(spawner: &'a T, synic: Arc<dyn SynicPortAccess>, gm: GuestMemory) -> Self {
264 Self {
265 spawner,
266 synic,
267 gm,
268 private_gm: None,
269 vtl: Vtl::Vtl0,
270 hvsock_notify: None,
271 server_relay: None,
272 external_server: None,
273 external_requests: None,
274 use_message_redirect: false,
275 channel_id_offset: 0,
276 max_version: None,
277 delay_max_version: false,
278 enable_mnf: false,
279 force_confidential_external_memory: false,
280 send_messages_while_stopped: false,
281 }
282 }
283
284 pub fn private_gm(mut self, private_gm: Option<GuestMemory>) -> Self {
288 self.private_gm = private_gm;
289 self
290 }
291
292 pub fn vtl(mut self, vtl: Vtl) -> Self {
294 self.vtl = vtl;
295 self
296 }
297
298 pub fn hvsock_notify(mut self, hvsock_notify: Option<HvsockServerChannelHalf>) -> Self {
300 self.hvsock_notify = hvsock_notify;
301 self
302 }
303
304 pub fn server_relay(mut self, server_relay: Option<VmbusServerChannelHalf>) -> Self {
307 self.server_relay = server_relay;
308 self
309 }
310
311 pub fn external_requests(
313 mut self,
314 external_requests: Option<mesh::Receiver<InitiateContactRequest>>,
315 ) -> Self {
316 self.external_requests = external_requests;
317 self
318 }
319
320 pub fn external_server(
323 mut self,
324 external_server: Option<mesh::Sender<InitiateContactRequest>>,
325 ) -> Self {
326 self.external_server = external_server;
327 self
328 }
329
330 pub fn use_message_redirect(mut self, use_message_redirect: bool) -> Self {
332 self.use_message_redirect = use_message_redirect;
333 self
334 }
335
336 pub fn enable_channel_id_offset(mut self, enable: bool) -> Self {
341 self.channel_id_offset = if enable { 1024 } else { 0 };
342 self
343 }
344
345 pub fn max_version(mut self, max_version: Option<MaxVersionInfo>) -> Self {
349 self.max_version = max_version;
350 self
351 }
352
353 pub fn delay_max_version(mut self, delay: bool) -> Self {
358 self.delay_max_version = delay;
359 self
360 }
361
362 pub fn enable_mnf(mut self, enable: bool) -> Self {
366 self.enable_mnf = enable;
367 self
368 }
369
370 pub fn force_confidential_external_memory(mut self, force: bool) -> Self {
373 self.force_confidential_external_memory = force;
374 self
375 }
376
377 pub fn send_messages_while_stopped(mut self, send: bool) -> Self {
384 self.send_messages_while_stopped = send;
385 self
386 }
387
388 pub fn build(self) -> anyhow::Result<VmbusServer> {
393 #[expect(clippy::disallowed_methods)] let (message_send, message_recv) = mpsc::channel(64);
395 let message_sender = Arc::new(MessageSender {
396 send: message_send.clone(),
397 multiclient: self.use_message_redirect,
398 });
399
400 let (redirect_vtl, redirect_sint) = if self.use_message_redirect {
401 (REDIRECT_VTL, REDIRECT_SINT)
402 } else {
403 (self.vtl, SINT)
404 };
405
406 let connection_id = if self.vtl == Vtl::Vtl0 && !self.use_message_redirect {
409 MESSAGE_CONNECTION_ID
410 } else {
411 VmbusServer::get_child_message_connection_id(0, redirect_sint, redirect_vtl)
414 };
415
416 let _message_port = self
417 .synic
418 .add_message_port(connection_id, redirect_vtl, message_sender)
419 .context("failed to create vmbus synic ports")?;
420
421 let _multiclient_message_port = if self.vtl == Vtl::Vtl0 && !self.use_message_redirect {
425 let multiclient_message_sender = Arc::new(MessageSender {
426 send: message_send,
427 multiclient: true,
428 });
429
430 Some(
431 self.synic
432 .add_message_port(
433 MULTICLIENT_MESSAGE_CONNECTION_ID,
434 self.vtl,
435 multiclient_message_sender,
436 )
437 .context("failed to create vmbus synic ports")?,
438 )
439 } else {
440 None
441 };
442
443 let (offer_send, offer_recv) = mesh::mpsc_channel();
444 let control = Arc::new(VmbusServerControl {
445 mem: self.gm.clone(),
446 private_mem: self.private_gm.clone(),
447 send: offer_send,
448 use_event: self.synic.prefer_os_events(),
449 force_confidential_external_memory: self.force_confidential_external_memory,
450 });
451
452 let mut server = channels::Server::new(self.vtl, connection_id, self.channel_id_offset);
453
454 if let Some(version) = self.max_version {
456 server.set_compatibility_version(version, self.delay_max_version);
457 }
458 let (relay_request_send, relay_response_recv) =
459 if let Some(server_relay) = self.server_relay {
460 let r = server_relay.response_receive.boxed().fuse();
461 (server_relay.request_send, r)
462 } else {
463 let (req_send, req_recv) = mesh::channel();
464 let resp_recv = req_recv
465 .map(|_| {
466 ModifyConnectionResponse::Supported(
467 protocol::ConnectionState::SUCCESSFUL,
468 protocol::FeatureFlags::from_bits(u32::MAX),
469 )
470 })
471 .boxed()
472 .fuse();
473 (req_send, resp_recv)
474 };
475
476 let (hvsock_send, hvsock_recv) = if let Some(hvsock_notify) = self.hvsock_notify {
478 let r = hvsock_notify.response_receive.boxed().fuse();
479 (hvsock_notify.request_send, r)
480 } else {
481 let (req_send, req_recv) = mesh::channel();
482 let resp_recv = req_recv
483 .map(|r: HvsockConnectRequest| HvsockConnectResult::from_request(&r, false))
484 .boxed()
485 .fuse();
486 (req_send, resp_recv)
487 };
488
489 let inner = ServerTaskInner {
490 running: false,
491 send_messages_while_stopped: self.send_messages_while_stopped,
492 gm: self.gm,
493 private_gm: self.private_gm,
494 vtl: self.vtl,
495 redirect_vtl,
496 redirect_sint,
497 message_port: self
498 .synic
499 .new_guest_message_port(redirect_vtl, 0, redirect_sint)?,
500 synic: self.synic,
501 hvsock_requests: 0,
502 hvsock_send,
503 channels: HashMap::new(),
504 channel_responses: FuturesUnordered::new(),
505 relay_send: relay_request_send,
506 external_server_send: self.external_server,
507 channel_bitmap: None,
508 shared_event_port: None,
509 reset_done: Vec::new(),
510 enable_mnf: self.enable_mnf,
511 };
512
513 let (task_send, task_recv) = mesh::channel();
514 let mut server_task = ServerTask {
515 server,
516 task_recv,
517 offer_recv,
518 message_recv,
519 server_request_recv: SelectAll::new(),
520 inner,
521 external_requests: self.external_requests,
522 next_seq: 0,
523 unstick_on_start: false,
524 };
525
526 let task = self.spawner.spawn("vmbus server", async move {
527 server_task.run(relay_response_recv, hvsock_recv).await;
528 server_task
529 });
530
531 Ok(VmbusServer {
532 task_send,
533 control,
534 _message_port,
535 _multiclient_message_port,
536 task,
537 })
538 }
539}
540
541impl VmbusServer {
542 pub fn builder<T: Spawn>(
544 spawner: &T,
545 synic: Arc<dyn SynicPortAccess>,
546 gm: GuestMemory,
547 ) -> VmbusServerBuilder<'_, T> {
548 VmbusServerBuilder::new(spawner, synic, gm)
549 }
550
551 pub async fn save(&self) -> SavedState {
552 self.task_send.call(VmbusRequest::Save, ()).await.unwrap()
553 }
554
555 pub async fn restore(&self, state: SavedState) -> Result<(), RestoreError> {
556 self.task_send
557 .call(VmbusRequest::Restore, state)
558 .await
559 .unwrap()
560 }
561
562 pub async fn post_restore(&self) -> Result<(), RestoreError> {
563 self.task_send
564 .call(VmbusRequest::PostRestore, ())
565 .await
566 .unwrap()
567 }
568
569 pub async fn stop(&self) {
571 self.task_send.call(VmbusRequest::Stop, ()).await.unwrap()
572 }
573
574 pub fn start(&self) {
576 self.task_send.send(VmbusRequest::Start);
577 }
578
579 pub async fn reset(&self) {
581 tracing::debug!("resetting channel state");
582 self.task_send.call(VmbusRequest::Reset, ()).await.unwrap()
583 }
584
585 pub async fn shutdown(self) {
587 drop(self.task_send);
588 let _ = self.task.await;
589 }
590
591 pub fn control(&self) -> Arc<VmbusServerControl> {
593 self.control.clone()
594 }
595
596 fn get_child_message_connection_id(vp_index: u32, sint_index: u8, vtl: Vtl) -> u32 {
599 MULTICLIENT_MESSAGE_CONNECTION_ID
600 | (vtl as u32) << 22
601 | vp_index << 8
602 | (sint_index as u32) << 4
603 }
604
605 fn get_child_event_port_id(channel_id: protocol::ChannelId, sint_index: u8, vtl: Vtl) -> u32 {
606 EVENT_PORT_ID | (vtl as u32) << 22 | channel_id.0 << 8 | (sint_index as u32) << 4
607 }
608}
609
610#[derive(mesh::MeshPayload)]
611pub struct RestoreInfo {
612 open_data: Option<OpenData>,
613 gpadls: Vec<(GpadlId, u16, Vec<u64>)>,
614 interrupt: Option<Interrupt>,
615}
616
617#[derive(Default)]
618pub struct SynicMessage {
619 data: Vec<u8>,
620 multiclient: bool,
621 trusted: bool,
622}
623
624struct ServerTask {
625 server: channels::Server,
626 task_recv: mesh::Receiver<VmbusRequest>,
627 offer_recv: mesh::Receiver<OfferRequest>,
628 message_recv: mpsc::Receiver<SynicMessage>,
629 server_request_recv: SelectAll<TaggedStream<OfferId, mesh::Receiver<ChannelServerRequest>>>,
630 inner: ServerTaskInner,
631 external_requests: Option<mesh::Receiver<InitiateContactRequest>>,
632 next_seq: u64,
634 unstick_on_start: bool,
635}
636
637struct ServerTaskInner {
638 running: bool,
639 send_messages_while_stopped: bool,
640 gm: GuestMemory,
641 private_gm: Option<GuestMemory>,
642 synic: Arc<dyn SynicPortAccess>,
643 vtl: Vtl,
644 redirect_vtl: Vtl,
645 redirect_sint: u8,
646 message_port: Box<dyn GuestMessagePort>,
647 hvsock_requests: usize,
648 hvsock_send: mesh::Sender<HvsockConnectRequest>,
649 channels: HashMap<OfferId, Channel>,
650 channel_responses: FuturesUnordered<
651 Pin<Box<dyn Send + Future<Output = (OfferId, u64, Result<ChannelResponse, RpcError>)>>>,
652 >,
653 external_server_send: Option<mesh::Sender<InitiateContactRequest>>,
654 relay_send: mesh::Sender<ModifyRelayRequest>,
655 channel_bitmap: Option<Arc<ChannelBitmap>>,
656 shared_event_port: Option<Box<dyn Send>>,
657 reset_done: Vec<Rpc<(), ()>>,
658 enable_mnf: bool,
659}
660
661#[derive(Debug)]
662enum ChannelResponse {
663 Open(Option<OpenResult>),
664 Close,
665 Gpadl(GpadlId, bool),
666 TeardownGpadl(GpadlId),
667 Modify(i32),
668}
669
670struct Channel {
671 key: OfferKey,
672 send: mesh::Sender<ChannelRequest>,
673 seq: u64,
674 state: ChannelState,
675 gpadls: Arc<GpadlMap>,
676 flags: protocol::OfferFlags,
677 reserved_state: ReservedState,
682}
683
684struct ReservedState {
685 message_port: Option<Box<dyn GuestMessagePort>>,
686 target: ConnectionTarget,
687}
688
689enum ChannelState {
690 Closed,
691 Opening {
692 open_params: OpenParams,
693 guest_event_port: Box<dyn GuestEventPort>,
694 host_to_guest_interrupt: Interrupt,
695 },
696 Open {
697 open_params: OpenParams,
698 _event_port: Box<dyn Send>,
699 guest_event_port: Box<dyn GuestEventPort>,
700 host_to_guest_interrupt: Interrupt,
701 guest_to_host_event: Arc<ChannelEvent>,
702 },
703 Closing,
704 FailedOpen,
705}
706
707impl ServerTask {
708 fn handle_offer(&mut self, mut info: OfferInfo) -> anyhow::Result<()> {
709 let key = info.params.key();
710 let flags = info.params.flags;
711
712 if self.inner.enable_mnf && self.inner.synic.monitor_support().is_some() {
713 if info.params.use_mnf.is_relayed() {
718 info.params.use_mnf = MnfUsage::Enabled {
719 latency: Duration::ZERO,
720 }
721 }
722 } else if info.params.use_mnf.is_enabled() {
723 info.params.use_mnf = MnfUsage::Disabled;
726 }
727
728 let offer_id = self
729 .server
730 .with_notifier(&mut self.inner)
731 .offer_channel(info.params)
732 .context("channel offer failed")?;
733
734 tracing::debug!(?offer_id, %key, "offered channel");
735
736 let id = self.next_seq;
737 self.next_seq += 1;
738 self.inner.channels.insert(
739 offer_id,
740 Channel {
741 key,
742 send: info.request_send,
743 state: ChannelState::Closed,
744 gpadls: GpadlMap::new(),
745 seq: id,
746 flags,
747 reserved_state: ReservedState {
748 message_port: None,
749 target: ConnectionTarget { vp: 0, sint: 0 },
750 },
751 },
752 );
753
754 self.server_request_recv
755 .push(TaggedStream::new(offer_id, info.server_request_recv));
756
757 Ok(())
758 }
759
760 fn handle_revoke(&mut self, offer_id: OfferId) {
761 if self.inner.channels.remove(&offer_id).is_some() {
764 tracing::info!(?offer_id, "revoking channel");
765 self.server
766 .with_notifier(&mut self.inner)
767 .revoke_channel(offer_id);
768 }
769 }
770
771 fn handle_response(
772 &mut self,
773 offer_id: OfferId,
774 seq: u64,
775 response: Result<ChannelResponse, RpcError>,
776 ) {
777 let channel = self
779 .inner
780 .channels
781 .get(&offer_id)
782 .filter(|channel| channel.seq == seq);
783
784 if let Some(channel) = channel {
785 match response {
786 Ok(response) => match response {
787 ChannelResponse::Open(result) => self.handle_open(offer_id, result),
788 ChannelResponse::Close => self.handle_close(offer_id),
789 ChannelResponse::Gpadl(gpadl_id, ok) => {
790 self.handle_gpadl_create(offer_id, gpadl_id, ok)
791 }
792 ChannelResponse::TeardownGpadl(gpadl_id) => {
793 self.handle_gpadl_teardown(offer_id, gpadl_id)
794 }
795 ChannelResponse::Modify(status) => self.handle_modify_channel(offer_id, status),
796 },
797 Err(err) => {
798 tracing::error!(
799 key = %channel.key,
800 error = &err as &dyn std::error::Error,
801 "channel response failure, channel is in inconsistent state until revoked"
802 );
803 }
804 }
805 } else {
806 tracing::debug!(offer_id = ?offer_id, seq, ?response, "received response after revoke");
807 }
808 }
809
810 fn handle_open(&mut self, offer_id: OfferId, result: Option<OpenResult>) {
811 let status = if result.is_some() {
812 0
813 } else {
814 protocol::STATUS_UNSUCCESSFUL
815 };
816 if let Err(err) = self.inner.complete_open(offer_id, result) {
817 tracelimit::error_ratelimited!(
818 error = err.as_ref() as &dyn std::error::Error,
819 "failed to complete open"
820 );
821 self.inner.notify(offer_id, channels::Action::Close);
825 } else {
826 self.server
827 .with_notifier(&mut self.inner)
828 .open_complete(offer_id, status);
829 }
830 }
831
832 fn handle_close(&mut self, offer_id: OfferId) {
833 let channel = self
834 .inner
835 .channels
836 .get_mut(&offer_id)
837 .expect("channel still exists");
838
839 match &mut channel.state {
840 ChannelState::Closing => {
841 channel.state = ChannelState::Closed;
842 self.server
843 .with_notifier(&mut self.inner)
844 .close_complete(offer_id);
845 }
846 ChannelState::FailedOpen => {
847 channel.state = ChannelState::Closed;
850 self.server
851 .with_notifier(&mut self.inner)
852 .open_complete(offer_id, protocol::STATUS_UNSUCCESSFUL);
853 }
854 _ => {
855 tracing::error!(?offer_id, "invalid close channel response");
856 }
857 };
858 }
859
860 fn handle_gpadl_create(&mut self, offer_id: OfferId, gpadl_id: GpadlId, ok: bool) {
861 let status = if ok { 0 } else { protocol::STATUS_UNSUCCESSFUL };
862 self.server
863 .with_notifier(&mut self.inner)
864 .gpadl_create_complete(offer_id, gpadl_id, status);
865 }
866
867 fn handle_gpadl_teardown(&mut self, offer_id: OfferId, gpadl_id: GpadlId) {
868 self.server
869 .with_notifier(&mut self.inner)
870 .gpadl_teardown_complete(offer_id, gpadl_id);
871 }
872
873 fn handle_modify_channel(&mut self, offer_id: OfferId, status: i32) {
874 self.server
875 .with_notifier(&mut self.inner)
876 .modify_channel_complete(offer_id, status);
877 }
878
879 fn handle_restore_channel(
880 &mut self,
881 offer_id: OfferId,
882 open: Option<OpenResult>,
883 ) -> anyhow::Result<RestoreResult> {
884 let gpadls = self.server.channel_gpadls(offer_id);
885
886 let open_request = open
889 .map(|result| -> anyhow::Result<_> {
890 let params = self.server.get_restore_open_params(offer_id)?;
891 let (_, interrupt) = self.inner.open_channel(offer_id, ¶ms)?;
892 let channel = self.inner.complete_open(offer_id, Some(result))?;
893 Ok(OpenRequest::new(
894 params.open_data,
895 interrupt,
896 self.server
897 .get_version()
898 .expect("must be connected")
899 .feature_flags,
900 channel.flags,
901 ))
902 })
903 .transpose()?;
904
905 self.server
906 .with_notifier(&mut self.inner)
907 .restore_channel(offer_id, open_request.is_some())?;
908
909 let channel = self.inner.channels.get_mut(&offer_id).unwrap();
910 for gpadl in &gpadls {
911 if let Ok(buf) =
912 MultiPagedRangeBuf::new(gpadl.request.count.into(), gpadl.request.buf.clone())
913 {
914 channel.gpadls.add(gpadl.request.id, buf);
915 }
916 }
917
918 let result = RestoreResult {
919 open_request,
920 gpadls,
921 };
922 Ok(result)
923 }
924
925 fn handle_request(&mut self, request: VmbusRequest) {
926 tracing::debug!(?request, "handle_request");
927 match request {
928 VmbusRequest::Reset(rpc) => self.handle_reset(rpc),
929 VmbusRequest::Inspect(deferred) => {
930 deferred.respond(|resp| {
931 resp.field("message_port", &self.inner.message_port)
932 .field("running", self.inner.running)
933 .field("hvsock_requests", self.inner.hvsock_requests)
934 .field_mut_with("unstick_channels", |v| {
935 let v: inspect::Value = if let Some(v) = v {
936 if v == "force" {
937 self.unstick_channels(true);
938 v.into()
939 } else {
940 let v =
941 v.parse().ok().context("expected false, true, or force")?;
942 if v {
943 self.unstick_channels(false);
944 }
945 v.into()
946 }
947 } else {
948 false.into()
949 };
950 anyhow::Ok(v)
951 })
952 .merge(&self.server.with_notifier(&mut self.inner));
953 });
954 }
955 VmbusRequest::Save(rpc) => rpc.handle_sync(|()| SavedState {
956 server: self.server.save(),
957 lost_synic_bug_fixed: true,
958 }),
959 VmbusRequest::Restore(rpc) => rpc.handle_sync(|state| {
960 self.unstick_on_start = !state.lost_synic_bug_fixed;
961 self.server.restore(state.server)
962 }),
963 VmbusRequest::PostRestore(rpc) => {
964 rpc.handle_sync(|()| self.server.with_notifier(&mut self.inner).post_restore())
965 }
966 VmbusRequest::Stop(rpc) => rpc.handle_sync(|()| {
967 if self.inner.running {
968 self.inner.running = false;
969 }
970 }),
971 VmbusRequest::Start => {
972 if !self.inner.running {
973 self.inner.running = true;
974 if self.unstick_on_start {
975 tracing::info!(
976 "lost synic bug fix is not in yet, call unstick_channels to mitigate the issue."
977 );
978 self.unstick_channels(false);
979 self.unstick_on_start = false;
980 }
981 }
982 }
983 }
984 }
985
986 fn handle_reset(&mut self, rpc: Rpc<(), ()>) {
987 let needs_reset = self.inner.reset_done.is_empty();
988 self.inner.reset_done.push(rpc);
989 if needs_reset {
990 self.server.with_notifier(&mut self.inner).reset();
991 }
992 }
993
994 fn handle_relay_response(&mut self, response: ModifyConnectionResponse) {
995 self.server
996 .with_notifier(&mut self.inner)
997 .complete_modify_connection(response);
998 }
999
1000 fn handle_tl_connect_result(&mut self, result: HvsockConnectResult) {
1001 assert_ne!(self.inner.hvsock_requests, 0);
1002 self.inner.hvsock_requests -= 1;
1003
1004 self.server
1005 .with_notifier(&mut self.inner)
1006 .send_tl_connect_result(result);
1007 }
1008
1009 fn handle_synic_message(&mut self, message: SynicMessage) {
1010 match self
1011 .server
1012 .with_notifier(&mut self.inner)
1013 .handle_synic_message(message)
1014 {
1015 Ok(()) => {}
1016 Err(err) => {
1017 tracing::warn!(
1018 error = &err as &dyn std::error::Error,
1019 "synic message error"
1020 );
1021 }
1022 }
1023 }
1024
1025 fn handle_external_request(&mut self, request: InitiateContactRequest) {
1032 self.server
1033 .with_notifier(&mut self.inner)
1034 .initiate_contact(request);
1035 }
1036
1037 async fn run(
1038 &mut self,
1039 mut relay_response_recv: impl futures::stream::FusedStream<Item = ModifyConnectionResponse>
1040 + Unpin,
1041 mut hvsock_recv: impl futures::stream::FusedStream<Item = HvsockConnectResult> + Unpin,
1042 ) {
1043 loop {
1044 let running_not_resetting = self.inner.running && self.inner.reset_done.is_empty();
1049 let mut external_requests = OptionFuture::from(
1050 running_not_resetting
1051 .then(|| {
1052 self.external_requests
1053 .as_mut()
1054 .map(|r| r.select_next_some())
1055 })
1056 .flatten(),
1057 );
1058
1059 let has_pending_messages = self.server.has_pending_messages();
1061 let message_port = self.inner.message_port.as_mut();
1062 let mut flush_pending_messages =
1063 OptionFuture::from((running_not_resetting && has_pending_messages).then(|| {
1064 poll_fn(|cx| {
1065 self.server.poll_flush_pending_messages(|msg| {
1066 message_port.poll_post_message(cx, VMBUS_MESSAGE_TYPE, msg.data())
1067 })
1068 })
1069 .fuse()
1070 }));
1071
1072 let mut message_recv = OptionFuture::from(
1076 (running_not_resetting
1077 && !has_pending_messages
1078 && self.inner.hvsock_requests < MAX_CONCURRENT_HVSOCK_REQUESTS)
1079 .then(|| self.message_recv.select_next_some()),
1080 );
1081
1082 let mut channel_response = OptionFuture::from(
1084 (self.inner.running || !self.inner.reset_done.is_empty())
1085 .then(|| self.inner.channel_responses.select_next_some()),
1086 );
1087
1088 let mut hvsock_response =
1090 OptionFuture::from(running_not_resetting.then(|| hvsock_recv.select_next_some()));
1091
1092 futures::select! { r = self.task_recv.recv().fuse() => {
1094 if let Ok(request) = r {
1095 self.handle_request(request);
1096 } else {
1097 break;
1098 }
1099 }
1100 r = self.offer_recv.select_next_some() => {
1101 match r {
1102 OfferRequest::Offer(rpc) => {
1103 rpc.handle_failable_sync(|request| { self.handle_offer(request) })
1104 },
1105 OfferRequest::ForceReset(rpc) => {
1106 self.handle_reset(rpc);
1107 }
1108 }
1109 }
1110 r = self.server_request_recv.select_next_some() => {
1111 match r {
1112 (id, Some(request)) => match request {
1113 ChannelServerRequest::Restore(rpc) => rpc.handle_failable_sync(|open| {
1114 self.handle_restore_channel(id, open)
1115 }),
1116 ChannelServerRequest::Revoke(rpc) => rpc.handle_sync(|_| {
1117 self.handle_revoke(id);
1118 })
1119 },
1120 (id, None) => self.handle_revoke(id),
1121 }
1122 }
1123 r = channel_response => {
1124 let (id, seq, response) = r.unwrap();
1125 self.handle_response(id, seq, response);
1126 }
1127 r = relay_response_recv.select_next_some() => {
1128 self.handle_relay_response(r);
1129 },
1130 r = hvsock_response => {
1131 self.handle_tl_connect_result(r.unwrap());
1132 }
1133 data = message_recv => {
1134 let data = data.unwrap();
1135 self.handle_synic_message(data);
1136 }
1137 r = external_requests => {
1138 let r = r.unwrap();
1139 self.handle_external_request(r);
1140 }
1141 _r = flush_pending_messages => {}
1142 complete => break,
1143 }
1144 }
1145 }
1146
1147 fn unstick_channels(&self, force: bool) {
1151 for channel in self.inner.channels.values() {
1152 if let Err(err) = self.unstick_channel(channel, force) {
1153 tracing::warn!(
1154 channel = %channel.key,
1155 error = err.as_ref() as &dyn std::error::Error,
1156 "could not unstick channel"
1157 );
1158 }
1159 }
1160 }
1161
1162 fn unstick_channel(&self, channel: &Channel, force: bool) -> anyhow::Result<()> {
1163 if let ChannelState::Open {
1164 open_params,
1165 host_to_guest_interrupt,
1166 guest_to_host_event,
1167 ..
1168 } = &channel.state
1169 {
1170 if force {
1171 tracing::info!(channel = %channel.key, "waking host and guest");
1172 guest_to_host_event.0.deliver();
1173 host_to_guest_interrupt.deliver();
1174 return Ok(());
1175 }
1176
1177 let gpadl = channel
1178 .gpadls
1179 .clone()
1180 .view()
1181 .map(open_params.open_data.ring_gpadl_id)
1182 .context("couldn't find ring gpadl")?;
1183
1184 let aligned = AlignedGpadlView::new(gpadl)
1185 .ok()
1186 .context("ring not aligned")?;
1187 let (in_gpadl, out_gpadl) = aligned
1188 .split(open_params.open_data.ring_offset)
1189 .ok()
1190 .context("couldn't split ring")?;
1191
1192 if let Err(err) = self.unstick_incoming_ring(
1193 channel,
1194 in_gpadl,
1195 guest_to_host_event,
1196 host_to_guest_interrupt,
1197 ) {
1198 tracing::warn!(
1199 channel = %channel.key,
1200 error = err.as_ref() as &dyn std::error::Error,
1201 "could not unstick incoming ring"
1202 );
1203 }
1204 if let Err(err) = self.unstick_outgoing_ring(
1205 channel,
1206 out_gpadl,
1207 guest_to_host_event,
1208 host_to_guest_interrupt,
1209 ) {
1210 tracing::warn!(
1211 channel = %channel.key,
1212 error = err.as_ref() as &dyn std::error::Error,
1213 "could not unstick outgoing ring"
1214 );
1215 }
1216 }
1217 Ok(())
1218 }
1219
1220 fn unstick_incoming_ring(
1221 &self,
1222 channel: &Channel,
1223 in_gpadl: AlignedGpadlView,
1224 guest_to_host_event: &ChannelEvent,
1225 host_to_guest_interrupt: &Interrupt,
1226 ) -> Result<(), anyhow::Error> {
1227 let incoming_mem = GpadlRingMem::new(in_gpadl, &self.inner.gm)?;
1228 if ring::reader_needs_signal(&incoming_mem) {
1229 tracing::info!(channel = %channel.key, "waking host for incoming ring");
1230 guest_to_host_event.0.deliver();
1231 }
1232 if ring::writer_needs_signal(&incoming_mem) {
1233 tracing::info!(channel = %channel.key, "waking guest for incoming ring");
1234 host_to_guest_interrupt.deliver();
1235 }
1236 Ok(())
1237 }
1238
1239 fn unstick_outgoing_ring(
1240 &self,
1241 channel: &Channel,
1242 out_gpadl: AlignedGpadlView,
1243 guest_to_host_event: &ChannelEvent,
1244 host_to_guest_interrupt: &Interrupt,
1245 ) -> Result<(), anyhow::Error> {
1246 let outgoing_mem = GpadlRingMem::new(out_gpadl, &self.inner.gm)?;
1247 if ring::reader_needs_signal(&outgoing_mem) {
1248 tracing::info!(channel = %channel.key, "waking guest for outgoing ring");
1249 host_to_guest_interrupt.deliver();
1250 }
1251 if ring::writer_needs_signal(&outgoing_mem) {
1252 tracing::info!(channel = %channel.key, "waking host for outgoing ring");
1253 guest_to_host_event.0.deliver();
1254 }
1255 Ok(())
1256 }
1257}
1258
1259impl Notifier for ServerTaskInner {
1260 fn notify(&mut self, offer_id: OfferId, action: channels::Action) {
1261 let channel = self
1262 .channels
1263 .get_mut(&offer_id)
1264 .expect("channel does not exist");
1265
1266 fn handle<I: 'static + Send, R: 'static + Send>(
1267 offer_id: OfferId,
1268 channel: &Channel,
1269 req: impl FnOnce(Rpc<I, R>) -> ChannelRequest,
1270 input: I,
1271 f: impl 'static + Send + FnOnce(R) -> ChannelResponse,
1272 ) -> Pin<Box<dyn Send + Future<Output = (OfferId, u64, Result<ChannelResponse, RpcError>)>>>
1273 {
1274 let recv = channel.send.call(req, input);
1275 let seq = channel.seq;
1276 Box::pin(async move {
1277 let r = recv.await.map(f);
1278 (offer_id, seq, r)
1279 })
1280 }
1281
1282 let response = match action {
1283 channels::Action::Open(open_params, version) => {
1284 let seq = channel.seq;
1285 match self.open_channel(offer_id, &open_params) {
1286 Ok((channel, interrupt)) => handle(
1287 offer_id,
1288 channel,
1289 ChannelRequest::Open,
1290 OpenRequest::new(
1291 open_params.open_data,
1292 interrupt,
1293 version.feature_flags,
1294 channel.flags,
1295 ),
1296 ChannelResponse::Open,
1297 ),
1298 Err(err) => {
1299 tracelimit::error_ratelimited!(
1300 err = err.as_ref() as &dyn std::error::Error,
1301 ?offer_id,
1302 "could not open channel",
1303 );
1304
1305 Box::pin(future::ready((
1308 offer_id,
1309 seq,
1310 Ok(ChannelResponse::Open(None)),
1311 )))
1312 }
1313 }
1314 }
1315 channels::Action::Close => {
1316 if let Some(channel_bitmap) = self.channel_bitmap.as_ref() {
1317 if let ChannelState::Open { open_params, .. } = channel.state {
1318 channel_bitmap.unregister_channel(open_params.event_flag);
1319 }
1320 }
1321
1322 channel.state = ChannelState::Closing;
1323 handle(offer_id, channel, ChannelRequest::Close, (), |()| {
1324 ChannelResponse::Close
1325 })
1326 }
1327 channels::Action::Gpadl(gpadl_id, count, buf) => {
1328 channel.gpadls.add(
1329 gpadl_id,
1330 MultiPagedRangeBuf::new(count.into(), buf.clone()).unwrap(),
1331 );
1332 handle(
1333 offer_id,
1334 channel,
1335 ChannelRequest::Gpadl,
1336 GpadlRequest {
1337 id: gpadl_id,
1338 count,
1339 buf,
1340 },
1341 move |r| ChannelResponse::Gpadl(gpadl_id, r),
1342 )
1343 }
1344 channels::Action::TeardownGpadl {
1345 gpadl_id,
1346 post_restore,
1347 } => {
1348 if !post_restore {
1349 channel.gpadls.remove(gpadl_id, Box::new(|| ()));
1350 }
1351
1352 handle(
1353 offer_id,
1354 channel,
1355 ChannelRequest::TeardownGpadl,
1356 gpadl_id,
1357 move |()| ChannelResponse::TeardownGpadl(gpadl_id),
1358 )
1359 }
1360 channels::Action::Modify { target_vp } => {
1361 if let ChannelState::Open {
1362 guest_event_port, ..
1363 } = &mut channel.state
1364 {
1365 if let Err(err) = guest_event_port.set_target_vp(target_vp) {
1366 tracelimit::error_ratelimited!(
1367 error = &err as &dyn std::error::Error,
1368 channel = %channel.key,
1369 "could not modify channel",
1370 );
1371 let seq = channel.seq;
1372 Box::pin(async move {
1373 (
1374 offer_id,
1375 seq,
1376 Ok(ChannelResponse::Modify(protocol::STATUS_UNSUCCESSFUL)),
1377 )
1378 })
1379 } else {
1380 handle(
1381 offer_id,
1382 channel,
1383 ChannelRequest::Modify,
1384 ModifyRequest::TargetVp { target_vp },
1385 ChannelResponse::Modify,
1386 )
1387 }
1388 } else {
1389 unreachable!();
1390 }
1391 }
1392 };
1393 self.channel_responses.push(response);
1394 }
1395
1396 fn modify_connection(&mut self, mut request: ModifyConnectionRequest) -> anyhow::Result<()> {
1397 self.map_interrupt_page(request.interrupt_page)
1398 .context("Failed to map interrupt page.")?;
1399
1400 self.set_monitor_page(request.monitor_page, request.force)
1401 .context("Failed to map monitor page.")?;
1402
1403 if let Some(vp) = request.target_message_vp {
1404 self.message_port.set_target_vp(vp)?;
1405 }
1406
1407 if request.notify_relay {
1408 if self.enable_mnf {
1413 request.monitor_page = Update::Unchanged;
1414 }
1415
1416 self.relay_send.send(request.into());
1417 }
1418
1419 Ok(())
1420 }
1421
1422 fn forward_unhandled(&mut self, request: InitiateContactRequest) {
1423 if let Some(external_server) = &self.external_server_send {
1424 external_server.send(request);
1425 } else {
1426 tracing::warn!(?request, "nowhere to forward unhandled request")
1427 }
1428 }
1429
1430 fn inspect(&self, version: Option<VersionInfo>, offer_id: OfferId, req: inspect::Request<'_>) {
1431 let channel = self.channels.get(&offer_id).expect("should exist");
1432 let mut resp = req.respond();
1433 if let ChannelState::Open { open_params, .. } = &channel.state {
1434 let mem = if self.private_gm.is_some()
1435 && channel.flags.confidential_ring_buffer()
1436 && version
1437 .expect("must be connected")
1438 .feature_flags
1439 .confidential_channels()
1440 {
1441 self.private_gm.as_ref().unwrap()
1442 } else {
1443 &self.gm
1444 };
1445
1446 inspect_rings(
1447 &mut resp,
1448 mem,
1449 channel.gpadls.clone(),
1450 &open_params.open_data,
1451 );
1452 }
1453 }
1454
1455 fn send_message(&mut self, message: &OutgoingMessage, target: MessageTarget) -> bool {
1456 if !self.running && !self.send_messages_while_stopped {
1469 if !matches!(target, MessageTarget::Default) {
1470 tracelimit::error_ratelimited!(?target, "dropping message while paused");
1471 }
1472 return false;
1473 }
1474
1475 let mut port_storage;
1476 let port = match target {
1477 MessageTarget::Default => self.message_port.as_mut(),
1478 MessageTarget::ReservedChannel(offer_id, target) => {
1479 if let Some(port) = self.get_reserved_channel_message_port(offer_id, target) {
1480 port.as_mut()
1481 } else {
1482 return true;
1484 }
1485 }
1486 MessageTarget::Custom(target) => {
1487 port_storage = match self.synic.new_guest_message_port(
1488 self.redirect_vtl,
1489 target.vp,
1490 target.sint,
1491 ) {
1492 Ok(port) => port,
1493 Err(err) => {
1494 tracing::error!(
1495 ?err,
1496 ?self.redirect_vtl,
1497 ?target,
1498 "could not create message port"
1499 );
1500
1501 return true;
1503 }
1504 };
1505 port_storage.as_mut()
1506 }
1507 };
1508
1509 matches!(
1512 port.poll_post_message(
1513 &mut std::task::Context::from_waker(std::task::Waker::noop()),
1514 VMBUS_MESSAGE_TYPE,
1515 message.data()
1516 ),
1517 Poll::Ready(())
1518 )
1519 }
1520
1521 fn notify_hvsock(&mut self, request: &HvsockConnectRequest) {
1522 self.hvsock_requests += 1;
1523 self.hvsock_send.send(*request);
1524 }
1525
1526 fn reset_complete(&mut self) {
1527 if let Some(monitor) = self.synic.monitor_support() {
1528 if let Err(err) = monitor.set_monitor_page(self.vtl, None) {
1529 tracing::warn!(?err, "resetting monitor page failed")
1530 }
1531 }
1532
1533 self.unreserve_channels();
1534 for done in self.reset_done.drain(..) {
1535 done.complete(());
1536 }
1537 }
1538
1539 fn unload_complete(&mut self) {
1540 self.unreserve_channels();
1541 }
1542}
1543
1544impl ServerTaskInner {
1545 fn open_channel(
1546 &mut self,
1547 offer_id: OfferId,
1548 open_params: &OpenParams,
1549 ) -> anyhow::Result<(&mut Channel, Interrupt)> {
1550 let channel = self
1551 .channels
1552 .get_mut(&offer_id)
1553 .expect("channel does not exist");
1554
1555 let (target_vp, event_flag) = if self.channel_bitmap.is_some() {
1558 (0, 0)
1559 } else {
1560 (open_params.open_data.target_vp, open_params.event_flag)
1561 };
1562 let (target_vtl, target_sint) = if open_params.flags.redirect_interrupt() {
1563 (self.redirect_vtl, self.redirect_sint)
1564 } else {
1565 (self.vtl, SINT)
1566 };
1567
1568 let guest_event_port = self.synic.new_guest_event_port(
1569 VmbusServer::get_child_event_port_id(open_params.channel_id, SINT, self.vtl),
1570 target_vtl,
1571 target_vp,
1572 target_sint,
1573 event_flag,
1574 open_params.monitor_info,
1575 )?;
1576
1577 let interrupt = ChannelBitmap::create_interrupt(
1578 &self.channel_bitmap,
1579 guest_event_port.interrupt(),
1580 open_params.event_flag,
1581 );
1582
1583 channel.reserved_state.message_port = None;
1585
1586 if let Some(target) = open_params.reserved_target {
1588 channel.reserved_state.message_port = Some(self.synic.new_guest_message_port(
1589 self.redirect_vtl,
1590 target.vp,
1591 target.sint,
1592 )?);
1593
1594 channel.reserved_state.target = target;
1595 }
1596
1597 channel.state = ChannelState::Opening {
1598 open_params: *open_params,
1599 guest_event_port,
1600 host_to_guest_interrupt: interrupt.clone(),
1601 };
1602 Ok((channel, interrupt))
1603 }
1604
1605 fn complete_open(
1606 &mut self,
1607 offer_id: OfferId,
1608 result: Option<OpenResult>,
1609 ) -> anyhow::Result<&mut Channel> {
1610 let channel = self
1611 .channels
1612 .get_mut(&offer_id)
1613 .expect("channel does not exist");
1614
1615 channel.state = if let Some(result) = result {
1616 match std::mem::replace(&mut channel.state, ChannelState::FailedOpen) {
1619 ChannelState::Opening {
1620 open_params,
1621 guest_event_port,
1622 host_to_guest_interrupt,
1623 } => {
1624 let guest_to_host_event =
1625 Arc::new(ChannelEvent(result.guest_to_host_interrupt));
1626 if let Some(channel_bitmap) = self.channel_bitmap.as_ref() {
1628 channel_bitmap.register_channel(
1629 open_params.event_flag,
1630 guest_to_host_event.0.clone(),
1631 );
1632 }
1633 let event_port = self
1635 .synic
1636 .add_event_port(
1637 open_params.connection_id,
1638 self.vtl,
1639 guest_to_host_event.clone(),
1640 open_params.monitor_info,
1641 )
1642 .with_context(|| {
1643 format!(
1644 "failed to create event port for VTL {:?}, connection ID {:#x}",
1645 self.vtl, open_params.connection_id
1646 )
1647 })?;
1648
1649 ChannelState::Open {
1650 open_params,
1651 _event_port: event_port,
1652 guest_event_port,
1653 host_to_guest_interrupt,
1654 guest_to_host_event,
1655 }
1656 }
1657 s => {
1658 tracing::error!("attempting to complete open of open or closed channel");
1659 s
1661 }
1662 }
1663 } else {
1664 ChannelState::Closed
1665 };
1666 Ok(channel)
1667 }
1668
1669 fn map_interrupt_page(&mut self, interrupt_page: Update<u64>) -> anyhow::Result<()> {
1672 let interrupt_page = match interrupt_page {
1673 Update::Unchanged => return Ok(()),
1674 Update::Reset => {
1675 self.channel_bitmap = None;
1676 self.shared_event_port = None;
1677 return Ok(());
1678 }
1679 Update::Set(interrupt_page) => interrupt_page,
1680 };
1681
1682 assert_ne!(interrupt_page, 0);
1683
1684 if interrupt_page % PAGE_SIZE as u64 != 0 {
1685 anyhow::bail!("interrupt page {:#x} is not page aligned", interrupt_page);
1686 }
1687
1688 let interrupt_page = self
1691 .gm
1692 .lockable_subrange(interrupt_page, PAGE_SIZE as u64)?
1693 .lock_gpns(false, &[0])?;
1694
1695 let channel_bitmap = Arc::new(ChannelBitmap::new(interrupt_page));
1696 self.channel_bitmap = Some(channel_bitmap.clone());
1697
1698 let interrupt = Interrupt::from_fn(move || {
1700 channel_bitmap.handle_shared_interrupt();
1701 });
1702
1703 self.shared_event_port = Some(self.synic.add_event_port(
1704 SHARED_EVENT_CONNECTION_ID,
1705 self.vtl,
1706 Arc::new(ChannelEvent(interrupt)),
1707 None,
1708 )?);
1709
1710 Ok(())
1711 }
1712
1713 fn set_monitor_page(
1714 &mut self,
1715 monitor_page: Update<MonitorPageGpas>,
1716 force: bool,
1717 ) -> anyhow::Result<()> {
1718 let monitor_page = match monitor_page {
1719 Update::Unchanged => return Ok(()),
1720 Update::Reset => None,
1721 Update::Set(value) => Some(value),
1722 };
1723
1724 if !force
1727 && self.channels.iter().any(|(_, c)| {
1728 matches!(
1729 &c.state,
1730 ChannelState::Open {
1731 open_params,
1732 ..
1733 } | ChannelState::Opening {
1734 open_params,
1735 ..
1736 } if open_params.monitor_info.is_some()
1737 )
1738 })
1739 {
1740 anyhow::bail!("attempt to change monitor page while open channels using mnf");
1741 }
1742
1743 if self.enable_mnf {
1744 if let Some(monitor) = self.synic.monitor_support() {
1745 if let Err(err) = monitor.set_monitor_page(self.vtl, monitor_page) {
1746 anyhow::bail!(
1747 "setting monitor page failed, err = {err:?}, monitor_page = {monitor_page:?}"
1748 );
1749 }
1750 }
1751 }
1752
1753 Ok(())
1754 }
1755
1756 fn get_reserved_channel_message_port(
1757 &mut self,
1758 offer_id: OfferId,
1759 new_target: ConnectionTarget,
1760 ) -> Option<&mut Box<dyn GuestMessagePort>> {
1761 let channel = self
1762 .channels
1763 .get_mut(&offer_id)
1764 .expect("channel does not exist");
1765
1766 assert!(
1767 channel.reserved_state.message_port.is_some(),
1768 "channel is not reserved"
1769 );
1770
1771 if channel.reserved_state.target.sint != new_target.sint {
1774 channel.reserved_state.message_port = None;
1776 let message_port = self
1777 .synic
1778 .new_guest_message_port(self.redirect_vtl, new_target.vp, new_target.sint)
1779 .inspect_err(|err| {
1780 tracing::error!(
1781 ?err,
1782 ?self.redirect_vtl,
1783 ?new_target,
1784 "could not create reserved channel message port"
1785 )
1786 })
1787 .ok()?;
1788
1789 channel.reserved_state.message_port = Some(message_port);
1790 channel.reserved_state.target = new_target;
1791 } else if channel.reserved_state.target.vp != new_target.vp {
1792 let message_port = channel.reserved_state.message_port.as_mut().unwrap();
1793
1794 if let Err(err) = message_port.set_target_vp(new_target.vp) {
1797 tracing::error!(
1798 ?err,
1799 ?self.redirect_vtl,
1800 ?new_target,
1801 "could not update reserved channel message port"
1802 );
1803 }
1804
1805 channel.reserved_state.target = new_target;
1806 return Some(message_port);
1807 }
1808
1809 Some(channel.reserved_state.message_port.as_mut().unwrap())
1810 }
1811
1812 fn unreserve_channels(&mut self) {
1813 for channel in self.channels.values_mut() {
1815 if let ChannelState::Closed = channel.state {
1816 channel.reserved_state.message_port = None;
1817 }
1818 }
1819 }
1820}
1821
1822#[derive(Clone)]
1824pub struct VmbusServerControl {
1825 mem: GuestMemory,
1826 private_mem: Option<GuestMemory>,
1827 send: mesh::Sender<OfferRequest>,
1828 use_event: bool,
1829 force_confidential_external_memory: bool,
1830}
1831
1832impl VmbusServerControl {
1833 pub async fn offer_core(&self, offer_info: OfferInfo) -> anyhow::Result<OfferResources> {
1836 let flags = offer_info.params.flags;
1837 self.send
1838 .call_failable(OfferRequest::Offer, offer_info)
1839 .await?;
1840 Ok(OfferResources::new(
1841 self.mem.clone(),
1842 if flags.confidential_ring_buffer() || flags.confidential_external_memory() {
1843 self.private_mem.clone()
1844 } else {
1845 None
1846 },
1847 ))
1848 }
1849
1850 pub async fn force_reset(&self) -> anyhow::Result<()> {
1853 self.send
1854 .call(OfferRequest::ForceReset, ())
1855 .await
1856 .context("vmbus server is gone")
1857 }
1858
1859 async fn offer(&self, request: OfferInput) -> anyhow::Result<OfferResources> {
1860 let mut offer_info = OfferInfo {
1861 params: request.params.into(),
1862 request_send: request.request_send,
1863 server_request_recv: request.server_request_recv,
1864 };
1865
1866 if self.force_confidential_external_memory {
1867 tracing::warn!(
1868 key = %offer_info.params.key(),
1869 "forcing confidential external memory for channel"
1870 );
1871
1872 offer_info
1873 .params
1874 .flags
1875 .set_confidential_external_memory(true);
1876 }
1877
1878 self.offer_core(offer_info).await
1879 }
1880}
1881
1882fn inspect_rings(
1884 resp: &mut inspect::Response<'_>,
1885 gm: &GuestMemory,
1886 gpadl_map: Arc<GpadlMap>,
1887 open_data: &OpenData,
1888) -> Option<()> {
1889 let gpadl = gpadl_map
1890 .view()
1891 .map(GpadlId(open_data.ring_gpadl_id.0))
1892 .ok()?;
1893 let aligned = AlignedGpadlView::new(gpadl).ok()?;
1894 let (in_gpadl, out_gpadl) = aligned.split(open_data.ring_offset).ok()?;
1895 if let Ok(incoming_mem) = GpadlRingMem::new(in_gpadl, gm) {
1896 resp.child("incoming_ring", |req| ring::inspect_ring(incoming_mem, req));
1897 }
1898 if let Ok(outgoing_mem) = GpadlRingMem::new(out_gpadl, gm) {
1899 resp.child("outgoing_ring", |req| ring::inspect_ring(outgoing_mem, req));
1900 }
1901 Some(())
1902}
1903
1904pub(crate) struct MessageSender {
1905 send: mpsc::Sender<SynicMessage>,
1906 multiclient: bool,
1907}
1908
1909impl MessageSender {
1910 fn poll_handle_message(
1911 &self,
1912 cx: &mut std::task::Context<'_>,
1913 msg: &[u8],
1914 trusted: bool,
1915 ) -> Poll<Result<(), SendError>> {
1916 let mut send = self.send.clone();
1917 ready!(send.poll_ready(cx))?;
1918 send.start_send(SynicMessage {
1919 data: msg.to_vec(),
1920 multiclient: self.multiclient,
1921 trusted,
1922 })?;
1923
1924 Poll::Ready(Ok(()))
1925 }
1926}
1927
1928impl MessagePort for MessageSender {
1929 fn poll_handle_message(
1930 &self,
1931 cx: &mut std::task::Context<'_>,
1932 msg: &[u8],
1933 trusted: bool,
1934 ) -> Poll<()> {
1935 if let Err(err) = ready!(self.poll_handle_message(cx, msg, trusted)) {
1936 tracelimit::error_ratelimited!(
1937 error = &err as &dyn std::error::Error,
1938 "failed to send message"
1939 );
1940 }
1941
1942 Poll::Ready(())
1943 }
1944}
1945
1946#[async_trait]
1947impl ParentBus for VmbusServerControl {
1948 async fn add_child(&self, request: OfferInput) -> anyhow::Result<OfferResources> {
1949 self.offer(request).await
1950 }
1951
1952 fn clone_bus(&self) -> Box<dyn ParentBus> {
1953 Box::new(self.clone())
1954 }
1955
1956 fn use_event(&self) -> bool {
1957 self.use_event
1958 }
1959}
1960
1961#[cfg(test)]
1962mod tests {
1963 use super::*;
1964 use pal_async::DefaultDriver;
1965 use pal_async::async_test;
1966 use pal_async::driver::SpawnDriver;
1967 use pal_async::timer::Instant;
1968 use pal_async::timer::PolledTimer;
1969 use parking_lot::Mutex;
1970 use protocol::UserDefinedData;
1971 use std::time::Duration;
1972 use test_with_tracing::test;
1973 use vmbus_channel::bus::OfferParams;
1974 use vmbus_core::protocol::ChannelId;
1975 use vmbus_core::protocol::VmbusMessage;
1976 use vmcore::synic::MonitorInfo;
1977 use vmcore::synic::SynicPortAccess;
1978 use zerocopy::FromBytes;
1979 use zerocopy::Immutable;
1980 use zerocopy::IntoBytes;
1981 use zerocopy::KnownLayout;
1982
1983 struct MockSynicInner {
1984 message_port: Option<Arc<dyn MessagePort>>,
1985 }
1986
1987 struct MockSynic {
1988 inner: Mutex<MockSynicInner>,
1989 message_send: mesh::Sender<Vec<u8>>,
1990 spawner: Arc<dyn SpawnDriver>,
1991 }
1992
1993 impl MockSynic {
1994 fn new(message_send: mesh::Sender<Vec<u8>>, spawner: Arc<dyn SpawnDriver>) -> Self {
1995 Self {
1996 inner: Mutex::new(MockSynicInner { message_port: None }),
1997 message_send,
1998 spawner,
1999 }
2000 }
2001
2002 fn send_message(&self, msg: impl VmbusMessage + IntoBytes + Immutable + KnownLayout) {
2003 self.send_message_core(OutgoingMessage::new(&msg), false);
2004 }
2005
2006 fn send_message_trusted(
2007 &self,
2008 msg: impl VmbusMessage + IntoBytes + Immutable + KnownLayout,
2009 ) {
2010 self.send_message_core(OutgoingMessage::new(&msg), true);
2011 }
2012
2013 fn send_message_core(&self, msg: OutgoingMessage, trusted: bool) {
2014 assert_eq!(
2015 self.inner
2016 .lock()
2017 .message_port
2018 .as_ref()
2019 .unwrap()
2020 .poll_handle_message(
2021 &mut std::task::Context::from_waker(std::task::Waker::noop()),
2022 msg.data(),
2023 trusted,
2024 ),
2025 Poll::Ready(())
2026 );
2027 }
2028 }
2029
2030 #[derive(Debug)]
2031 struct MockGuestPort {}
2032
2033 impl GuestEventPort for MockGuestPort {
2034 fn interrupt(&self) -> Interrupt {
2035 Interrupt::null()
2036 }
2037
2038 fn set_target_vp(&mut self, _vp: u32) -> Result<(), vmcore::synic::HypervisorError> {
2039 Ok(())
2040 }
2041 }
2042
2043 struct MockGuestMessagePort {
2044 send: mesh::Sender<Vec<u8>>,
2045 spawner: Arc<dyn SpawnDriver>,
2046 timer: Option<(PolledTimer, Instant)>,
2047 }
2048
2049 impl GuestMessagePort for MockGuestMessagePort {
2050 fn poll_post_message(
2051 &mut self,
2052 cx: &mut std::task::Context<'_>,
2053 _typ: u32,
2054 payload: &[u8],
2055 ) -> Poll<()> {
2056 if let Some((timer, deadline)) = self.timer.as_mut() {
2057 ready!(timer.sleep_until(*deadline).poll_unpin(cx));
2058 self.timer = None;
2059 }
2060
2061 let mut pending_chance = [0; 1];
2063 getrandom::fill(&mut pending_chance).unwrap();
2064 if pending_chance[0] % 4 == 0 {
2065 let mut timer = PolledTimer::new(self.spawner.as_ref());
2066 let deadline = Instant::now() + Duration::from_millis(10);
2067 match timer.sleep_until(deadline).poll_unpin(cx) {
2068 Poll::Ready(_) => {}
2069 Poll::Pending => {
2070 self.timer = Some((timer, deadline));
2071 return Poll::Pending;
2072 }
2073 }
2074 }
2075
2076 self.send.send(payload.into());
2077 Poll::Ready(())
2078 }
2079
2080 fn set_target_vp(&mut self, _vp: u32) -> Result<(), vmcore::synic::HypervisorError> {
2081 Ok(())
2082 }
2083 }
2084
2085 impl Inspect for MockGuestMessagePort {
2086 fn inspect(&self, _req: inspect::Request<'_>) {}
2087 }
2088
2089 impl SynicPortAccess for MockSynic {
2090 fn add_message_port(
2091 &self,
2092 connection_id: u32,
2093 _minimum_vtl: Vtl,
2094 port: Arc<dyn MessagePort>,
2095 ) -> Result<Box<dyn Sync + Send>, vmcore::synic::Error> {
2096 self.inner.lock().message_port = Some(port);
2097 Ok(Box::new(connection_id))
2098 }
2099
2100 fn add_event_port(
2101 &self,
2102 connection_id: u32,
2103 _minimum_vtl: Vtl,
2104 _port: Arc<dyn EventPort>,
2105 _monitor_info: Option<MonitorInfo>,
2106 ) -> Result<Box<dyn Sync + Send>, vmcore::synic::Error> {
2107 Ok(Box::new(connection_id))
2108 }
2109
2110 fn new_guest_message_port(
2111 &self,
2112 _vtl: Vtl,
2113 _vp: u32,
2114 _sint: u8,
2115 ) -> Result<Box<(dyn GuestMessagePort)>, vmcore::synic::HypervisorError> {
2116 Ok(Box::new(MockGuestMessagePort {
2117 send: self.message_send.clone(),
2118 spawner: Arc::clone(&self.spawner),
2119 timer: None,
2120 }))
2121 }
2122
2123 fn new_guest_event_port(
2124 &self,
2125 _port_id: u32,
2126 _vtl: Vtl,
2127 _vp: u32,
2128 _sint: u8,
2129 _flag: u16,
2130 _monitor_info: Option<MonitorInfo>,
2131 ) -> Result<Box<(dyn GuestEventPort)>, vmcore::synic::HypervisorError> {
2132 Ok(Box::new(MockGuestPort {}))
2133 }
2134
2135 fn prefer_os_events(&self) -> bool {
2136 false
2137 }
2138 }
2139
2140 struct TestChannel {
2141 request_recv: mesh::Receiver<ChannelRequest>,
2142 server_request_send: mesh::Sender<ChannelServerRequest>,
2143 _resources: OfferResources,
2144 }
2145
2146 impl TestChannel {
2147 async fn next_request(&mut self) -> ChannelRequest {
2148 self.request_recv.next().await.unwrap()
2149 }
2150
2151 async fn handle_gpadl(&mut self) {
2152 let ChannelRequest::Gpadl(rpc) = self.next_request().await else {
2153 panic!("Wrong request");
2154 };
2155
2156 rpc.complete(true);
2157 }
2158
2159 async fn handle_open(&mut self, f: fn(&OpenRequest)) {
2160 let ChannelRequest::Open(rpc) = self.next_request().await else {
2161 panic!("Wrong request");
2162 };
2163
2164 f(rpc.input());
2165 rpc.complete(Some(OpenResult {
2166 guest_to_host_interrupt: Interrupt::null(),
2167 }));
2168 }
2169
2170 async fn handle_gpadl_teardown(&mut self) {
2171 let rpc = self.get_gpadl_teardown().await;
2172 rpc.complete(());
2173 }
2174
2175 async fn get_gpadl_teardown(&mut self) -> Rpc<GpadlId, ()> {
2176 let ChannelRequest::TeardownGpadl(rpc) = self.next_request().await else {
2177 panic!("Wrong request");
2178 };
2179
2180 rpc
2181 }
2182
2183 async fn restore(&self) {
2184 self.server_request_send
2185 .call(ChannelServerRequest::Restore, None)
2186 .await
2187 .unwrap()
2188 .unwrap();
2189 }
2190 }
2191
2192 struct TestEnv {
2193 vmbus: VmbusServer,
2194 synic: Arc<MockSynic>,
2195 message_recv: mesh::Receiver<Vec<u8>>,
2196 trusted: bool,
2197 }
2198
2199 impl TestEnv {
2200 fn new(spawner: DefaultDriver) -> Self {
2201 let spawner: Arc<dyn SpawnDriver> = Arc::new(spawner);
2202 let (message_send, message_recv) = mesh::channel();
2203 let synic = Arc::new(MockSynic::new(message_send, Arc::clone(&spawner)));
2204 let gm = GuestMemory::empty();
2205 let vmbus = VmbusServerBuilder::new(&spawner, synic.clone(), gm)
2206 .build()
2207 .unwrap();
2208
2209 Self {
2210 vmbus,
2211 synic,
2212 message_recv,
2213 trusted: false,
2214 }
2215 }
2216
2217 async fn offer(&self, id: u32, allow_confidential_external_memory: bool) -> TestChannel {
2218 let guid = Guid {
2219 data1: id,
2220 ..Guid::ZERO
2221 };
2222 let (request_send, request_recv) = mesh::channel();
2223 let (server_request_send, server_request_recv) = mesh::channel();
2224 let offer = OfferInput {
2225 request_send,
2226 server_request_recv,
2227 params: OfferParams {
2228 interface_name: "test".into(),
2229 instance_id: guid,
2230 interface_id: guid,
2231 mmio_megabytes: 0,
2232 mmio_megabytes_optional: 0,
2233 channel_type: vmbus_channel::bus::ChannelType::Device {
2234 pipe_packets: false,
2235 },
2236 subchannel_index: 0,
2237 mnf_interrupt_latency: None,
2238 offer_order: None,
2239 allow_confidential_external_memory,
2240 },
2241 };
2242
2243 let control = self.vmbus.control();
2244 let _resources = control.add_child(offer).await.unwrap();
2245
2246 TestChannel {
2247 request_recv,
2248 server_request_send,
2249 _resources,
2250 }
2251 }
2252
2253 async fn gpadl(&mut self, channel_id: u32, gpadl_id: u32, channel: &mut TestChannel) {
2254 self.synic.send_message_core(
2255 OutgoingMessage::with_data(
2256 &protocol::GpadlHeader {
2257 channel_id: ChannelId(channel_id),
2258 gpadl_id: GpadlId(gpadl_id),
2259 count: 1,
2260 len: 16,
2261 },
2262 [1u64, 0u64].as_bytes(),
2263 ),
2264 self.trusted,
2265 );
2266
2267 channel.handle_gpadl().await;
2268 self.expect_response(protocol::MessageType::GPADL_CREATED)
2269 .await;
2270 }
2271
2272 async fn open_channel(
2273 &mut self,
2274 channel_id: u32,
2275 ring_gpadl_id: u32,
2276 channel: &mut TestChannel,
2277 f: fn(&OpenRequest),
2278 ) {
2279 self.gpadl(channel_id, ring_gpadl_id, channel).await;
2280 self.synic.send_message_core(
2281 OutgoingMessage::new(&protocol::OpenChannel {
2282 channel_id: ChannelId(channel_id),
2283 open_id: 0,
2284 ring_buffer_gpadl_id: GpadlId(ring_gpadl_id),
2285 target_vp: 0,
2286 downstream_ring_buffer_page_offset: 0,
2287 user_data: UserDefinedData::default(),
2288 }),
2289 self.trusted,
2290 );
2291
2292 channel.handle_open(f).await;
2293 self.expect_response(protocol::MessageType::OPEN_CHANNEL_RESULT)
2294 .await;
2295 }
2296
2297 async fn expect_response(&mut self, expected: protocol::MessageType) {
2298 let data = self.message_recv.next().await.unwrap();
2299 let header = protocol::MessageHeader::read_from_prefix(&data).unwrap().0; assert_eq!(expected, header.message_type())
2301 }
2302
2303 async fn get_response<T: VmbusMessage + FromBytes + Immutable + KnownLayout>(
2304 &mut self,
2305 ) -> T {
2306 let data = self.message_recv.next().await.unwrap();
2307 let (header, message) = protocol::MessageHeader::read_from_prefix(&data).unwrap(); assert_eq!(T::MESSAGE_TYPE, header.message_type());
2309 T::read_from_prefix(message).unwrap().0 }
2311
2312 fn initiate_contact(
2313 &mut self,
2314 version: protocol::Version,
2315 feature_flags: protocol::FeatureFlags,
2316 trusted: bool,
2317 ) {
2318 self.synic.send_message_core(
2319 OutgoingMessage::new(&protocol::InitiateContact {
2320 version_requested: version as u32,
2321 target_message_vp: 0,
2322 child_to_parent_monitor_page_gpa: 0,
2323 parent_to_child_monitor_page_gpa: 0,
2324 interrupt_page_or_target_info: protocol::TargetInfo::new()
2325 .with_sint(2)
2326 .with_vtl(0)
2327 .with_feature_flags(feature_flags.into())
2328 .into(),
2329 }),
2330 trusted,
2331 );
2332
2333 self.trusted = trusted;
2334 }
2335
2336 async fn connect(
2337 &mut self,
2338 offer_count: u32,
2339 feature_flags: protocol::FeatureFlags,
2340 trusted: bool,
2341 ) {
2342 self.initiate_contact(protocol::Version::Copper, feature_flags, trusted);
2343
2344 self.expect_response(protocol::MessageType::VERSION_RESPONSE)
2345 .await;
2346
2347 self.synic
2348 .send_message_core(OutgoingMessage::new(&protocol::RequestOffers {}), trusted);
2349
2350 for _ in 0..offer_count {
2351 self.expect_response(protocol::MessageType::OFFER_CHANNEL)
2352 .await;
2353 }
2354
2355 self.expect_response(protocol::MessageType::ALL_OFFERS_DELIVERED)
2356 .await;
2357 }
2358 }
2359
2360 #[async_test]
2361 async fn test_save_restore(spawner: DefaultDriver) {
2362 let mut env = TestEnv::new(spawner);
2367 let mut channel = env.offer(1, false).await;
2368 env.vmbus.start();
2369 env.connect(1, protocol::FeatureFlags::new(), false).await;
2370
2371 env.gpadl(1, 10, &mut channel).await;
2373
2374 env.synic.send_message(protocol::GpadlTeardown {
2376 channel_id: ChannelId(1),
2377 gpadl_id: GpadlId(10),
2378 });
2379
2380 let rpc = channel.get_gpadl_teardown().await;
2383 env.vmbus.stop().await;
2384 let saved_state = env.vmbus.save().await;
2385 env.vmbus.start();
2386
2387 rpc.complete(());
2389 env.expect_response(protocol::MessageType::GPADL_TORNDOWN)
2390 .await;
2391
2392 env.synic.send_message(protocol::RelIdReleased {
2393 channel_id: ChannelId(1),
2394 });
2395
2396 env.vmbus.reset().await;
2397 env.vmbus.stop().await;
2398
2399 env.vmbus.restore(saved_state).await.unwrap();
2402 channel.restore().await;
2403 env.vmbus.post_restore().await.unwrap();
2404 env.vmbus.start();
2405
2406 channel.handle_gpadl_teardown().await;
2408 env.expect_response(protocol::MessageType::GPADL_TORNDOWN)
2409 .await;
2410
2411 env.synic.send_message(protocol::RelIdReleased {
2412 channel_id: ChannelId(1),
2413 });
2414 }
2415
2416 #[async_test]
2417 async fn test_confidential_connection(spawner: DefaultDriver) {
2418 let mut env = TestEnv::new(spawner);
2419 let mut channel = env.offer(1, false).await;
2421 let mut channel2 = env.offer(2, true).await;
2422
2423 let (request_send, request_recv) = mesh::channel();
2425 let (server_request_send, server_request_recv) = mesh::channel();
2426 let id = Guid {
2427 data1: 3,
2428 ..Guid::ZERO
2429 };
2430 let control = env.vmbus.control();
2431 let relay_resources = control
2432 .offer_core(OfferInfo {
2433 params: OfferParamsInternal {
2434 interface_name: "test".into(),
2435 instance_id: id,
2436 interface_id: id,
2437 mmio_megabytes: 0,
2438 mmio_megabytes_optional: 0,
2439 subchannel_index: 0,
2440 use_mnf: MnfUsage::Disabled,
2441 offer_order: None,
2442 flags: protocol::OfferFlags::new().with_enumerate_device_interface(true),
2443 ..Default::default()
2444 },
2445 request_send,
2446 server_request_recv,
2447 })
2448 .await
2449 .unwrap();
2450
2451 let mut relay_channel = TestChannel {
2452 request_recv,
2453 server_request_send,
2454 _resources: relay_resources,
2455 };
2456
2457 env.vmbus.start();
2458 env.initiate_contact(
2459 protocol::Version::Copper,
2460 protocol::FeatureFlags::new().with_confidential_channels(true),
2461 true,
2462 );
2463
2464 env.expect_response(protocol::MessageType::VERSION_RESPONSE)
2465 .await;
2466
2467 env.synic.send_message_trusted(protocol::RequestOffers {});
2468
2469 let offer = env.get_response::<protocol::OfferChannel>().await;
2471 assert!(offer.flags.confidential_ring_buffer());
2472 assert!(!offer.flags.confidential_external_memory());
2473 let offer = env.get_response::<protocol::OfferChannel>().await;
2474 assert!(offer.flags.confidential_ring_buffer());
2475 assert!(offer.flags.confidential_external_memory());
2476
2477 let offer = env.get_response::<protocol::OfferChannel>().await;
2479 assert!(!offer.flags.confidential_ring_buffer());
2480 assert!(!offer.flags.confidential_external_memory());
2481
2482 env.expect_response(protocol::MessageType::ALL_OFFERS_DELIVERED)
2483 .await;
2484
2485 env.open_channel(1, 1, &mut channel, |request| {
2488 assert!(request.use_confidential_ring);
2489 assert!(!request.use_confidential_external_memory);
2490 })
2491 .await;
2492
2493 env.open_channel(2, 2, &mut channel2, |request| {
2494 assert!(request.use_confidential_ring);
2495 assert!(request.use_confidential_external_memory);
2496 })
2497 .await;
2498
2499 env.open_channel(3, 3, &mut relay_channel, |request| {
2500 assert!(!request.use_confidential_ring);
2501 assert!(!request.use_confidential_external_memory);
2502 })
2503 .await;
2504 }
2505
2506 #[async_test]
2507 async fn test_confidential_channels_unsupported(spawner: DefaultDriver) {
2508 let mut env = TestEnv::new(spawner);
2509 let mut channel = env.offer(1, false).await;
2510 let mut channel2 = env.offer(2, true).await;
2511
2512 env.vmbus.start();
2513 env.connect(2, protocol::FeatureFlags::new(), true).await;
2514
2515 env.open_channel(1, 1, &mut channel, |request| {
2518 assert!(!request.use_confidential_ring);
2519 assert!(!request.use_confidential_external_memory);
2520 })
2521 .await;
2522
2523 env.open_channel(2, 2, &mut channel2, |request| {
2524 assert!(!request.use_confidential_ring);
2525 assert!(!request.use_confidential_external_memory);
2526 })
2527 .await;
2528 }
2529
2530 #[async_test]
2531 async fn test_confidential_channels_untrusted(spawner: DefaultDriver) {
2532 let mut env = TestEnv::new(spawner);
2533 let mut channel = env.offer(1, false).await;
2534 let mut channel2 = env.offer(2, true).await;
2535
2536 env.vmbus.start();
2537 env.connect(
2540 2,
2541 protocol::FeatureFlags::new().with_confidential_channels(true),
2542 false,
2543 )
2544 .await;
2545
2546 env.open_channel(1, 1, &mut channel, |request| {
2549 assert!(!request.use_confidential_ring);
2550 assert!(!request.use_confidential_external_memory);
2551 })
2552 .await;
2553
2554 env.open_channel(2, 2, &mut channel2, |request| {
2555 assert!(!request.use_confidential_ring);
2556 assert!(!request.use_confidential_external_memory);
2557 })
2558 .await;
2559 }
2560}