1#![expect(missing_docs)]
7#![forbid(unsafe_code)]
8
9pub mod driver;
10pub mod filter;
11mod hvsock;
12pub mod saved_state;
13
14pub use self::saved_state::SavedState;
15use anyhow::Context as _;
16use anyhow::Result;
17use futures::FutureExt;
18use futures::StreamExt;
19use futures::future::OptionFuture;
20use futures::stream::SelectAll;
21use futures_concurrency::future::Race;
22use guid::Guid;
23use inspect::Inspect;
24use mesh::rpc::FailableRpc;
25use mesh::rpc::Rpc;
26use mesh::rpc::RpcSend;
27use pal_async::task::Spawn;
28use pal_async::task::Task;
29use pal_event::Event;
30use std::collections::HashMap;
31use std::collections::VecDeque;
32use std::collections::hash_map;
33use std::convert::TryInto;
34use std::future::Future;
35use std::future::poll_fn;
36use std::ops::Deref;
37use std::ops::DerefMut;
38use std::pin::pin;
39use std::sync::Arc;
40use std::sync::atomic::AtomicU32;
41use std::sync::atomic::Ordering;
42use std::task::Context;
43use std::task::Poll;
44use thiserror::Error;
45use vmbus_async::async_dgram::AsyncRecv;
46use vmbus_async::async_dgram::AsyncRecvExt;
47use vmbus_channel::bus::GpadlRequest;
48use vmbus_channel::bus::ModifyRequest;
49use vmbus_channel::bus::OfferKey;
50use vmbus_channel::bus::OpenData;
51use vmbus_channel::gpadl::GpadlId;
52use vmbus_core::HvsockConnectRequest;
53use vmbus_core::OutgoingMessage;
54use vmbus_core::TaggedStream;
55use vmbus_core::VersionInfo;
56use vmbus_core::protocol;
57use vmbus_core::protocol::ChannelId;
58use vmbus_core::protocol::ConnectionState;
59use vmbus_core::protocol::FeatureFlags;
60use vmbus_core::protocol::Message;
61use vmbus_core::protocol::OpenChannelFlags;
62use vmbus_core::protocol::Version;
63use vmcore::interrupt::Interrupt;
64use vmcore::synic::MonitorPageGpas;
65use zerocopy::Immutable;
66use zerocopy::IntoBytes;
67use zerocopy::KnownLayout;
68
69const SINT: u8 = 2;
70const VTL: u8 = 0;
71const SUPPORTED_VERSIONS: &[Version] = &[Version::Iron, Version::Copper];
72const SUPPORTED_FEATURE_FLAGS: FeatureFlags = FeatureFlags::new()
73 .with_guest_specified_signal_parameters(true)
74 .with_channel_interrupt_redirection(true)
75 .with_modify_connection(true)
76 .with_client_id(true)
77 .with_pause_resume(true);
78
79pub trait SynicEventClient: Send + Sync {
81 fn map_event(&self, event_flag: u16, event: &Event) -> std::io::Result<()>;
83
84 fn unmap_event(&self, event_flag: u16);
86
87 fn signal_event(&self, connection_id: u32, event_flag: u16) -> std::io::Result<()>;
89}
90
91pub trait VmbusMessageSource: AsyncRecv + Send {
93 fn pause_message_stream(&mut self) {}
96
97 fn resume_message_stream(&mut self) {}
99}
100
101pub trait PollPostMessage: Send {
102 fn poll_post_message(
103 &mut self,
104 cx: &mut Context<'_>,
105 connection_id: u32,
106 typ: u32,
107 msg: &[u8],
108 ) -> Poll<()>;
109}
110
111#[derive(Inspect)]
112pub struct VmbusClient {
113 #[inspect(flatten, send = "TaskRequest::Inspect")]
114 task_send: mesh::Sender<TaskRequest>,
115 #[inspect(skip)]
116 access: VmbusClientAccess,
117 #[inspect(skip)]
118 task: Task<ClientTask>,
119}
120
121#[derive(Debug, thiserror::Error)]
122pub enum ConnectError {
123 #[error("invalid state to connect to the server")]
124 InvalidState,
125 #[error("no supported protocol versions")]
126 NoSupportedVersions,
127 #[error("failed to connect to the server: {0:?}")]
128 FailedToConnect(ConnectionState),
129}
130
131#[derive(Clone)]
132pub struct VmbusClientAccess {
133 client_request_send: mesh::Sender<ClientRequest>,
134}
135
136pub struct VmbusClientBuilder {
138 event_client: Arc<dyn SynicEventClient>,
139 msg_source: Box<dyn VmbusMessageSource>,
140 msg_client: Box<dyn PollPostMessage>,
141}
142
143impl VmbusClientBuilder {
144 pub fn new(
146 event_client: impl SynicEventClient + 'static,
147 msg_source: impl VmbusMessageSource + 'static,
148 msg_client: impl PollPostMessage + 'static,
149 ) -> Self {
150 Self {
151 event_client: Arc::new(event_client),
152 msg_source: Box::new(msg_source),
153 msg_client: Box::new(msg_client),
154 }
155 }
156
157 pub fn build(self, spawner: &impl Spawn) -> VmbusClient {
159 let (task_send, task_recv) = mesh::channel();
160 let (client_request_send, client_request_recv) = mesh::channel();
161
162 let inner = ClientTaskInner {
163 messages: OutgoingMessages {
164 poster: self.msg_client,
165 queued: VecDeque::new(),
166 state: OutgoingMessageState::Paused,
167 },
168 teardown_gpadls: HashMap::new(),
169 channel_requests: SelectAll::new(),
170 synic: SynicState {
171 event_flag_state: Vec::new(),
172 event_client: self.event_client,
173 },
174 };
175
176 let mut task = ClientTask {
177 inner,
178 channels: ChannelList::default(),
179 task_recv,
180 running: false,
181 msg_source: self.msg_source,
182 client_request_recv,
183 state: ClientState::Disconnected,
184 modify_request: None,
185 hvsock_tracker: hvsock::HvsockRequestTracker::new(),
186 };
187
188 let task = spawner.spawn("vmbus client", async move {
189 task.run().await;
190 task
191 });
192
193 VmbusClient {
194 access: VmbusClientAccess {
195 client_request_send,
196 },
197 task_send,
198 task,
199 }
200 }
201}
202
203impl VmbusClient {
204 pub async fn connect(
207 &mut self,
208 target_message_vp: u32,
209 monitor_page: Option<MonitorPageGpas>,
210 client_id: Guid,
211 ) -> Result<ConnectResult, ConnectError> {
212 let request = ConnectRequest {
213 target_message_vp,
214 monitor_page,
215 client_id,
216 };
217
218 self.access
219 .client_request_send
220 .call(ClientRequest::Connect, request)
221 .await
222 .unwrap()
223 }
224
225 pub async fn unload(self) {
226 self.access
227 .client_request_send
228 .call(ClientRequest::Unload, ())
229 .await
230 .unwrap();
231
232 self.sever().await;
233 }
234
235 pub fn access(&self) -> &VmbusClientAccess {
236 &self.access
237 }
238
239 pub fn start(&mut self) {
240 self.task_send.send(TaskRequest::Start);
241 }
242
243 pub async fn stop(&mut self) {
244 self.task_send
245 .call(TaskRequest::Stop, ())
246 .await
247 .expect("Failed to send stop request");
248 }
249
250 pub async fn save(&self) -> SavedState {
251 self.task_send
252 .call(TaskRequest::Save, ())
253 .await
254 .expect("Failed to send save request")
255 }
256
257 pub async fn restore(
258 &mut self,
259 state: SavedState,
260 ) -> Result<Option<ConnectResult>, RestoreError> {
261 self.task_send
262 .call(TaskRequest::Restore, state)
263 .await
264 .expect("Failed to send restore request")
265 }
266
267 pub async fn post_restore(&mut self) {
268 self.task_send
269 .call(TaskRequest::PostRestore, ())
270 .await
271 .expect("Failed to send post-restore request");
272 }
273
274 async fn sever(self) -> VmbusClientBuilder {
275 drop(self.task_send);
276 let task = self.task.await;
277 VmbusClientBuilder {
278 event_client: task.inner.synic.event_client,
279 msg_source: task.msg_source,
280 msg_client: task.inner.messages.poster,
281 }
282 }
283}
284
285#[derive(Debug)]
286pub struct ConnectResult {
287 pub version: VersionInfo,
288 pub offers: Vec<OfferInfo>,
289 pub offer_recv: mesh::Receiver<OfferInfo>,
290}
291
292impl VmbusClientAccess {
293 pub async fn modify(&self, request: ModifyConnectionRequest) -> ConnectionState {
294 self.client_request_send
295 .call(ClientRequest::Modify, request)
296 .await
297 .expect("Failed to send modify request")
298 }
299
300 pub fn connect_hvsock(
301 &self,
302 request: HvsockConnectRequest,
303 ) -> impl Future<Output = Option<OfferInfo>> + use<> {
304 self.client_request_send
305 .call(ClientRequest::HvsockConnect, request)
306 .map(|r| r.ok().flatten())
307 }
308}
309
310#[derive(Debug)]
311pub struct OpenRequest {
312 pub open_data: OpenData,
313 pub incoming_event: Option<Event>,
314 pub use_vtl2_connection_id: bool,
315}
316
317#[derive(Debug)]
318pub struct RestoreRequest {
319 pub incoming_event: Option<Event>,
320 pub redirected_event_flag: Option<u16>,
322 pub connection_id: u32,
324}
325
326pub enum ChannelRequest {
328 Open(FailableRpc<OpenRequest, OpenOutput>),
329 Restore(FailableRpc<RestoreRequest, OpenOutput>),
330 Close(Rpc<(), ()>),
331 Gpadl(FailableRpc<GpadlRequest, ()>),
332 TeardownGpadl(Rpc<GpadlId, ()>),
333 Modify(Rpc<ModifyRequest, i32>),
334}
335
336#[derive(Debug)]
337pub struct OpenOutput {
338 pub redirected_event_flag: Option<u16>,
340}
341
342impl std::fmt::Display for ChannelRequest {
343 fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
344 let s = match self {
345 ChannelRequest::Open(_) => "Open",
346 ChannelRequest::Close(_) => "Close",
347 ChannelRequest::Restore(_) => "Restore",
348 ChannelRequest::Gpadl(_) => "Gpadl",
349 ChannelRequest::TeardownGpadl(_) => "TeardownGpadl",
350 ChannelRequest::Modify(_) => "Modify",
351 };
352 fmt.pad(s)
353 }
354}
355
356#[derive(Debug, Error)]
357pub enum RestoreError {
358 #[error("unsupported protocol version {0:#x}")]
359 UnsupportedVersion(u32),
360
361 #[error("unsupported feature flags {0:#x}")]
362 UnsupportedFeatureFlags(u32),
363
364 #[error("duplicate channel id {0}")]
365 DuplicateChannelId(u32),
366
367 #[error("duplicate gpadl id {0}")]
368 DuplicateGpadlId(u32),
369
370 #[error("gpadl for unknown channel id {0}")]
371 GpadlForUnknownChannelId(u32),
372
373 #[error("invalid pending message")]
374 InvalidPendingMessage(#[source] vmbus_core::MessageTooLarge),
375
376 #[error("failed to offer restored channel")]
377 OfferFailed(#[source] anyhow::Error),
378}
379
380#[derive(Debug, Inspect)]
383pub struct OfferInfo {
384 pub offer: protocol::OfferChannel,
385 #[inspect(skip)]
386 pub guest_to_host_interrupt: Interrupt,
387 #[inspect(skip)]
388 pub request_send: mesh::Sender<ChannelRequest>,
389 #[inspect(skip)]
390 pub revoke_recv: mesh::OneshotReceiver<()>,
391}
392
393#[derive(Debug)]
394enum ClientRequest {
395 Connect(Rpc<ConnectRequest, Result<ConnectResult, ConnectError>>),
396 Unload(Rpc<(), ()>),
397 Modify(Rpc<ModifyConnectionRequest, ConnectionState>),
398 HvsockConnect(Rpc<HvsockConnectRequest, Option<OfferInfo>>),
399}
400
401impl std::fmt::Display for ClientRequest {
402 fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
403 let s = match self {
404 ClientRequest::Connect(..) => "Connect",
405 ClientRequest::Unload { .. } => "Unload",
406 ClientRequest::Modify(..) => "Modify",
407 ClientRequest::HvsockConnect(..) => "HvsockConnect",
408 };
409 fmt.pad(s)
410 }
411}
412
413enum TaskRequest {
414 Inspect(inspect::Deferred),
415 Save(Rpc<(), SavedState>),
416 Restore(Rpc<SavedState, Result<Option<ConnectResult>, RestoreError>>),
417 PostRestore(Rpc<(), ()>),
418 Start,
419 Stop(Rpc<(), ()>),
420}
421
422#[derive(Inspect)]
426#[inspect(external_tag)]
427enum ClientState {
428 Disconnected,
430 Connecting {
432 version: Version,
433 #[inspect(skip)]
434 rpc: Rpc<ConnectRequest, Result<ConnectResult, ConnectError>>,
435 },
436 Connected {
438 version: VersionInfo,
439 #[inspect(skip)]
440 offer_send: mesh::Sender<OfferInfo>,
441 },
442 RequestingOffers {
444 version: VersionInfo,
445 #[inspect(skip)]
446 rpc: Rpc<(), Result<ConnectResult, ConnectError>>,
447 #[inspect(skip)]
448 offers: Vec<OfferInfo>,
449 },
450 Disconnecting {
452 version: VersionInfo,
453 #[inspect(skip)]
454 rpc: Rpc<(), ()>,
455 },
456}
457
458impl ClientState {
459 fn get_version(&self) -> Option<VersionInfo> {
460 match self {
461 ClientState::Connected { version, .. } => Some(*version),
462 ClientState::RequestingOffers { version, .. } => Some(*version),
463 ClientState::Disconnecting { version, .. } => Some(*version),
464 ClientState::Disconnected | ClientState::Connecting { .. } => None,
465 }
466 }
467}
468
469impl std::fmt::Display for ClientState {
470 fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
471 let s = match self {
472 ClientState::Disconnected => "Disconnected",
473 ClientState::Connecting { .. } => "Connecting",
474 ClientState::Connected { .. } => "Connected",
475 ClientState::RequestingOffers { .. } => "RequestingOffers",
476 ClientState::Disconnecting { .. } => "Disconnecting",
477 };
478 fmt.pad(s)
479 }
480}
481
482#[derive(Copy, Clone, Debug, Default)]
483struct ConnectRequest {
484 target_message_vp: u32,
485 monitor_page: Option<MonitorPageGpas>,
486 client_id: Guid,
487}
488
489#[derive(Copy, Clone, Debug, Default)]
490pub struct ModifyConnectionRequest {
491 pub monitor_page: Option<MonitorPageGpas>,
492}
493
494impl From<ModifyConnectionRequest> for protocol::ModifyConnection {
495 fn from(value: ModifyConnectionRequest) -> Self {
496 let monitor_page = value.monitor_page.unwrap_or_default();
497
498 Self {
499 parent_to_child_monitor_page_gpa: monitor_page.parent_to_child,
500 child_to_parent_monitor_page_gpa: monitor_page.child_to_parent,
501 }
502 }
503}
504
505#[derive(Debug, Inspect)]
509#[inspect(external_tag)]
510enum ChannelState {
511 Offered,
513 Opening {
515 redirected_event_flag: Option<u16>,
516 #[inspect(skip)]
517 redirected_event: Option<Event>,
518 #[inspect(skip)]
519 rpc: FailableRpc<(), OpenOutput>,
520 },
521 Restored,
523 Opened {
525 redirected_event_flag: Option<u16>,
526 #[inspect(skip)]
527 redirected_event: Option<Event>,
528 },
529 Revoked,
531}
532
533impl std::fmt::Display for ChannelState {
534 fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
535 let s = match self {
536 ChannelState::Opening { .. } => "Opening",
537 ChannelState::Offered => "Offered",
538 ChannelState::Opened { .. } => "Opened",
539 ChannelState::Restored => "Restored",
540 ChannelState::Revoked => "Revoked",
541 };
542 fmt.pad(s)
543 }
544}
545
546#[derive(Debug, Inspect)]
547struct Channel {
548 offer: protocol::OfferChannel,
549 #[inspect(skip)]
551 revoke_send: Option<mesh::OneshotSender<()>>,
552 state: ChannelState,
553 #[inspect(with = "|x| x.is_some()")]
554 modify_response_send: Option<Rpc<(), i32>>,
555 #[inspect(with = "|x| inspect::iter_by_key(x).map_key(|x| x.0)")]
556 gpadls: HashMap<GpadlId, GpadlState>,
557 is_client_released: bool,
558 connection_id: Arc<AtomicU32>,
559}
560
561impl Channel {
562 fn pending_request(&self) -> Option<&'static str> {
563 if self.modify_response_send.is_some() {
564 return Some("modify");
565 }
566 self.gpadls.iter().find_map(|(_, gpadl)| match gpadl {
567 GpadlState::Offered(_) => Some("creating gpadl"),
568 GpadlState::Created => None,
569 GpadlState::TearingDown { .. } => Some("tearing down gpadl"),
570 })
571 }
572}
573
574#[derive(Inspect)]
575struct ClientTask {
576 #[inspect(flatten)]
577 inner: ClientTaskInner,
578 channels: ChannelList,
579 state: ClientState,
580 hvsock_tracker: hvsock::HvsockRequestTracker,
581 running: bool,
582 #[inspect(with = "|x| x.is_some()")]
583 modify_request: Option<Rpc<ModifyConnectionRequest, ConnectionState>>,
584 #[inspect(skip)]
585 msg_source: Box<dyn VmbusMessageSource>,
586 #[inspect(skip)]
587 task_recv: mesh::Receiver<TaskRequest>,
588 #[inspect(skip)]
589 client_request_recv: mesh::Receiver<ClientRequest>,
590}
591
592impl ClientTask {
593 fn handle_initiate_contact(
594 &mut self,
595 rpc: Rpc<ConnectRequest, Result<ConnectResult, ConnectError>>,
596 version: Version,
597 ) {
598 let ClientState::Disconnected = self.state else {
599 tracing::warn!(client_state = %self.state, "invalid client state for InitiateContact");
600 rpc.complete(Err(ConnectError::InvalidState));
601 return;
602 };
603 let feature_flags = if version >= Version::Copper {
604 SUPPORTED_FEATURE_FLAGS
605 } else {
606 FeatureFlags::new()
607 };
608
609 let request = rpc.input();
610
611 tracing::debug!(version = ?version, ?feature_flags, "VmBus client connecting");
612 let target_info = protocol::TargetInfo::new()
613 .with_sint(SINT)
614 .with_vtl(VTL)
615 .with_feature_flags(feature_flags.into());
616 let monitor_page = request.monitor_page.unwrap_or_default();
617 let msg = protocol::InitiateContact2 {
618 initiate_contact: protocol::InitiateContact {
619 version_requested: version as u32,
620 target_message_vp: request.target_message_vp,
621 interrupt_page_or_target_info: target_info.into(),
622 parent_to_child_monitor_page_gpa: monitor_page.parent_to_child,
623 child_to_parent_monitor_page_gpa: monitor_page.child_to_parent,
624 },
625 client_id: request.client_id,
626 };
627
628 self.state = ClientState::Connecting { version, rpc };
629 if version < Version::Copper {
630 self.inner.messages.send(&msg.initiate_contact)
631 } else {
632 self.inner.messages.send(&msg);
633 }
634 }
635
636 fn handle_unload(&mut self, rpc: Rpc<(), ()>) {
637 tracing::debug!(%self.state, "VmBus client disconnecting");
638 self.state = ClientState::Disconnecting {
639 version: self.state.get_version().expect("invalid state for unload"),
640 rpc,
641 };
642
643 self.inner.messages.send(&protocol::Unload {});
644 }
645
646 fn handle_modify(&mut self, request: Rpc<ModifyConnectionRequest, ConnectionState>) {
647 if !matches!(self.state, ClientState::Connected { version, .. }
648 if version.feature_flags.modify_connection())
649 {
650 tracing::warn!("ModifyConnection not supported");
651 request.complete(ConnectionState::FAILED_UNKNOWN_FAILURE);
652 return;
653 }
654
655 if self.modify_request.is_some() {
656 tracing::warn!("Duplicate ModifyConnection request");
657 request.complete(ConnectionState::FAILED_UNKNOWN_FAILURE);
658 return;
659 }
660
661 let message = protocol::ModifyConnection::from(*request.input());
662 self.modify_request = Some(request);
663 self.inner.messages.send(&message);
664 }
665
666 fn handle_tl_connect(&mut self, rpc: Rpc<HvsockConnectRequest, Option<OfferInfo>>) {
667 let message = protocol::TlConnectRequest2::from(*rpc.input());
671 self.hvsock_tracker.add_request(rpc);
672 self.inner.messages.send(&message);
673 }
674
675 fn handle_client_request(&mut self, request: ClientRequest) {
676 match request {
677 ClientRequest::Connect(rpc) => {
678 self.handle_initiate_contact(rpc, *SUPPORTED_VERSIONS.last().unwrap());
679 }
680 ClientRequest::Unload(rpc) => {
681 self.handle_unload(rpc);
682 }
683 ClientRequest::Modify(request) => self.handle_modify(request),
684 ClientRequest::HvsockConnect(request) => self.handle_tl_connect(request),
685 }
686 }
687
688 fn handle_version_response(&mut self, msg: protocol::VersionResponse2) {
689 let old_state = std::mem::replace(&mut self.state, ClientState::Disconnected);
690 let ClientState::Connecting { version, rpc } = old_state else {
691 self.state = old_state;
692 tracing::warn!(
693 client_state = %self.state,
694 "invalid client state to handle VersionResponse"
695 );
696 return;
697 };
698 if msg.version_response.version_supported > 0 {
699 if msg.version_response.connection_state != ConnectionState::SUCCESSFUL {
700 rpc.complete(Err(ConnectError::FailedToConnect(
701 msg.version_response.connection_state,
702 )));
703 return;
704 }
705
706 let feature_flags = if version >= Version::Copper {
707 FeatureFlags::from(msg.supported_features)
708 } else {
709 FeatureFlags::new()
710 };
711
712 let version = VersionInfo {
713 version,
714 feature_flags,
715 };
716
717 self.inner.messages.send(&protocol::RequestOffers {});
718 self.state = ClientState::RequestingOffers {
719 version,
720 rpc: rpc.split().1,
721 offers: Vec::new(),
722 };
723 tracing::info!(?version, "VmBus client connected, requesting offers");
724 } else {
725 let index = SUPPORTED_VERSIONS
726 .iter()
727 .position(|v| *v == version)
728 .unwrap();
729
730 if index == 0 {
731 rpc.complete(Err(ConnectError::NoSupportedVersions));
732 return;
733 }
734 let next_version = SUPPORTED_VERSIONS[index - 1];
735 tracing::debug!(
736 version = version as u32,
737 next_version = next_version as u32,
738 "Unsupported version, retrying"
739 );
740 self.handle_initiate_contact(rpc, next_version);
741 }
742 }
743
744 fn create_channel(&mut self, offer: protocol::OfferChannel) -> Result<OfferInfo> {
745 self.create_channel_core(offer, ChannelState::Offered)
746 }
747
748 fn create_channel_core(
749 &mut self,
750 offer: protocol::OfferChannel,
751 state: ChannelState,
752 ) -> Result<OfferInfo> {
753 if self.channels.0.contains_key(&offer.channel_id) {
754 anyhow::bail!("channel {:?} exists", offer.channel_id);
755 }
756 let (request_send, request_recv) = mesh::channel();
757 let (revoke_send, revoke_recv) = mesh::oneshot();
758
759 let connection_id = Arc::new(AtomicU32::new(0));
760 self.channels.0.insert(
761 offer.channel_id,
762 Channel {
763 revoke_send: Some(revoke_send),
764 offer,
765 state,
766 modify_response_send: None,
767 gpadls: HashMap::new(),
768 is_client_released: false,
769 connection_id: connection_id.clone(),
770 },
771 );
772
773 self.inner
774 .channel_requests
775 .push(TaggedStream::new(offer.channel_id, request_recv));
776
777 Ok(OfferInfo {
778 offer,
779 guest_to_host_interrupt: self.inner.synic.guest_to_host_interrupt(connection_id),
780 revoke_recv,
781 request_send,
782 })
783 }
784
785 fn handle_offer(&mut self, offer: protocol::OfferChannel) {
786 let offer_info = self
787 .create_channel(offer)
788 .expect("channel should not exist");
789
790 tracing::info!(
791 state = %self.state,
792 channel_id = offer.channel_id.0,
793 interface_id = %offer.interface_id,
794 instance_id = %offer.instance_id,
795 subchannel_index = offer.subchannel_index,
796 "received offer");
797
798 if let Some(offer) = self.hvsock_tracker.check_offer(&offer_info.offer) {
799 offer.complete(Some(offer_info));
800 } else {
801 match &mut self.state {
802 ClientState::Connected { offer_send, .. } => {
803 offer_send.send(offer_info);
804 }
805 ClientState::RequestingOffers { offers, .. } => {
806 offers.push(offer_info);
807 }
808 state => unreachable!("invalid client state for OfferChannel: {state}"),
809 }
810 }
811 }
812
813 fn handle_rescind(&mut self, rescind: protocol::RescindChannelOffer) -> TriedRelease {
814 let mut channel = self.channels.get_mut(rescind.channel_id);
815 tracing::info!(
816 state = %self.state,
817 channel_id = rescind.channel_id.0,
818 key = %OfferKey::from(&channel.offer),
819 "received rescind"
820 );
821 let event_flag = match std::mem::replace(&mut channel.state, ChannelState::Revoked) {
822 ChannelState::Offered => None,
823 ChannelState::Opening {
824 redirected_event_flag,
825 redirected_event: _,
826 rpc,
827 } => {
828 rpc.fail(anyhow::anyhow!("channel revoked"));
829 redirected_event_flag
830 }
831 ChannelState::Restored => None,
832 ChannelState::Opened {
833 redirected_event_flag,
834 redirected_event: _,
835 } => redirected_event_flag,
836 ChannelState::Revoked => {
837 panic!("channel id {:?} already revoked", rescind.channel_id);
838 }
839 };
840 if let Some(event_flag) = event_flag {
841 self.inner.synic.free_event_flag(event_flag);
842 }
843
844 channel.revoke_send.take().unwrap().send(());
846
847 channel.try_release(&mut self.inner.messages)
848 }
849
850 fn handle_offers_delivered(&mut self) {
851 match std::mem::replace(&mut self.state, ClientState::Disconnected) {
852 ClientState::RequestingOffers {
853 version,
854 rpc,
855 offers,
856 } => {
857 tracing::info!(version = ?version, "VmBus client connected, offers delivered");
858 let (offer_send, offer_recv) = mesh::channel();
859 self.state = ClientState::Connected {
860 version,
861 offer_send,
862 };
863 rpc.complete(Ok(ConnectResult {
864 version,
865 offers,
866 offer_recv,
867 }));
868 }
869 state => {
870 tracing::warn!(client_state = %state, "invalid client state for OffersDelivered");
871 self.state = state;
872 }
873 }
874 }
875
876 fn handle_gpadl_created(&mut self, request: protocol::GpadlCreated) -> TriedRelease {
877 let mut channel = self.channels.get_mut(request.channel_id);
878 let Some(gpadl_state) = channel.gpadls.get_mut(&request.gpadl_id) else {
879 panic!("GpadlCreated for unknown gpadl {:#x}", request.gpadl_id.0);
880 };
881
882 let rpc = match std::mem::replace(gpadl_state, GpadlState::Created) {
883 GpadlState::Offered(rpc) => rpc,
884 old_state => {
885 panic!(
886 "invalid state {old_state:?} for gpadl {:#x}:{:#x}",
887 request.channel_id.0, request.gpadl_id.0
888 );
889 }
890 };
891
892 let gpadl_created = request.status == protocol::STATUS_SUCCESS;
893 if gpadl_created {
894 rpc.complete(Ok(()));
895 } else {
896 channel.gpadls.remove(&request.gpadl_id).unwrap();
897 rpc.fail(anyhow::anyhow!(
898 "gpadl creation failed: {:#x}",
899 request.status
900 ));
901 };
902 channel.try_release(&mut self.inner.messages)
903 }
904
905 fn handle_open_result(&mut self, result: protocol::OpenResult) {
906 let mut channel = self.channels.get_mut(result.channel_id);
907 tracing::debug!(
908 channel_id = result.channel_id.0,
909 key = %OfferKey::from(&channel.offer),
910 result = result.status,
911 "received open result"
912 );
913
914 let channel_opened = result.status == protocol::STATUS_SUCCESS as u32;
915 let old_state = std::mem::replace(&mut channel.state, ChannelState::Offered);
916 let ChannelState::Opening {
917 redirected_event_flag,
918 redirected_event,
919 rpc,
920 } = old_state
921 else {
922 tracing::warn!(
923 key = %OfferKey::from(&channel.offer),
924 old_state = ?channel.state,
925 channel_opened,
926 "invalid state for open result"
927 );
928 channel.state = old_state;
929 return;
930 };
931
932 if !channel_opened {
933 if let Some(event_flag) = redirected_event_flag {
934 self.inner.synic.free_event_flag(event_flag);
935 }
936 rpc.fail(anyhow::anyhow!("open failed: {:#x}", result.status));
937 return;
938 }
939
940 channel.state = ChannelState::Opened {
941 redirected_event_flag,
942 redirected_event,
943 };
944
945 rpc.complete(Ok(OpenOutput {
946 redirected_event_flag,
947 }));
948 }
949
950 fn handle_gpadl_torndown(&mut self, request: protocol::GpadlTorndown) -> TriedRelease {
951 let Some(channel_id) = self.inner.teardown_gpadls.remove(&request.gpadl_id) else {
952 panic!("gpadl {:#x} not in teardown list", request.gpadl_id.0);
953 };
954
955 let mut channel = self.channels.get_mut(channel_id);
956 tracing::debug!(
957 gpadl_id = request.gpadl_id.0,
958 channel_id = channel_id.0,
959 key = %OfferKey::from(&channel.offer),
960 "Received GpadlTorndown"
961 );
962
963 let gpadl_state = channel
964 .gpadls
965 .remove(&request.gpadl_id)
966 .expect("gpadl validated above");
967
968 let GpadlState::TearingDown { rpcs } = gpadl_state else {
969 panic!("gpadl should be tearing down if in teardown list, state = {gpadl_state:?}");
970 };
971
972 for rpc in rpcs {
973 rpc.complete(());
974 }
975 channel.try_release(&mut self.inner.messages)
976 }
977
978 fn handle_unload_complete(&mut self) {
979 match std::mem::replace(&mut self.state, ClientState::Disconnected) {
980 ClientState::Disconnecting { version: _, rpc } => {
981 tracing::info!("VmBus client disconnected");
982 rpc.complete(());
983 }
984 state => {
985 tracing::warn!(client_state = %state, "invalid client state for UnloadComplete");
986 }
987 }
988 }
989
990 fn handle_modify_complete(&mut self, response: protocol::ModifyConnectionResponse) {
991 if let Some(request) = self.modify_request.take() {
992 request.complete(response.connection_state)
993 } else {
994 tracing::warn!("Unexpected modify complete request");
995 }
996 }
997
998 fn handle_modify_channel_response(
999 &mut self,
1000 response: protocol::ModifyChannelResponse,
1001 ) -> TriedRelease {
1002 let mut channel = self.channels.get_mut(response.channel_id);
1003 let Some(sender) = channel.modify_response_send.take() else {
1004 panic!(
1005 "unexpected modify channel response for channel {:#x}",
1006 response.channel_id.0
1007 );
1008 };
1009
1010 sender.complete(response.status);
1011 channel.try_release(&mut self.inner.messages)
1012 }
1013
1014 fn handle_tl_connect_result(&mut self, response: protocol::TlConnectResult) {
1015 if let Some(rpc) = self.hvsock_tracker.check_result(&response) {
1016 rpc.complete(None);
1017 }
1018 }
1019
1020 fn handle_synic_message(&mut self, data: &[u8]) -> bool {
1022 let msg = Message::parse(data, self.state.get_version()).unwrap();
1023 tracing::trace!(?msg, "received client message from synic");
1024
1025 match msg {
1026 Message::VersionResponse3(version_response, ..) => {
1027 self.handle_version_response(version_response.version_response2);
1032 }
1033 Message::VersionResponse2(version_response, ..) => {
1034 self.handle_version_response(version_response);
1035 }
1036 Message::VersionResponse(version_response, ..) => {
1037 self.handle_version_response(version_response.into());
1038 }
1039 Message::OfferChannel(offer, ..) => {
1040 self.handle_offer(offer);
1041 }
1042 Message::AllOffersDelivered(..) => {
1043 self.handle_offers_delivered();
1044 }
1045 Message::UnloadComplete(..) => {
1046 self.handle_unload_complete();
1047 }
1048 Message::ModifyConnectionResponse(response, ..) => {
1049 self.handle_modify_complete(response);
1050 }
1051 Message::GpadlCreated(gpadl, ..) => {
1052 self.handle_gpadl_created(gpadl);
1053 }
1054 Message::OpenResult(result, ..) => {
1055 self.handle_open_result(result);
1056 }
1057 Message::GpadlTorndown(gpadl, ..) => {
1058 self.handle_gpadl_torndown(gpadl);
1059 }
1060 Message::RescindChannelOffer(rescind, ..) => {
1061 self.handle_rescind(rescind);
1062 }
1063 Message::ModifyChannelResponse(response, ..) => {
1064 self.handle_modify_channel_response(response);
1065 }
1066 Message::TlConnectResult(response, ..) => self.handle_tl_connect_result(response),
1067 Message::CloseReservedChannelResponse(..) => {
1069 todo!("Unsupported message {msg:?}")
1070 }
1071 Message::PauseResponse(..) => {
1072 return false;
1073 }
1074 Message::RequestOffers(..)
1076 | Message::OpenChannel2(..)
1077 | Message::OpenChannel(..)
1078 | Message::CloseChannel(..)
1079 | Message::GpadlHeader(..)
1080 | Message::GpadlBody(..)
1081 | Message::GpadlTeardown(..)
1082 | Message::RelIdReleased(..)
1083 | Message::InitiateContact(..)
1084 | Message::InitiateContact2(..)
1085 | Message::Unload(..)
1086 | Message::OpenReservedChannel(..)
1087 | Message::CloseReservedChannel(..)
1088 | Message::TlConnectRequest2(..)
1089 | Message::TlConnectRequest(..)
1090 | Message::ModifyChannel(..)
1091 | Message::ModifyConnection(..)
1092 | Message::Pause(..)
1093 | Message::Resume(..) => {
1094 unreachable!("Client received server message {msg:?}");
1095 }
1096 }
1097 true
1098 }
1099
1100 fn handle_open_channel(
1101 &mut self,
1102 channel_id: ChannelId,
1103 rpc: FailableRpc<OpenRequest, OpenOutput>,
1104 ) {
1105 let mut channel = self.channels.get_mut(channel_id);
1106 match &channel.state {
1107 ChannelState::Offered => {}
1108 ChannelState::Revoked => {
1109 rpc.fail(anyhow::anyhow!("channel revoked"));
1110 return;
1111 }
1112 state => {
1113 rpc.fail(anyhow::anyhow!("invalid channel state: {}", state));
1114 return;
1115 }
1116 }
1117
1118 tracing::info!(
1119 channel_id = channel_id.0,
1120 key = %OfferKey::from(&channel.offer),
1121 "opening channel on host"
1122 );
1123
1124 let (request, rpc) = rpc.split();
1125 let open_data = &request.open_data;
1126
1127 let supports_interrupt_redirection =
1128 if let ClientState::Connected { version, .. } = self.state {
1129 version.feature_flags.guest_specified_signal_parameters()
1130 || version.feature_flags.channel_interrupt_redirection()
1131 } else {
1132 false
1133 };
1134
1135 if !supports_interrupt_redirection && open_data.event_flag != channel_id.0 as u16 {
1136 rpc.fail(anyhow::anyhow!(
1137 "host does not support specifying the event flag"
1138 ));
1139 return;
1140 }
1141
1142 let open_channel = protocol::OpenChannel {
1143 channel_id,
1144 open_id: 0,
1145 ring_buffer_gpadl_id: open_data.ring_gpadl_id,
1146 target_vp: open_data.target_vp,
1147 downstream_ring_buffer_page_offset: open_data.ring_offset,
1148 user_data: open_data.user_data,
1149 };
1150
1151 let connection_id = if request.use_vtl2_connection_id {
1152 if !supports_interrupt_redirection {
1153 rpc.fail(anyhow::anyhow!(
1154 "host does not support specfiying the connection ID"
1155 ));
1156 return;
1157 }
1158 protocol::ConnectionId::new(channel_id.0, 2.try_into().unwrap(), 7).0
1159 } else {
1160 open_data.connection_id
1161 };
1162
1163 let mut flags = OpenChannelFlags::new();
1166 let event_flag = if let Some(event) = &request.incoming_event {
1167 if !supports_interrupt_redirection {
1168 rpc.fail(anyhow::anyhow!(
1169 "host does not support redirecting interrupts"
1170 ));
1171 return;
1172 }
1173
1174 flags.set_redirect_interrupt(true);
1175 match self.inner.synic.allocate_event_flag(event) {
1176 Ok(flag) => flag,
1177 Err(err) => {
1178 rpc.fail(err.context("failed to allocate event flag"));
1179 return;
1180 }
1181 }
1182 } else {
1183 open_data.event_flag
1184 };
1185
1186 if supports_interrupt_redirection {
1187 self.inner.messages.send(&protocol::OpenChannel2 {
1188 open_channel,
1189 connection_id,
1190 event_flag,
1191 flags,
1192 });
1193 } else {
1194 self.inner.messages.send(&open_channel);
1195 }
1196
1197 channel
1198 .connection_id
1199 .store(connection_id, Ordering::Release);
1200 channel.state = ChannelState::Opening {
1201 redirected_event_flag: (request.incoming_event.is_some()).then_some(event_flag),
1202 redirected_event: request.incoming_event,
1203 rpc,
1204 }
1205 }
1206
1207 fn handle_restore_channel(
1208 &mut self,
1209 channel_id: ChannelId,
1210 request: RestoreRequest,
1211 ) -> Result<OpenOutput> {
1212 let mut channel = self.channels.get_mut(channel_id);
1213 if !matches!(channel.state, ChannelState::Restored) {
1214 anyhow::bail!("invalid channel state: {}", channel.state);
1215 }
1216
1217 if request.incoming_event.is_some() != request.redirected_event_flag.is_some() {
1218 anyhow::bail!("incoming event and redirected event flag must both be set or unset");
1219 }
1220
1221 if let Some((flag, event)) = request
1222 .redirected_event_flag
1223 .zip(request.incoming_event.as_ref())
1224 {
1225 self.inner.synic.restore_event_flag(flag, event)?;
1226 }
1227
1228 channel
1229 .connection_id
1230 .store(request.connection_id, Ordering::Release);
1231 channel.state = ChannelState::Opened {
1232 redirected_event_flag: request.redirected_event_flag,
1233 redirected_event: request.incoming_event,
1234 };
1235 Ok(OpenOutput {
1236 redirected_event_flag: request.redirected_event_flag,
1237 })
1238 }
1239
1240 fn handle_gpadl(&mut self, channel_id: ChannelId, rpc: FailableRpc<GpadlRequest, ()>) {
1241 let (request, rpc) = rpc.split();
1242 let mut channel = self.channels.get_mut(channel_id);
1243 if channel
1244 .gpadls
1245 .insert(request.id, GpadlState::Offered(rpc))
1246 .is_some()
1247 {
1248 panic!(
1249 "duplicate gpadl ID {:?} for channel {:?}.",
1250 request.id, channel_id
1251 );
1252 }
1253
1254 tracing::trace!(
1255 channel_id = channel_id.0,
1256 key = %OfferKey::from(&channel.offer),
1257 gpadl_id = request.id.0,
1258 count = request.count,
1259 len = request.buf.len(),
1260 "received gpadl request"
1261 );
1262
1263 let (first, remaining) = if request.buf.len() > protocol::GpadlHeader::MAX_DATA_VALUES {
1265 request.buf.split_at(protocol::GpadlHeader::MAX_DATA_VALUES)
1266 } else {
1267 (request.buf.as_slice(), [].as_slice())
1268 };
1269
1270 let message = protocol::GpadlHeader {
1271 channel_id,
1272 gpadl_id: request.id,
1273 len: (request.buf.len() * size_of::<u64>())
1274 .try_into()
1275 .expect("Too many GPA values"),
1276 count: request.count,
1277 };
1278
1279 self.inner
1280 .messages
1281 .send_with_data(&message, first.as_bytes());
1282
1283 let message = protocol::GpadlBody {
1285 rsvd: 0,
1286 gpadl_id: request.id,
1287 };
1288 for chunk in remaining.chunks(protocol::GpadlBody::MAX_DATA_VALUES) {
1289 self.inner
1290 .messages
1291 .send_with_data(&message, chunk.as_bytes());
1292 }
1293 }
1294
1295 fn handle_gpadl_teardown(&mut self, channel_id: ChannelId, rpc: Rpc<GpadlId, ()>) {
1296 let (gpadl_id, rpc) = rpc.split();
1297 let mut channel = self.channels.get_mut(channel_id);
1298 let Some(gpadl_state) = channel.gpadls.get_mut(&gpadl_id) else {
1299 tracing::warn!(
1300 gpadl_id = gpadl_id.0,
1301 channel_id = channel_id.0,
1302 key = %OfferKey::from(&channel.offer),
1303 "Gpadl teardown for unknown gpadl or revoked channel"
1304 );
1305 return;
1306 };
1307
1308 match gpadl_state {
1309 GpadlState::Offered(_) => {
1310 tracing::warn!(
1311 gpadl_id = gpadl_id.0,
1312 channel_id = channel_id.0,
1313 key = %OfferKey::from(&channel.offer),
1314 "gpadl teardown for offered gpadl"
1315 );
1316 }
1317 GpadlState::Created => {
1318 *gpadl_state = GpadlState::TearingDown { rpcs: vec![rpc] };
1319 assert!(
1323 self.inner
1324 .teardown_gpadls
1325 .insert(gpadl_id, channel_id)
1326 .is_none(),
1327 "Gpadl state validated above"
1328 );
1329
1330 self.inner.messages.send(&protocol::GpadlTeardown {
1331 channel_id,
1332 gpadl_id,
1333 });
1334 }
1335 GpadlState::TearingDown { rpcs } => {
1336 rpcs.push(rpc);
1337 }
1338 }
1339 }
1340
1341 fn handle_close_channel(&mut self, channel_id: ChannelId) {
1342 let mut channel = self.channels.get_mut(channel_id);
1343 self.inner.close_channel(channel_id, &mut channel);
1344 }
1345
1346 fn handle_modify_channel(&mut self, channel_id: ChannelId, rpc: Rpc<ModifyRequest, i32>) {
1347 assert!(self.check_version(Version::Iron));
1351 let mut channel = self.channels.get_mut(channel_id);
1352 if channel.modify_response_send.is_some() {
1353 panic!("duplicate channel modify request {channel_id:?}");
1354 }
1355
1356 let (request, response) = rpc.split();
1357 channel.modify_response_send = Some(response);
1358 let payload = match request {
1359 ModifyRequest::TargetVp { target_vp } => protocol::ModifyChannel {
1360 channel_id,
1361 target_vp,
1362 },
1363 };
1364
1365 self.inner.messages.send(&payload);
1366 }
1367
1368 fn handle_channel_request(&mut self, channel_id: ChannelId, request: ChannelRequest) {
1369 match request {
1370 ChannelRequest::Open(rpc) => self.handle_open_channel(channel_id, rpc),
1371 ChannelRequest::Restore(rpc) => {
1372 rpc.handle_failable_sync(|request| self.handle_restore_channel(channel_id, request))
1373 }
1374 ChannelRequest::Gpadl(req) => self.handle_gpadl(channel_id, req),
1375 ChannelRequest::TeardownGpadl(req) => self.handle_gpadl_teardown(channel_id, req),
1376 ChannelRequest::Close(req) => {
1377 req.handle_sync(|()| self.handle_close_channel(channel_id))
1378 }
1379 ChannelRequest::Modify(req) => self.handle_modify_channel(channel_id, req),
1380 }
1381 }
1382
1383 async fn handle_task(&mut self, task: TaskRequest) {
1384 match task {
1385 TaskRequest::Inspect(deferred) => {
1386 deferred.inspect(&*self);
1387 }
1388 TaskRequest::Save(rpc) => rpc.handle_sync(|()| self.handle_save()),
1389 TaskRequest::Restore(rpc) => {
1390 rpc.handle_sync(|saved_state| self.handle_restore(saved_state))
1391 }
1392 TaskRequest::PostRestore(rpc) => rpc.handle_sync(|()| self.handle_post_restore()),
1393 TaskRequest::Start => self.handle_start(),
1394 TaskRequest::Stop(rpc) => rpc.handle(async |()| self.handle_stop().await).await,
1395 }
1396 }
1397
1398 fn handle_device_removal(&mut self, channel_id: ChannelId) -> TriedRelease {
1400 let mut channel = self.channels.get_mut(channel_id);
1401 channel.is_client_released = true;
1402 if let ChannelState::Opened { .. } = channel.state {
1404 tracing::warn!(
1405 channel_id = channel_id.0,
1406 key = %OfferKey::from(&channel.offer),
1407 "Channel dropped without closing first"
1408 );
1409 self.inner.close_channel(channel_id, &mut channel);
1410 }
1411 channel.try_release(&mut self.inner.messages)
1412 }
1413
1414 fn check_version(&self, version: Version) -> bool {
1416 matches!(self.state, ClientState::Connected { version: v, .. } if v.version >= version)
1417 }
1418
1419 fn handle_start(&mut self) {
1420 assert!(!self.running);
1421 self.msg_source.resume_message_stream();
1422 self.inner.messages.resume();
1423 self.running = true;
1424 }
1425
1426 async fn handle_stop(&mut self) {
1427 assert!(self.running);
1428
1429 loop {
1430 while let Some((id, request)) = self.channels.revoked_channel_with_pending_request() {
1435 tracelimit::info_ratelimited!(
1436 channel_id = id.0,
1437 request,
1438 "waiting for responses for channel"
1439 );
1440 assert!(self.process_next_message().await);
1441 }
1442
1443 if self.can_pause_resume() {
1444 self.inner.messages.pause();
1445 } else {
1446 self.msg_source.pause_message_stream();
1449 self.inner.messages.force_pause();
1450 }
1451
1452 while self.process_next_message().await {}
1455
1456 if self
1459 .channels
1460 .revoked_channel_with_pending_request()
1461 .is_none()
1462 {
1463 break;
1464 }
1465 if !self.can_pause_resume() {
1466 self.msg_source.resume_message_stream();
1467 }
1468 self.inner.messages.resume();
1469 }
1470
1471 tracing::debug!("messages drained");
1472 self.running = false;
1474 }
1475
1476 async fn process_next_message(&mut self) -> bool {
1477 let mut buf = [0; protocol::MAX_MESSAGE_SIZE];
1478 let recv = self.msg_source.recv(&mut buf);
1479 let flush = async {
1482 self.inner.messages.flush_messages().await;
1483 std::future::pending().await
1484 };
1485 let size = (recv, flush)
1486 .race()
1487 .await
1488 .expect("Fatal error reading messages from synic");
1489 if size == 0 {
1490 return false;
1491 }
1492 self.handle_synic_message(&buf[..size])
1493 }
1494
1495 fn can_pause_resume(&self) -> bool {
1504 if let ClientState::Connected { version, .. } = self.state {
1505 version.feature_flags.pause_resume()
1506 } else {
1507 false
1508 }
1509 }
1510
1511 async fn run(&mut self) {
1512 let mut buf = [0; protocol::MAX_MESSAGE_SIZE];
1513 loop {
1514 let mut message_recv =
1515 OptionFuture::from(self.running.then(|| self.msg_source.recv(&mut buf).fuse()));
1516
1517 let host_backed_up = !self.inner.messages.is_empty();
1527 let flush_messages = OptionFuture::from(
1528 (self.running && host_backed_up)
1529 .then(|| self.inner.messages.flush_messages().fuse()),
1530 );
1531
1532 let mut client_request_recv = OptionFuture::from(
1533 (self.running && !host_backed_up).then(|| self.client_request_recv.next()),
1534 );
1535
1536 let mut channel_requests = OptionFuture::from(
1537 (self.running && !host_backed_up)
1538 .then(|| self.inner.channel_requests.select_next_some()),
1539 );
1540
1541 futures::select! { _r = pin!(flush_messages) => {}
1543 r = self.task_recv.next() => {
1544 if let Some(task) = r {
1545 self.handle_task(task).await;
1546 } else {
1547 break;
1548 }
1549 }
1550 r = client_request_recv => {
1551 if let Some(Some(request)) = r {
1552 self.handle_client_request(request);
1553 } else {
1554 break;
1555 }
1556 }
1557 r = channel_requests => {
1558 match r.unwrap() {
1559 (id, Some(request)) => self.handle_channel_request(id, request),
1560 (id, _) => {
1561 self.handle_device_removal(id);
1562 }
1563 }
1564 }
1565 r = message_recv => {
1566 match r.unwrap() {
1567 Ok(size) => {
1568 if size == 0 {
1569 panic!("Unexpected end of file reading messages from synic.");
1570 }
1571
1572 self.handle_synic_message(&buf[..size]);
1573 }
1574 Err(err) => {
1575 panic!("Error reading messages from synic: {err:?}");
1576 }
1577 }
1578 }
1579 complete => break,
1580 }
1581 }
1582 }
1583}
1584
1585impl ClientTaskInner {
1586 fn close_channel(&mut self, channel_id: ChannelId, channel: &mut Channel) {
1587 if let ChannelState::Opened {
1588 redirected_event_flag,
1589 ..
1590 } = channel.state
1591 {
1592 if let Some(flag) = redirected_event_flag {
1593 self.synic.free_event_flag(flag);
1594 }
1595 tracing::info!(
1596 channel_id = channel_id.0,
1597 key = %OfferKey::from(&channel.offer),
1598 "closing channel on host"
1599 );
1600
1601 self.messages.send(&protocol::CloseChannel { channel_id });
1602 channel.state = ChannelState::Offered;
1603 channel.connection_id.store(0, Ordering::Release);
1604 } else {
1605 tracing::warn!(
1606 channel_id = channel_id.0,
1607 key = %OfferKey::from(&channel.offer),
1608 channel_state = %channel.state,
1609 "invalid channel state for close channel"
1610 );
1611 }
1612 }
1613}
1614
1615#[derive(Debug, Inspect)]
1616#[inspect(external_tag)]
1617enum GpadlState {
1618 Offered(#[inspect(skip)] FailableRpc<(), ()>),
1620 Created,
1622 TearingDown {
1624 #[inspect(skip)]
1625 rpcs: Vec<Rpc<(), ()>>,
1626 },
1627}
1628
1629#[derive(Inspect)]
1630struct OutgoingMessages {
1631 #[inspect(skip)]
1632 poster: Box<dyn PollPostMessage>,
1633 #[inspect(with = "|x| x.len()")]
1634 queued: VecDeque<OutgoingMessage>,
1635 state: OutgoingMessageState,
1636}
1637
1638#[derive(Inspect, PartialEq, Eq, Debug)]
1639enum OutgoingMessageState {
1640 Running,
1641 SendingPauseMessage,
1642 Paused,
1643}
1644
1645impl OutgoingMessages {
1646 fn send<T: IntoBytes + protocol::VmbusMessage + std::fmt::Debug + Immutable + KnownLayout>(
1647 &mut self,
1648 msg: &T,
1649 ) {
1650 self.send_with_data(msg, &[])
1651 }
1652
1653 fn send_with_data<
1654 T: IntoBytes + protocol::VmbusMessage + std::fmt::Debug + Immutable + KnownLayout,
1655 >(
1656 &mut self,
1657 msg: &T,
1658 data: &[u8],
1659 ) {
1660 tracing::trace!(typ = ?T::MESSAGE_TYPE, "Sending message to host");
1661 let msg = OutgoingMessage::with_data(msg, data);
1662 if self.queued.is_empty() && self.state == OutgoingMessageState::Running {
1663 let r = self.poster.poll_post_message(
1664 &mut Context::from_waker(std::task::Waker::noop()),
1665 protocol::VMBUS_MESSAGE_REDIRECT_CONNECTION_ID,
1666 1,
1667 msg.data(),
1668 );
1669 if let Poll::Ready(()) = r {
1670 return;
1671 }
1672 }
1673 tracing::trace!("queueing message");
1674 self.queued.push_back(msg);
1675 }
1676
1677 async fn flush_messages(&mut self) {
1678 let mut send = async |msg: &OutgoingMessage| {
1679 poll_fn(|cx| {
1680 self.poster.poll_post_message(
1681 cx,
1682 protocol::VMBUS_MESSAGE_REDIRECT_CONNECTION_ID,
1683 1,
1684 msg.data(),
1685 )
1686 })
1687 .await
1688 };
1689 match self.state {
1690 OutgoingMessageState::Running => {
1691 while let Some(msg) = self.queued.front() {
1692 send(msg).await;
1693 tracing::trace!("sent queued message");
1694 self.queued.pop_front();
1695 }
1696 }
1697 OutgoingMessageState::SendingPauseMessage => {
1698 send(&OutgoingMessage::new(&protocol::Pause)).await;
1699 tracing::trace!("sent pause message");
1700 self.state = OutgoingMessageState::Paused;
1701 }
1702 OutgoingMessageState::Paused => {}
1703 }
1704 }
1705
1706 fn pause(&mut self) {
1709 assert_eq!(self.state, OutgoingMessageState::Running);
1710 self.state = OutgoingMessageState::SendingPauseMessage;
1711 self.queued
1713 .push_front(OutgoingMessage::new(&protocol::Resume));
1714 }
1715
1716 fn force_pause(&mut self) {
1720 assert_eq!(self.state, OutgoingMessageState::Running);
1721 self.state = OutgoingMessageState::Paused;
1722 }
1723
1724 fn resume(&mut self) {
1725 assert_eq!(self.state, OutgoingMessageState::Paused);
1726 self.state = OutgoingMessageState::Running;
1727 }
1728
1729 fn is_empty(&self) -> bool {
1730 self.queued.is_empty()
1731 }
1732}
1733
1734#[derive(Inspect)]
1735struct ClientTaskInner {
1736 messages: OutgoingMessages,
1737 #[inspect(with = "|x| inspect::iter_by_key(x).map_key(|id| id.0)")]
1738 teardown_gpadls: HashMap<GpadlId, ChannelId>,
1739 #[inspect(skip)]
1740 channel_requests: SelectAll<TaggedStream<ChannelId, mesh::Receiver<ChannelRequest>>>,
1741 synic: SynicState,
1742}
1743
1744#[derive(Inspect)]
1745struct SynicState {
1746 #[inspect(skip)]
1747 event_client: Arc<dyn SynicEventClient>,
1748 #[inspect(iter_by_index)]
1749 event_flag_state: Vec<bool>,
1750}
1751
1752#[derive(Inspect, Default)]
1753#[inspect(transparent)]
1754struct ChannelList(
1755 #[inspect(with = "|x| inspect::iter_by_key(x).map_key(|id| id.0)")] HashMap<ChannelId, Channel>,
1756);
1757
1758struct ChannelRef<'a>(hash_map::OccupiedEntry<'a, ChannelId, Channel>);
1761
1762struct TriedRelease(());
1766
1767impl ChannelRef<'_> {
1768 fn try_release(self, messages: &mut OutgoingMessages) -> TriedRelease {
1772 if self.is_client_released
1773 && matches!(self.state, ChannelState::Revoked)
1774 && self.pending_request().is_none()
1775 {
1776 let channel_id = *self.0.key();
1777 tracelimit::info_ratelimited!(
1778 channel_id = channel_id.0,
1779 key = %OfferKey::from(&self.offer),
1780 "releasing channel"
1781 );
1782
1783 messages.send(&protocol::RelIdReleased { channel_id });
1784 self.0.remove();
1785 }
1786 TriedRelease(())
1787 }
1788}
1789
1790impl Deref for ChannelRef<'_> {
1791 type Target = Channel;
1792
1793 fn deref(&self) -> &Self::Target {
1794 self.0.get()
1795 }
1796}
1797
1798impl DerefMut for ChannelRef<'_> {
1799 fn deref_mut(&mut self) -> &mut Self::Target {
1800 self.0.get_mut()
1801 }
1802}
1803
1804impl ChannelList {
1805 fn revoked_channel_with_pending_request(&self) -> Option<(ChannelId, &'static str)> {
1806 self.0.iter().find_map(|(&id, channel)| {
1807 if !matches!(channel.state, ChannelState::Revoked) {
1808 return None;
1809 }
1810 Some((id, channel.pending_request()?))
1811 })
1812 }
1813
1814 #[track_caller]
1815 fn get_mut(&mut self, channel_id: ChannelId) -> ChannelRef<'_> {
1816 match self.0.entry(channel_id) {
1817 hash_map::Entry::Occupied(entry) => ChannelRef(entry),
1818 hash_map::Entry::Vacant(_) => {
1819 panic!("channel {:?} not found", channel_id);
1820 }
1821 }
1822 }
1823}
1824
1825impl SynicState {
1826 fn guest_to_host_interrupt(&self, connection_id: Arc<AtomicU32>) -> Interrupt {
1827 Interrupt::from_fn({
1828 let event_client = self.event_client.clone();
1829 move || {
1830 let connection_id = connection_id.load(Ordering::Acquire);
1831 if connection_id == 0 {
1832 tracing::debug!("interrupt signal after close");
1833 return;
1834 }
1835
1836 if let Err(err) = event_client.signal_event(connection_id, 0) {
1837 tracelimit::warn_ratelimited!(
1838 error = &err as &dyn std::error::Error,
1839 "failed to signal event"
1840 );
1841 }
1842 }
1843 })
1844 }
1845
1846 const MAX_EVENT_FLAGS: u16 = 2047;
1847
1848 fn allocate_event_flag(&mut self, event: &Event) -> Result<u16> {
1849 let i = self
1850 .event_flag_state
1851 .iter()
1852 .position(|&used| !used)
1853 .ok_or(())
1854 .or_else(|()| {
1855 if self.event_flag_state.len() >= Self::MAX_EVENT_FLAGS as usize {
1856 anyhow::bail!("out of event flags");
1857 }
1858 self.event_flag_state.push(false);
1859 Ok(self.event_flag_state.len() - 1)
1860 })?;
1861
1862 let event_flag = (i + 1) as u16;
1863 self.event_client
1864 .map_event(event_flag, event)
1865 .context("failed to map event")?;
1866 self.event_flag_state[i] = true;
1867 Ok(event_flag)
1868 }
1869
1870 fn restore_event_flag(&mut self, flag: u16, event: &Event) -> Result<()> {
1871 let i = (flag as usize)
1872 .checked_sub(1)
1873 .context("invalid event flag")?;
1874 if i >= Self::MAX_EVENT_FLAGS as usize {
1875 anyhow::bail!("invalid event flag");
1876 }
1877 if self.event_flag_state.len() <= i {
1878 self.event_flag_state.resize(i + 1, false);
1879 }
1880 if self.event_flag_state[i] {
1881 anyhow::bail!("event flag already in use");
1882 }
1883 self.event_client
1884 .map_event(flag, event)
1885 .context("failed to map event")?;
1886 self.event_flag_state[i] = true;
1887 Ok(())
1888 }
1889
1890 fn free_event_flag(&mut self, flag: u16) {
1891 let i = flag as usize - 1;
1892 assert!(i < self.event_flag_state.len());
1893 self.event_flag_state[i] = false;
1894 }
1895}
1896
1897#[cfg(test)]
1898mod tests {
1899 use super::*;
1900 use futures_concurrency::future::Join;
1901 use guid::Guid;
1902 use pal_async::DefaultDriver;
1903 use pal_async::async_test;
1904 use pal_async::timer::PolledTimer;
1905 use protocol::TargetInfo;
1906 use std::fmt::Debug;
1907 use std::task::ready;
1908 use std::time::Duration;
1909 use test_with_tracing::test;
1910 use vmbus_core::protocol::MessageHeader;
1911 use vmbus_core::protocol::MessageType;
1912 use vmbus_core::protocol::OfferFlags;
1913 use vmbus_core::protocol::UserDefinedData;
1914 use vmbus_core::protocol::VmbusMessage;
1915 use zerocopy::FromBytes;
1916 use zerocopy::FromZeros;
1917 use zerocopy::Immutable;
1918 use zerocopy::IntoBytes;
1919 use zerocopy::KnownLayout;
1920
1921 const VMBUS_TEST_CLIENT_ID: Guid = guid::guid!("e6e6e6e6-e6e6-e6e6-e6e6-e6e6e6e6e6e6");
1922
1923 fn in_msg<T: IntoBytes + Immutable + KnownLayout>(message_type: MessageType, t: T) -> Vec<u8> {
1924 let mut data = Vec::new();
1925 data.extend_from_slice(&message_type.0.to_ne_bytes());
1926 data.extend_from_slice(&0u32.to_ne_bytes());
1927 data.extend_from_slice(t.as_bytes());
1928 data
1929 }
1930
1931 #[track_caller]
1932 fn check_message<T>(msg: OutgoingMessage, chk: T)
1933 where
1934 T: IntoBytes + FromBytes + Immutable + KnownLayout + Debug + VmbusMessage,
1935 {
1936 check_message_with_data(msg, chk, &[]);
1937 }
1938
1939 #[track_caller]
1940 fn check_message_with_data<T>(msg: OutgoingMessage, chk: T, data: &[u8])
1941 where
1942 T: IntoBytes + FromBytes + Immutable + KnownLayout + Debug + VmbusMessage,
1943 {
1944 let chk_data = OutgoingMessage::with_data(&chk, data);
1945 if msg.data() != chk_data.data() {
1946 let (header, rest) = MessageHeader::read_from_prefix(msg.data()).unwrap();
1947 assert_eq!(header.message_type(), <T as VmbusMessage>::MESSAGE_TYPE);
1948 let (msg, rest) = T::read_from_prefix(rest).expect("incorrect message size");
1949 if msg.as_bytes() != chk.as_bytes() {
1950 panic!("mismatched messages, expected {:#?}, got {:#?}", chk, msg);
1951 }
1952 if rest != data {
1953 panic!("mismatched data, expected {:#?}, got {:#?}", data, rest);
1954 }
1955 }
1956 }
1957
1958 struct TestServer {
1959 messages: mesh::Receiver<OutgoingMessage>,
1960 send: mesh::Sender<Vec<u8>>,
1961 }
1962
1963 impl TestServer {
1964 async fn next(&mut self) -> Option<OutgoingMessage> {
1965 self.messages.next().await
1966 }
1967
1968 fn send(&self, msg: Vec<u8>) {
1969 self.send.send(msg);
1970 }
1971
1972 async fn connect(&mut self, client: &mut VmbusClient) -> ConnectResult {
1973 self.connect_with_channels(client, |_| {}).await
1974 }
1975
1976 async fn connect_with_channels(
1977 &mut self,
1978 client: &mut VmbusClient,
1979 send_offers: impl FnOnce(&mut Self),
1980 ) -> ConnectResult {
1981 let client_connect = client.connect(0, None, Guid::ZERO);
1982
1983 let server_connect = async {
1984 let _ = self.next().await.unwrap();
1985
1986 self.send(in_msg(
1987 MessageType::VERSION_RESPONSE,
1988 protocol::VersionResponse2 {
1989 version_response: protocol::VersionResponse {
1990 version_supported: 1,
1991 connection_state: ConnectionState::SUCCESSFUL,
1992 padding: 0,
1993 selected_version_or_connection_id: 0,
1994 },
1995 supported_features: SUPPORTED_FEATURE_FLAGS.into(),
1996 },
1997 ));
1998
1999 check_message(self.next().await.unwrap(), protocol::RequestOffers {});
2000
2001 send_offers(self);
2002 self.send(in_msg(MessageType::ALL_OFFERS_DELIVERED, [0x00]));
2003 };
2004
2005 let (connection, ()) = (client_connect, server_connect).join().await;
2006
2007 let connection = connection.unwrap();
2008 assert_eq!(connection.version.version, Version::Copper);
2009 assert_eq!(connection.version.feature_flags, SUPPORTED_FEATURE_FLAGS);
2010 connection
2011 }
2012
2013 async fn get_channel(&mut self, client: &mut VmbusClient) -> OfferInfo {
2014 let [channel] = self
2015 .get_channels(client, 1)
2016 .await
2017 .offers
2018 .try_into()
2019 .unwrap();
2020 channel
2021 }
2022
2023 async fn get_channels(&mut self, client: &mut VmbusClient, count: usize) -> ConnectResult {
2024 self.connect_with_channels(client, |this| {
2025 for i in 0..count {
2026 let offer = protocol::OfferChannel {
2027 interface_id: Guid::new_random(),
2028 instance_id: Guid::new_random(),
2029 rsvd: [0; 4],
2030 flags: OfferFlags::new(),
2031 mmio_megabytes: 0,
2032 user_defined: UserDefinedData::new_zeroed(),
2033 subchannel_index: 0,
2034 mmio_megabytes_optional: 0,
2035 channel_id: ChannelId(i as u32),
2036 monitor_id: 0,
2037 monitor_allocated: 0,
2038 is_dedicated: 0,
2039 connection_id: 0,
2040 };
2041
2042 this.send(in_msg(MessageType::OFFER_CHANNEL, offer));
2043 }
2044 })
2045 .await
2046 }
2047
2048 async fn stop_client(&mut self, client: &mut VmbusClient) {
2049 let client_stop = client.stop();
2050 let server_stop = async {
2051 check_message(self.next().await.unwrap(), protocol::Pause);
2052 self.send(in_msg(MessageType::PAUSE_RESPONSE, protocol::PauseResponse));
2053 };
2054 (client_stop, server_stop).join().await;
2055 }
2056
2057 async fn start_client(&mut self, client: &mut VmbusClient) {
2058 client.start();
2059 check_message(self.next().await.unwrap(), protocol::Resume);
2060 }
2061 }
2062
2063 struct TestServerClient {
2064 sender: mesh::Sender<OutgoingMessage>,
2065 timer: PolledTimer,
2066 deadline: Option<pal_async::timer::Instant>,
2067 }
2068
2069 impl PollPostMessage for TestServerClient {
2070 fn poll_post_message(
2071 &mut self,
2072 cx: &mut Context<'_>,
2073 _connection_id: u32,
2074 _typ: u32,
2075 msg: &[u8],
2076 ) -> Poll<()> {
2077 loop {
2078 if let Some(deadline) = self.deadline {
2079 ready!(self.timer.poll_until(cx, deadline));
2080 self.deadline = None;
2081 }
2082 let mut b = [0];
2087 getrandom::fill(&mut b).unwrap();
2088 if b[0] % 4 == 0 {
2089 self.deadline =
2090 Some(pal_async::timer::Instant::now() + Duration::from_millis(10));
2091 } else {
2092 let msg = OutgoingMessage::from_message(msg).unwrap();
2093 tracing::info!(
2094 msg = ?MessageHeader::read_from_prefix(msg.data()),
2095 "sending message"
2096 );
2097 self.sender.send(msg);
2098 break Poll::Ready(());
2099 }
2100 }
2101 }
2102 }
2103
2104 struct NoopSynicEvents;
2105
2106 impl SynicEventClient for NoopSynicEvents {
2107 fn map_event(&self, _event_flag: u16, _event: &Event) -> std::io::Result<()> {
2108 Ok(())
2109 }
2110
2111 fn unmap_event(&self, _event_flag: u16) {}
2112
2113 fn signal_event(&self, _connection_id: u32, _event_flag: u16) -> std::io::Result<()> {
2114 Err(std::io::ErrorKind::Unsupported.into())
2115 }
2116 }
2117
2118 struct TestMessageSource {
2119 msg_recv: mesh::Receiver<Vec<u8>>,
2120 paused: bool,
2121 }
2122
2123 impl AsyncRecv for TestMessageSource {
2124 fn poll_recv(
2125 &mut self,
2126 cx: &mut Context<'_>,
2127 mut bufs: &mut [std::io::IoSliceMut<'_>],
2128 ) -> Poll<std::io::Result<usize>> {
2129 let value = match self.msg_recv.poll_recv(cx) {
2130 Poll::Ready(v) => v.unwrap(),
2131 Poll::Pending => {
2132 if self.paused {
2133 return Poll::Ready(Ok(0));
2134 } else {
2135 return Poll::Pending;
2136 }
2137 }
2138 };
2139 let mut remaining = value.as_slice();
2140 let mut total_size = 0;
2141 while !remaining.is_empty() && !bufs.is_empty() {
2142 let size = bufs[0].len().min(remaining.len());
2143 bufs[0][..size].copy_from_slice(&remaining[..size]);
2144 remaining = &remaining[size..];
2145 bufs = &mut bufs[1..];
2146 total_size += size;
2147 }
2148
2149 Ok(total_size).into()
2150 }
2151 }
2152
2153 impl VmbusMessageSource for TestMessageSource {
2154 fn pause_message_stream(&mut self) {
2155 self.paused = true;
2156 }
2157
2158 fn resume_message_stream(&mut self) {
2159 self.paused = false;
2160 }
2161 }
2162
2163 fn test_init(driver: &DefaultDriver) -> (TestServer, VmbusClient) {
2164 let (msg_send, msg_recv) = mesh::channel();
2165 let (synic_send, synic_recv) = mesh::channel();
2166 let server = TestServer {
2167 messages: synic_recv,
2168 send: msg_send,
2169 };
2170 let mut client = VmbusClientBuilder::new(
2171 NoopSynicEvents,
2172 TestMessageSource {
2173 msg_recv,
2174 paused: false,
2175 },
2176 TestServerClient {
2177 sender: synic_send,
2178 deadline: None,
2179 timer: PolledTimer::new(driver),
2180 },
2181 )
2182 .build(driver);
2183 client.start();
2184 (server, client)
2185 }
2186
2187 #[async_test]
2188 async fn test_initiate_contact_success(driver: DefaultDriver) {
2189 let (mut server, client) = test_init(&driver);
2190 let _recv = client
2191 .access
2192 .client_request_send
2193 .call(ClientRequest::Connect, ConnectRequest::default());
2194 check_message(
2195 server.next().await.unwrap(),
2196 protocol::InitiateContact2 {
2197 initiate_contact: protocol::InitiateContact {
2198 version_requested: Version::Copper as u32,
2199 target_message_vp: 0,
2200 interrupt_page_or_target_info: TargetInfo::new()
2201 .with_sint(2)
2202 .with_vtl(0)
2203 .with_feature_flags(SUPPORTED_FEATURE_FLAGS.into())
2204 .into(),
2205 parent_to_child_monitor_page_gpa: 0,
2206 child_to_parent_monitor_page_gpa: 0,
2207 },
2208 ..FromZeros::new_zeroed()
2209 },
2210 );
2211 }
2212
2213 #[async_test]
2214 async fn test_connect_success(driver: DefaultDriver) {
2215 let (mut server, mut client) = test_init(&driver);
2216 let client_connect = client.connect(0, None, Guid::ZERO);
2217
2218 let server_connect = async {
2219 check_message(
2220 server.next().await.unwrap(),
2221 protocol::InitiateContact2 {
2222 initiate_contact: protocol::InitiateContact {
2223 version_requested: Version::Copper as u32,
2224 target_message_vp: 0,
2225 interrupt_page_or_target_info: TargetInfo::new()
2226 .with_sint(2)
2227 .with_vtl(0)
2228 .with_feature_flags(SUPPORTED_FEATURE_FLAGS.into())
2229 .into(),
2230 parent_to_child_monitor_page_gpa: 0,
2231 child_to_parent_monitor_page_gpa: 0,
2232 },
2233 ..FromZeros::new_zeroed()
2234 },
2235 );
2236
2237 server.send(in_msg(
2238 MessageType::VERSION_RESPONSE,
2239 protocol::VersionResponse2 {
2240 version_response: protocol::VersionResponse {
2241 version_supported: 1,
2242 connection_state: ConnectionState::SUCCESSFUL,
2243 padding: 0,
2244 selected_version_or_connection_id: 0,
2245 },
2246 supported_features: SUPPORTED_FEATURE_FLAGS.into_bits(),
2247 },
2248 ));
2249
2250 check_message(server.next().await.unwrap(), protocol::RequestOffers {});
2251 server.send(in_msg(MessageType::ALL_OFFERS_DELIVERED, [0x00]));
2252 };
2253
2254 let (connection, ()) = (client_connect, server_connect).join().await;
2255 let connection = connection.unwrap();
2256
2257 assert_eq!(connection.version.version, Version::Copper);
2258 assert_eq!(connection.version.feature_flags, SUPPORTED_FEATURE_FLAGS);
2259 }
2260
2261 #[async_test]
2262 async fn test_feature_flags(driver: DefaultDriver) {
2263 let (mut server, mut client) = test_init(&driver);
2264 let client_connect = client.connect(0, None, Guid::ZERO);
2265
2266 let server_connect = async {
2267 check_message(
2268 server.next().await.unwrap(),
2269 protocol::InitiateContact2 {
2270 initiate_contact: protocol::InitiateContact {
2271 version_requested: Version::Copper as u32,
2272 target_message_vp: 0,
2273 interrupt_page_or_target_info: TargetInfo::new()
2274 .with_sint(2)
2275 .with_vtl(0)
2276 .with_feature_flags(SUPPORTED_FEATURE_FLAGS.into())
2277 .into(),
2278 parent_to_child_monitor_page_gpa: 0,
2279 child_to_parent_monitor_page_gpa: 0,
2280 },
2281 ..FromZeros::new_zeroed()
2282 },
2283 );
2284
2285 server.send(in_msg(
2288 MessageType::VERSION_RESPONSE,
2289 protocol::VersionResponse2 {
2290 version_response: protocol::VersionResponse {
2291 version_supported: 1,
2292 connection_state: ConnectionState::SUCCESSFUL,
2293 padding: 0,
2294 selected_version_or_connection_id: 0,
2295 },
2296 supported_features: 2,
2297 },
2298 ));
2299
2300 check_message(server.next().await.unwrap(), protocol::RequestOffers {});
2301 server.send(in_msg(MessageType::ALL_OFFERS_DELIVERED, [0x00]));
2302 };
2303
2304 let (connection, ()) = (client_connect, server_connect).join().await;
2305 let connection = connection.unwrap();
2306
2307 assert_eq!(connection.version.version, Version::Copper);
2308 assert_eq!(
2309 connection.version.feature_flags,
2310 FeatureFlags::new().with_channel_interrupt_redirection(true)
2311 );
2312 }
2313
2314 #[async_test]
2315 async fn test_client_id(driver: DefaultDriver) {
2316 let (mut server, client) = test_init(&driver);
2317 let initiate_contact = ConnectRequest {
2318 client_id: VMBUS_TEST_CLIENT_ID,
2319 ..Default::default()
2320 };
2321 let _recv = client
2322 .access
2323 .client_request_send
2324 .call(ClientRequest::Connect, initiate_contact);
2325
2326 check_message(
2327 server.next().await.unwrap(),
2328 protocol::InitiateContact2 {
2329 initiate_contact: protocol::InitiateContact {
2330 version_requested: Version::Copper as u32,
2331 target_message_vp: 0,
2332 interrupt_page_or_target_info: TargetInfo::new()
2333 .with_sint(2)
2334 .with_vtl(0)
2335 .with_feature_flags(SUPPORTED_FEATURE_FLAGS.into())
2336 .into(),
2337 parent_to_child_monitor_page_gpa: 0,
2338 child_to_parent_monitor_page_gpa: 0,
2339 },
2340 client_id: VMBUS_TEST_CLIENT_ID,
2341 },
2342 );
2343 }
2344
2345 #[async_test]
2346 async fn test_version_negotiation(driver: DefaultDriver) {
2347 let (mut server, mut client) = test_init(&driver);
2348 let client_connect = client.connect(0, None, Guid::ZERO);
2349
2350 let server_connect = async {
2351 check_message(
2352 server.next().await.unwrap(),
2353 protocol::InitiateContact2 {
2354 initiate_contact: protocol::InitiateContact {
2355 version_requested: Version::Copper as u32,
2356 target_message_vp: 0,
2357 interrupt_page_or_target_info: TargetInfo::new()
2358 .with_sint(2)
2359 .with_vtl(0)
2360 .with_feature_flags(SUPPORTED_FEATURE_FLAGS.into())
2361 .into(),
2362 parent_to_child_monitor_page_gpa: 0,
2363 child_to_parent_monitor_page_gpa: 0,
2364 },
2365 ..FromZeros::new_zeroed()
2366 },
2367 );
2368
2369 server.send(in_msg(
2370 MessageType::VERSION_RESPONSE,
2371 protocol::VersionResponse {
2372 version_supported: 0,
2373 connection_state: ConnectionState::SUCCESSFUL,
2374 padding: 0,
2375 selected_version_or_connection_id: 0,
2376 },
2377 ));
2378
2379 check_message(
2380 server.next().await.unwrap(),
2381 protocol::InitiateContact {
2382 version_requested: Version::Iron as u32,
2383 target_message_vp: 0,
2384 interrupt_page_or_target_info: TargetInfo::new()
2385 .with_sint(2)
2386 .with_vtl(0)
2387 .with_feature_flags(FeatureFlags::new().into())
2388 .into(),
2389 parent_to_child_monitor_page_gpa: 0,
2390 child_to_parent_monitor_page_gpa: 0,
2391 },
2392 );
2393
2394 server.send(in_msg(
2395 MessageType::VERSION_RESPONSE,
2396 protocol::VersionResponse {
2397 version_supported: 1,
2398 connection_state: ConnectionState::SUCCESSFUL,
2399 padding: 0,
2400 selected_version_or_connection_id: 0,
2401 },
2402 ));
2403
2404 check_message(server.next().await.unwrap(), protocol::RequestOffers {});
2405 server.send(in_msg(MessageType::ALL_OFFERS_DELIVERED, [0x00]));
2406 };
2407
2408 let (connection, ()) = (client_connect, server_connect).join().await;
2409 let connection = connection.unwrap();
2410
2411 assert_eq!(connection.version.version, Version::Iron);
2412 assert_eq!(connection.version.feature_flags, FeatureFlags::new());
2413 }
2414
2415 #[async_test]
2416 async fn test_open_channel_success(driver: DefaultDriver) {
2417 let (mut server, mut client) = test_init(&driver);
2418 let channel = server.get_channel(&mut client).await;
2419
2420 let recv = channel.request_send.call(
2421 ChannelRequest::Open,
2422 OpenRequest {
2423 open_data: OpenData {
2424 target_vp: 0,
2425 ring_offset: 0,
2426 ring_gpadl_id: GpadlId(0),
2427 event_flag: 0,
2428 connection_id: 0,
2429 user_data: UserDefinedData::new_zeroed(),
2430 },
2431 incoming_event: None,
2432 use_vtl2_connection_id: false,
2433 },
2434 );
2435
2436 check_message(
2437 server.next().await.unwrap(),
2438 protocol::OpenChannel2 {
2439 open_channel: protocol::OpenChannel {
2440 channel_id: ChannelId(0),
2441 open_id: 0,
2442 ring_buffer_gpadl_id: GpadlId(0),
2443 target_vp: 0,
2444 downstream_ring_buffer_page_offset: 0,
2445 user_data: UserDefinedData::new_zeroed(),
2446 },
2447 connection_id: 0,
2448 event_flag: 0,
2449 flags: Default::default(),
2450 },
2451 );
2452
2453 server.send(in_msg(
2454 MessageType::OPEN_CHANNEL_RESULT,
2455 protocol::OpenResult {
2456 channel_id: ChannelId(0),
2457 open_id: 0,
2458 status: protocol::STATUS_SUCCESS as u32,
2459 },
2460 ));
2461
2462 recv.await.unwrap().unwrap();
2463 }
2464
2465 #[async_test]
2466 async fn test_open_channel_fail(driver: DefaultDriver) {
2467 let (mut server, mut client) = test_init(&driver);
2468 let channel = server.get_channel(&mut client).await;
2469
2470 let recv = channel.request_send.call(
2471 ChannelRequest::Open,
2472 OpenRequest {
2473 open_data: OpenData {
2474 target_vp: 0,
2475 ring_offset: 0,
2476 ring_gpadl_id: GpadlId(0),
2477 event_flag: 0,
2478 connection_id: 0,
2479 user_data: UserDefinedData::new_zeroed(),
2480 },
2481 incoming_event: None,
2482 use_vtl2_connection_id: false,
2483 },
2484 );
2485
2486 check_message(
2487 server.next().await.unwrap(),
2488 protocol::OpenChannel2 {
2489 open_channel: protocol::OpenChannel {
2490 channel_id: ChannelId(0),
2491 open_id: 0,
2492 ring_buffer_gpadl_id: GpadlId(0),
2493 target_vp: 0,
2494 downstream_ring_buffer_page_offset: 0,
2495 user_data: UserDefinedData::new_zeroed(),
2496 },
2497 connection_id: 0,
2498 event_flag: 0,
2499 flags: Default::default(),
2500 },
2501 );
2502
2503 server.send(in_msg(
2504 MessageType::OPEN_CHANNEL_RESULT,
2505 protocol::OpenResult {
2506 channel_id: ChannelId(0),
2507 open_id: 0,
2508 status: protocol::STATUS_UNSUCCESSFUL as u32,
2509 },
2510 ));
2511
2512 recv.await.unwrap().unwrap_err();
2513 }
2514
2515 #[async_test]
2516 async fn test_modify_channel(driver: DefaultDriver) {
2517 let (mut server, mut client) = test_init(&driver);
2518 let channel = server.get_channel(&mut client).await;
2519
2520 let recv = channel.request_send.call(
2523 ChannelRequest::Modify,
2524 ModifyRequest::TargetVp { target_vp: 1 },
2525 );
2526
2527 check_message(
2528 server.next().await.unwrap(),
2529 protocol::ModifyChannel {
2530 channel_id: ChannelId(0),
2531 target_vp: 1,
2532 },
2533 );
2534
2535 server.send(in_msg(
2536 MessageType::MODIFY_CHANNEL_RESPONSE,
2537 protocol::ModifyChannelResponse {
2538 channel_id: ChannelId(0),
2539 status: protocol::STATUS_SUCCESS,
2540 },
2541 ));
2542
2543 let status = recv.await.unwrap();
2544 assert_eq!(status, protocol::STATUS_SUCCESS);
2545 }
2546
2547 #[async_test]
2548 async fn test_save_restore_connected(driver: DefaultDriver) {
2549 let (mut server, mut client) = test_init(&driver);
2550 server.connect(&mut client).await;
2551 server.stop_client(&mut client).await;
2552 let s0 = client.save().await;
2553 let builder = client.sever().await;
2554 let mut client = builder.build(&driver);
2555 client.restore(s0.clone()).await.unwrap();
2556
2557 let s1 = client.save().await;
2558
2559 assert_eq!(s0, s1);
2560 }
2561
2562 #[async_test]
2563 async fn test_save_restore_connected_with_channel(driver: DefaultDriver) {
2564 let (mut server, mut client) = test_init(&driver);
2565 let c0 = server.get_channel(&mut client).await;
2566 server.stop_client(&mut client).await;
2567 let s0 = client.save().await;
2568 let builder = client.sever().await;
2569 let mut client = builder.build(&driver);
2570 let connection = client.restore(s0.clone()).await.unwrap().unwrap();
2571 let s1 = client.save().await;
2572 assert_eq!(s0, s1);
2573 assert_eq!(connection.offers[0].offer, c0.offer);
2574 }
2575
2576 #[async_test]
2577 async fn test_save_restore_connected_with_revoked_channel(driver: DefaultDriver) {
2578 let (mut server, mut client) = test_init(&driver);
2579 let c0 = server.get_channel(&mut client).await;
2580 server.send(in_msg(
2581 MessageType::RESCIND_CHANNEL_OFFER,
2582 protocol::RescindChannelOffer {
2583 channel_id: ChannelId(0),
2584 },
2585 ));
2586 c0.revoke_recv.await.unwrap();
2587 let rpc = c0.request_send.call(
2588 ChannelRequest::Modify,
2589 ModifyRequest::TargetVp { target_vp: 1 },
2590 );
2591
2592 check_message(
2593 server.next().await.unwrap(),
2594 protocol::ModifyChannel {
2595 channel_id: ChannelId(0),
2596 target_vp: 1,
2597 },
2598 );
2599
2600 let client_stop = client.stop();
2601 let server_stop = async {
2602 server.send(in_msg(
2603 MessageType::MODIFY_CHANNEL_RESPONSE,
2604 protocol::ModifyChannelResponse {
2605 channel_id: ChannelId(0),
2606 status: protocol::STATUS_SUCCESS,
2607 },
2608 ));
2609 check_message(server.next().await.unwrap(), protocol::Pause);
2610 server.send(in_msg(MessageType::PAUSE_RESPONSE, protocol::PauseResponse));
2611 };
2612 (client_stop, server_stop).join().await;
2613
2614 rpc.await.unwrap();
2615
2616 let s0 = client.save().await;
2617 let builder = client.sever().await;
2618 let mut client = builder.build(&driver);
2619 let connection = client.restore(s0.clone()).await.unwrap().unwrap();
2620 let s1 = client.save().await;
2621 assert_eq!(s0, s1);
2622 assert!(connection.offers.is_empty());
2623 server.start_client(&mut client).await;
2624 check_message(
2625 server.next().await.unwrap(),
2626 protocol::RelIdReleased {
2627 channel_id: ChannelId(0),
2628 },
2629 );
2630 }
2631
2632 #[async_test]
2633 async fn test_connect_fails_on_incorrect_state(driver: DefaultDriver) {
2634 let (mut server, mut client) = test_init(&driver);
2635 server.connect(&mut client).await;
2636 let err = client.connect(0, None, Guid::ZERO).await.unwrap_err();
2637 assert!(matches!(err, ConnectError::InvalidState), "{:?}", err);
2638 }
2639
2640 #[async_test]
2641 async fn test_hot_add_remove(driver: DefaultDriver) {
2642 let (mut server, mut client) = test_init(&driver);
2643
2644 let mut connection = server.connect(&mut client).await;
2645 let offer = protocol::OfferChannel {
2646 interface_id: Guid::new_random(),
2647 instance_id: Guid::new_random(),
2648 rsvd: [0; 4],
2649 flags: OfferFlags::new(),
2650 mmio_megabytes: 0,
2651 user_defined: UserDefinedData::new_zeroed(),
2652 subchannel_index: 0,
2653 mmio_megabytes_optional: 0,
2654 channel_id: ChannelId(5),
2655 monitor_id: 0,
2656 monitor_allocated: 0,
2657 is_dedicated: 0,
2658 connection_id: 0,
2659 };
2660
2661 server.send(in_msg(MessageType::OFFER_CHANNEL, offer));
2662 let info = connection.offer_recv.next().await.unwrap();
2663
2664 assert_eq!(offer, info.offer);
2665
2666 server.send(in_msg(
2667 MessageType::RESCIND_CHANNEL_OFFER,
2668 protocol::RescindChannelOffer {
2669 channel_id: ChannelId(5),
2670 },
2671 ));
2672
2673 info.revoke_recv.await.unwrap();
2674 drop(info.request_send);
2675
2676 check_message(
2677 server.next().await.unwrap(),
2678 protocol::RelIdReleased {
2679 channel_id: ChannelId(5),
2680 },
2681 );
2682 }
2683
2684 #[async_test]
2685 async fn test_gpadl_success(driver: DefaultDriver) {
2686 let (mut server, mut client) = test_init(&driver);
2687 let channel = server.get_channel(&mut client).await;
2688 let recv = channel.request_send.call(
2689 ChannelRequest::Gpadl,
2690 GpadlRequest {
2691 id: GpadlId(1),
2692 count: 1,
2693 buf: vec![5],
2694 },
2695 );
2696
2697 check_message_with_data(
2698 server.next().await.unwrap(),
2699 protocol::GpadlHeader {
2700 channel_id: ChannelId(0),
2701 gpadl_id: GpadlId(1),
2702 len: 8,
2703 count: 1,
2704 },
2705 0x5u64.as_bytes(),
2706 );
2707
2708 server.send(in_msg(
2709 MessageType::GPADL_CREATED,
2710 protocol::GpadlCreated {
2711 channel_id: ChannelId(0),
2712 gpadl_id: GpadlId(1),
2713 status: protocol::STATUS_SUCCESS,
2714 },
2715 ));
2716
2717 recv.await.unwrap().unwrap();
2718
2719 let rpc = channel
2720 .request_send
2721 .call(ChannelRequest::TeardownGpadl, GpadlId(1));
2722
2723 check_message(
2724 server.next().await.unwrap(),
2725 protocol::GpadlTeardown {
2726 channel_id: ChannelId(0),
2727 gpadl_id: GpadlId(1),
2728 },
2729 );
2730
2731 server.send(in_msg(
2732 MessageType::GPADL_TORNDOWN,
2733 protocol::GpadlTorndown {
2734 gpadl_id: GpadlId(1),
2735 },
2736 ));
2737
2738 rpc.await.unwrap();
2739 }
2740
2741 #[async_test]
2742 async fn test_gpadl_fail(driver: DefaultDriver) {
2743 let (mut server, mut client) = test_init(&driver);
2744 let channel = server.get_channel(&mut client).await;
2745 let recv = channel.request_send.call(
2746 ChannelRequest::Gpadl,
2747 GpadlRequest {
2748 id: GpadlId(1),
2749 count: 1,
2750 buf: vec![7],
2751 },
2752 );
2753
2754 check_message_with_data(
2755 server.next().await.unwrap(),
2756 protocol::GpadlHeader {
2757 channel_id: ChannelId(0),
2758 gpadl_id: GpadlId(1),
2759 len: 8,
2760 count: 1,
2761 },
2762 0x7u64.as_bytes(),
2763 );
2764
2765 server.send(in_msg(
2766 MessageType::GPADL_CREATED,
2767 protocol::GpadlCreated {
2768 channel_id: ChannelId(0),
2769 gpadl_id: GpadlId(1),
2770 status: protocol::STATUS_UNSUCCESSFUL,
2771 },
2772 ));
2773
2774 recv.await.unwrap().unwrap_err();
2775 }
2776
2777 #[async_test]
2778 async fn test_gpadl_with_revoke(driver: DefaultDriver) {
2779 let (mut server, mut client) = test_init(&driver);
2780 let channel = server.get_channel(&mut client).await;
2781 let channel_id = ChannelId(0);
2782 for gpadl_id in [1, 2, 3].map(GpadlId) {
2783 let recv = channel.request_send.call(
2784 ChannelRequest::Gpadl,
2785 GpadlRequest {
2786 id: gpadl_id,
2787 count: 1,
2788 buf: vec![3],
2789 },
2790 );
2791
2792 check_message_with_data(
2793 server.next().await.unwrap(),
2794 protocol::GpadlHeader {
2795 channel_id,
2796 gpadl_id,
2797 len: 8,
2798 count: 1,
2799 },
2800 0x3u64.as_bytes(),
2801 );
2802
2803 server.send(in_msg(
2804 MessageType::GPADL_CREATED,
2805 protocol::GpadlCreated {
2806 channel_id,
2807 gpadl_id,
2808 status: protocol::STATUS_SUCCESS,
2809 },
2810 ));
2811
2812 recv.await.unwrap().unwrap();
2813 }
2814
2815 let rpc = channel
2816 .request_send
2817 .call(ChannelRequest::TeardownGpadl, GpadlId(1));
2818
2819 check_message(
2820 server.next().await.unwrap(),
2821 protocol::GpadlTeardown {
2822 channel_id,
2823 gpadl_id: GpadlId(1),
2824 },
2825 );
2826
2827 server.send(in_msg(
2828 MessageType::RESCIND_CHANNEL_OFFER,
2829 protocol::RescindChannelOffer { channel_id },
2830 ));
2831
2832 let recv = channel.request_send.call_failable(
2833 ChannelRequest::Gpadl,
2834 GpadlRequest {
2835 id: GpadlId(4),
2836 count: 1,
2837 buf: vec![3],
2838 },
2839 );
2840
2841 check_message_with_data(
2842 server.next().await.unwrap(),
2843 protocol::GpadlHeader {
2844 channel_id,
2845 gpadl_id: GpadlId(4),
2846 len: 8,
2847 count: 1,
2848 },
2849 0x3u64.as_bytes(),
2850 );
2851
2852 server.send(in_msg(
2853 MessageType::GPADL_CREATED,
2854 protocol::GpadlCreated {
2855 channel_id,
2856 gpadl_id: GpadlId(4),
2857 status: protocol::STATUS_UNSUCCESSFUL,
2858 },
2859 ));
2860
2861 server.send(in_msg(
2862 MessageType::GPADL_TORNDOWN,
2863 protocol::GpadlTorndown {
2864 gpadl_id: GpadlId(1),
2865 },
2866 ));
2867
2868 rpc.await.unwrap();
2869 recv.await.unwrap_err();
2870
2871 channel.revoke_recv.await.unwrap();
2872
2873 let rpc = channel
2874 .request_send
2875 .call(ChannelRequest::TeardownGpadl, GpadlId(2));
2876 drop(channel.request_send);
2877
2878 check_message(
2879 server.next().await.unwrap(),
2880 protocol::GpadlTeardown {
2881 channel_id,
2882 gpadl_id: GpadlId(2),
2883 },
2884 );
2885
2886 server.send(in_msg(
2887 MessageType::GPADL_TORNDOWN,
2888 protocol::GpadlTorndown {
2889 gpadl_id: GpadlId(2),
2890 },
2891 ));
2892
2893 rpc.await.unwrap();
2894
2895 check_message(
2896 server.next().await.unwrap(),
2897 protocol::RelIdReleased { channel_id },
2898 );
2899 }
2900
2901 #[async_test]
2902 async fn test_modify_connection(driver: DefaultDriver) {
2903 let (mut server, mut client) = test_init(&driver);
2904 server.connect(&mut client).await;
2905 let call = client.access.client_request_send.call(
2906 ClientRequest::Modify,
2907 ModifyConnectionRequest {
2908 monitor_page: Some(MonitorPageGpas {
2909 child_to_parent: 5,
2910 parent_to_child: 6,
2911 }),
2912 },
2913 );
2914
2915 check_message(
2916 server.next().await.unwrap(),
2917 protocol::ModifyConnection {
2918 child_to_parent_monitor_page_gpa: 5,
2919 parent_to_child_monitor_page_gpa: 6,
2920 },
2921 );
2922
2923 server.send(in_msg(
2924 MessageType::MODIFY_CONNECTION_RESPONSE,
2925 protocol::ModifyConnectionResponse {
2926 connection_state: ConnectionState::FAILED_LOW_RESOURCES,
2927 },
2928 ));
2929
2930 let result = call.await.unwrap();
2931 assert_eq!(ConnectionState::FAILED_LOW_RESOURCES, result);
2932 }
2933
2934 #[async_test]
2935 async fn test_hvsock(driver: DefaultDriver) {
2936 let (mut server, mut client) = test_init(&driver);
2937 server.connect(&mut client).await;
2938 let request = HvsockConnectRequest {
2939 service_id: Guid::new_random(),
2940 endpoint_id: Guid::new_random(),
2941 silo_id: Guid::new_random(),
2942 hosted_silo_unaware: false,
2943 };
2944
2945 let resp = client.access().connect_hvsock(request);
2946 check_message(
2947 server.next().await.unwrap(),
2948 protocol::TlConnectRequest2 {
2949 base: protocol::TlConnectRequest {
2950 service_id: request.service_id,
2951 endpoint_id: request.endpoint_id,
2952 },
2953 silo_id: request.silo_id,
2954 },
2955 );
2956
2957 server.send(in_msg(
2959 MessageType::TL_CONNECT_REQUEST_RESULT,
2960 protocol::TlConnectResult {
2961 service_id: request.service_id,
2962 endpoint_id: request.endpoint_id,
2963 status: protocol::STATUS_CONNECTION_REFUSED,
2964 },
2965 ));
2966
2967 let result = resp.await;
2968 assert!(result.is_none());
2969 }
2970
2971 #[async_test]
2972 async fn test_synic_event_flags(driver: DefaultDriver) {
2973 let (mut server, mut client) = test_init(&driver);
2974 let connection = server.get_channels(&mut client, 5).await;
2975 let event = Event::new();
2976
2977 for _ in 0..5 {
2978 for (i, channel) in connection.offers.iter().enumerate() {
2979 let recv = channel.request_send.call(
2980 ChannelRequest::Open,
2981 OpenRequest {
2982 open_data: OpenData {
2983 target_vp: 0,
2984 ring_offset: 0,
2985 ring_gpadl_id: GpadlId(0),
2986 event_flag: 0,
2987 connection_id: 0,
2988 user_data: UserDefinedData::new_zeroed(),
2989 },
2990 incoming_event: Some(event.clone()),
2991 use_vtl2_connection_id: false,
2992 },
2993 );
2994
2995 let expected_event_flag = i as u16 + 1;
2996
2997 check_message(
2998 server.next().await.unwrap(),
2999 protocol::OpenChannel2 {
3000 open_channel: protocol::OpenChannel {
3001 channel_id: channel.offer.channel_id,
3002 open_id: 0,
3003 ring_buffer_gpadl_id: GpadlId(0),
3004 target_vp: 0,
3005 downstream_ring_buffer_page_offset: 0,
3006 user_data: UserDefinedData::new_zeroed(),
3007 },
3008 connection_id: 0,
3009 event_flag: expected_event_flag,
3010 flags: OpenChannelFlags::new().with_redirect_interrupt(true),
3011 },
3012 );
3013
3014 server.send(in_msg(
3015 MessageType::OPEN_CHANNEL_RESULT,
3016 protocol::OpenResult {
3017 channel_id: channel.offer.channel_id,
3018 open_id: 0,
3019 status: protocol::STATUS_SUCCESS as u32,
3020 },
3021 ));
3022
3023 let output = recv.await.unwrap().unwrap();
3024 assert_eq!(output.redirected_event_flag, Some(expected_event_flag));
3025 }
3026
3027 for (i, channel) in connection.offers.iter().enumerate() {
3028 channel
3031 .request_send
3032 .call(ChannelRequest::Close, ())
3033 .await
3034 .unwrap();
3035
3036 check_message(
3037 server.next().await.unwrap(),
3038 protocol::CloseChannel {
3039 channel_id: ChannelId(i as u32),
3040 },
3041 );
3042 }
3043 }
3044 }
3045
3046 #[async_test]
3047 async fn test_revoke(driver: DefaultDriver) {
3048 let (mut server, mut client) = test_init(&driver);
3049 let channel = server.get_channel(&mut client).await;
3050
3051 server.send(in_msg(
3052 MessageType::RESCIND_CHANNEL_OFFER,
3053 protocol::RescindChannelOffer {
3054 channel_id: ChannelId(0),
3055 },
3056 ));
3057
3058 channel.revoke_recv.await.unwrap();
3059
3060 channel
3061 .request_send
3062 .call_failable(
3063 ChannelRequest::Open,
3064 OpenRequest {
3065 open_data: OpenData {
3066 target_vp: 0,
3067 ring_offset: 0,
3068 ring_gpadl_id: GpadlId(0),
3069 event_flag: 0,
3070 connection_id: 0,
3071 user_data: UserDefinedData::new_zeroed(),
3072 },
3073 incoming_event: None,
3074 use_vtl2_connection_id: false,
3075 },
3076 )
3077 .await
3078 .unwrap_err();
3079 }
3080
3081 #[async_test]
3082 #[should_panic(expected = "channel should not exist")]
3083 async fn test_reoffer_in_use_rel_id(driver: DefaultDriver) {
3084 let (mut server, mut client) = test_init(&driver);
3085 let mut connection = server.get_channels(&mut client, 1).await;
3086 let [channel] = connection.offers.try_into().unwrap();
3087
3088 server.send(in_msg(
3089 MessageType::RESCIND_CHANNEL_OFFER,
3090 protocol::RescindChannelOffer {
3091 channel_id: ChannelId(0),
3092 },
3093 ));
3094
3095 channel.revoke_recv.await.unwrap();
3096
3097 let offer = protocol::OfferChannel {
3099 interface_id: Guid::new_random(),
3100 instance_id: Guid::new_random(),
3101 rsvd: [0; 4],
3102 flags: OfferFlags::new(),
3103 mmio_megabytes: 0,
3104 user_defined: UserDefinedData::new_zeroed(),
3105 subchannel_index: 0,
3106 mmio_megabytes_optional: 0,
3107 channel_id: ChannelId(0),
3108 monitor_id: 0,
3109 monitor_allocated: 0,
3110 is_dedicated: 0,
3111 connection_id: 0,
3112 };
3113
3114 server.send(in_msg(MessageType::OFFER_CHANNEL, offer));
3115
3116 connection.offer_recv.next().await;
3117 }
3118
3119 #[async_test]
3120 async fn test_revoke_release_and_reoffer(driver: DefaultDriver) {
3121 let (mut server, mut client) = test_init(&driver);
3122 let mut connection = server.get_channels(&mut client, 1).await;
3123 let [channel] = connection.offers.try_into().unwrap();
3124
3125 server.send(in_msg(
3126 MessageType::RESCIND_CHANNEL_OFFER,
3127 protocol::RescindChannelOffer {
3128 channel_id: ChannelId(0),
3129 },
3130 ));
3131
3132 channel.revoke_recv.await.unwrap();
3133 drop(channel.request_send);
3134
3135 check_message(
3136 server.next().await.unwrap(),
3137 protocol::RelIdReleased {
3138 channel_id: ChannelId(0),
3139 },
3140 );
3141
3142 let offer = protocol::OfferChannel {
3143 interface_id: Guid::new_random(),
3144 instance_id: Guid::new_random(),
3145 rsvd: [0; 4],
3146 flags: OfferFlags::new(),
3147 mmio_megabytes: 0,
3148 user_defined: UserDefinedData::new_zeroed(),
3149 subchannel_index: 0,
3150 mmio_megabytes_optional: 0,
3151 channel_id: ChannelId(0),
3152 monitor_id: 0,
3153 monitor_allocated: 0,
3154 is_dedicated: 0,
3155 connection_id: 0,
3156 };
3157
3158 server.send(in_msg(MessageType::OFFER_CHANNEL, offer));
3159
3160 connection.offer_recv.next().await.unwrap();
3161 }
3162
3163 #[async_test]
3164 async fn test_release_revoke_and_reoffer(driver: DefaultDriver) {
3165 let (mut server, mut client) = test_init(&driver);
3166 let mut connection = server.get_channels(&mut client, 1).await;
3167 let [channel] = connection.offers.try_into().unwrap();
3168
3169 let open = channel.request_send.call_failable(
3170 ChannelRequest::Open,
3171 OpenRequest {
3172 open_data: OpenData {
3173 target_vp: 0,
3174 ring_offset: 0,
3175 ring_gpadl_id: GpadlId(0),
3176 event_flag: 0,
3177 connection_id: 0,
3178 user_data: UserDefinedData::new_zeroed(),
3179 },
3180 incoming_event: None,
3181 use_vtl2_connection_id: false,
3182 },
3183 );
3184
3185 let server_open = async {
3186 check_message(
3187 server.next().await.unwrap(),
3188 protocol::OpenChannel2 {
3189 open_channel: protocol::OpenChannel {
3190 channel_id: ChannelId(0),
3191 open_id: 0,
3192 ring_buffer_gpadl_id: GpadlId(0),
3193 target_vp: 0,
3194 downstream_ring_buffer_page_offset: 0,
3195 user_data: UserDefinedData::new_zeroed(),
3196 },
3197 connection_id: 0,
3198 event_flag: 0,
3199 flags: Default::default(),
3200 },
3201 );
3202 server.send(in_msg(
3203 MessageType::OPEN_CHANNEL_RESULT,
3204 protocol::OpenResult {
3205 channel_id: ChannelId(0),
3206 open_id: 0,
3207 status: protocol::STATUS_SUCCESS as u32,
3208 },
3209 ));
3210 };
3211
3212 (open, server_open).join().await.0.unwrap();
3213
3214 drop(channel);
3216
3217 check_message(
3218 server.next().await.unwrap(),
3219 protocol::CloseChannel {
3220 channel_id: ChannelId(0),
3221 },
3222 );
3223
3224 server.send(in_msg(
3225 MessageType::RESCIND_CHANNEL_OFFER,
3226 protocol::RescindChannelOffer {
3227 channel_id: ChannelId(0),
3228 },
3229 ));
3230
3231 check_message(
3233 server.next().await.unwrap(),
3234 protocol::RelIdReleased {
3235 channel_id: ChannelId(0),
3236 },
3237 );
3238
3239 let offer = protocol::OfferChannel {
3240 interface_id: Guid::new_random(),
3241 instance_id: Guid::new_random(),
3242 rsvd: [0; 4],
3243 flags: OfferFlags::new(),
3244 mmio_megabytes: 0,
3245 user_defined: UserDefinedData::new_zeroed(),
3246 subchannel_index: 0,
3247 mmio_megabytes_optional: 0,
3248 channel_id: ChannelId(0),
3249 monitor_id: 0,
3250 monitor_allocated: 0,
3251 is_dedicated: 0,
3252 connection_id: 0,
3253 };
3254
3255 server.send(in_msg(MessageType::OFFER_CHANNEL, offer));
3256
3257 connection.offer_recv.next().await.unwrap();
3259 }
3260}