1#![expect(missing_docs)]
7#![forbid(unsafe_code)]
8
9pub mod driver;
10pub mod filter;
11mod hvsock;
12pub mod saved_state;
13
14pub use self::saved_state::SavedState;
15use anyhow::Context as _;
16use anyhow::Result;
17use futures::FutureExt;
18use futures::StreamExt;
19use futures::future::OptionFuture;
20use futures::stream::SelectAll;
21use futures_concurrency::future::Race;
22use guid::Guid;
23use inspect::Inspect;
24use mesh::rpc::FailableRpc;
25use mesh::rpc::Rpc;
26use mesh::rpc::RpcSend;
27use pal_async::task::Spawn;
28use pal_async::task::Task;
29use pal_event::Event;
30use std::collections::HashMap;
31use std::collections::VecDeque;
32use std::collections::hash_map;
33use std::convert::TryInto;
34use std::future::Future;
35use std::future::poll_fn;
36use std::ops::Deref;
37use std::ops::DerefMut;
38use std::pin::pin;
39use std::sync::Arc;
40use std::sync::atomic::AtomicU32;
41use std::sync::atomic::Ordering;
42use std::task::Context;
43use std::task::Poll;
44use thiserror::Error;
45use vmbus_async::async_dgram::AsyncRecv;
46use vmbus_async::async_dgram::AsyncRecvExt;
47use vmbus_channel::bus::GpadlRequest;
48use vmbus_channel::bus::ModifyRequest;
49use vmbus_channel::bus::OfferKey;
50use vmbus_channel::bus::OpenData;
51use vmbus_channel::gpadl::GpadlId;
52use vmbus_core::HvsockConnectRequest;
53use vmbus_core::OutgoingMessage;
54use vmbus_core::TaggedStream;
55use vmbus_core::VersionInfo;
56use vmbus_core::protocol;
57use vmbus_core::protocol::ChannelId;
58use vmbus_core::protocol::ConnectionState;
59use vmbus_core::protocol::FeatureFlags;
60use vmbus_core::protocol::Message;
61use vmbus_core::protocol::OpenChannelFlags;
62use vmbus_core::protocol::Version;
63use vmcore::interrupt::Interrupt;
64use vmcore::synic::MonitorPageGpas;
65use zerocopy::Immutable;
66use zerocopy::IntoBytes;
67use zerocopy::KnownLayout;
68
69const SINT: u8 = 2;
70const VTL: u8 = 0;
71const SUPPORTED_VERSIONS: &[Version] = &[Version::Iron, Version::Copper];
72const SUPPORTED_FEATURE_FLAGS: FeatureFlags = FeatureFlags::new()
73 .with_guest_specified_signal_parameters(true)
74 .with_channel_interrupt_redirection(true)
75 .with_modify_connection(true)
76 .with_client_id(true)
77 .with_pause_resume(true);
78
79pub trait SynicEventClient: Send + Sync {
81 fn map_event(&self, event_flag: u16, event: &Event) -> std::io::Result<()>;
83
84 fn unmap_event(&self, event_flag: u16);
86
87 fn signal_event(&self, connection_id: u32, event_flag: u16) -> std::io::Result<()>;
89}
90
91pub trait VmbusMessageSource: AsyncRecv + Send {
93 fn pause_message_stream(&mut self) {}
96
97 fn resume_message_stream(&mut self) {}
99}
100
101pub trait PollPostMessage: Send {
102 fn poll_post_message(
103 &mut self,
104 cx: &mut Context<'_>,
105 connection_id: u32,
106 typ: u32,
107 msg: &[u8],
108 ) -> Poll<()>;
109}
110
111#[derive(Inspect)]
112pub struct VmbusClient {
113 #[inspect(flatten, send = "TaskRequest::Inspect")]
114 task_send: mesh::Sender<TaskRequest>,
115 #[inspect(skip)]
116 access: VmbusClientAccess,
117 #[inspect(skip)]
118 task: Task<ClientTask>,
119}
120
121#[derive(Debug, thiserror::Error)]
122pub enum ConnectError {
123 #[error("invalid state to connect to the server")]
124 InvalidState,
125 #[error("no supported protocol versions")]
126 NoSupportedVersions,
127 #[error("failed to connect to the server: {0:?}")]
128 FailedToConnect(ConnectionState),
129}
130
131#[derive(Clone)]
132pub struct VmbusClientAccess {
133 client_request_send: mesh::Sender<ClientRequest>,
134}
135
136pub struct VmbusClientBuilder {
138 event_client: Arc<dyn SynicEventClient>,
139 msg_source: Box<dyn VmbusMessageSource>,
140 msg_client: Box<dyn PollPostMessage>,
141}
142
143impl VmbusClientBuilder {
144 pub fn new(
146 event_client: impl SynicEventClient + 'static,
147 msg_source: impl VmbusMessageSource + 'static,
148 msg_client: impl PollPostMessage + 'static,
149 ) -> Self {
150 Self {
151 event_client: Arc::new(event_client),
152 msg_source: Box::new(msg_source),
153 msg_client: Box::new(msg_client),
154 }
155 }
156
157 pub fn build(self, spawner: &impl Spawn) -> VmbusClient {
159 let (task_send, task_recv) = mesh::channel();
160 let (client_request_send, client_request_recv) = mesh::channel();
161
162 let inner = ClientTaskInner {
163 messages: OutgoingMessages {
164 poster: self.msg_client,
165 queued: VecDeque::new(),
166 state: OutgoingMessageState::Paused,
167 },
168 teardown_gpadls: HashMap::new(),
169 channel_requests: SelectAll::new(),
170 synic: SynicState {
171 event_flag_state: Vec::new(),
172 event_client: self.event_client,
173 },
174 };
175
176 let mut task = ClientTask {
177 inner,
178 channels: ChannelList::default(),
179 task_recv,
180 running: false,
181 msg_source: self.msg_source,
182 client_request_recv,
183 state: ClientState::Disconnected,
184 modify_request: None,
185 hvsock_tracker: hvsock::HvsockRequestTracker::new(),
186 };
187
188 let task = spawner.spawn("vmbus client", async move {
189 task.run().await;
190 task
191 });
192
193 VmbusClient {
194 access: VmbusClientAccess {
195 client_request_send,
196 },
197 task_send,
198 task,
199 }
200 }
201}
202
203impl VmbusClient {
204 pub async fn connect(
207 &mut self,
208 target_message_vp: u32,
209 monitor_page: Option<MonitorPageGpas>,
210 client_id: Guid,
211 ) -> Result<ConnectResult, ConnectError> {
212 let request = ConnectRequest {
213 target_message_vp,
214 monitor_page,
215 client_id,
216 };
217
218 self.access
219 .client_request_send
220 .call(ClientRequest::Connect, request)
221 .await
222 .unwrap()
223 }
224
225 pub async fn unload(self) {
226 self.access
227 .client_request_send
228 .call(ClientRequest::Unload, ())
229 .await
230 .unwrap();
231
232 self.sever().await;
233 }
234
235 pub fn access(&self) -> &VmbusClientAccess {
236 &self.access
237 }
238
239 pub fn start(&mut self) {
240 self.task_send.send(TaskRequest::Start);
241 }
242
243 pub async fn stop(&mut self) {
244 self.task_send
245 .call(TaskRequest::Stop, ())
246 .await
247 .expect("Failed to send stop request");
248 }
249
250 pub async fn save(&self) -> SavedState {
251 self.task_send
252 .call(TaskRequest::Save, ())
253 .await
254 .expect("Failed to send save request")
255 }
256
257 pub async fn restore(
258 &mut self,
259 state: SavedState,
260 ) -> Result<Option<ConnectResult>, RestoreError> {
261 self.task_send
262 .call(TaskRequest::Restore, state)
263 .await
264 .expect("Failed to send restore request")
265 }
266
267 pub async fn post_restore(&mut self) {
268 self.task_send
269 .call(TaskRequest::PostRestore, ())
270 .await
271 .expect("Failed to send post-restore request");
272 }
273
274 async fn sever(self) -> VmbusClientBuilder {
275 drop(self.task_send);
276 let task = self.task.await;
277 VmbusClientBuilder {
278 event_client: task.inner.synic.event_client,
279 msg_source: task.msg_source,
280 msg_client: task.inner.messages.poster,
281 }
282 }
283}
284
285#[derive(Debug)]
286pub struct ConnectResult {
287 pub version: VersionInfo,
288 pub offers: Vec<OfferInfo>,
289 pub offer_recv: mesh::Receiver<OfferInfo>,
290}
291
292impl VmbusClientAccess {
293 pub async fn modify(&self, request: ModifyConnectionRequest) -> ConnectionState {
294 self.client_request_send
295 .call(ClientRequest::Modify, request)
296 .await
297 .expect("Failed to send modify request")
298 }
299
300 pub fn connect_hvsock(
301 &self,
302 request: HvsockConnectRequest,
303 ) -> impl Future<Output = Option<OfferInfo>> + use<> {
304 self.client_request_send
305 .call(ClientRequest::HvsockConnect, request)
306 .map(|r| r.ok().flatten())
307 }
308}
309
310#[derive(Debug)]
311pub struct OpenRequest {
312 pub open_data: OpenData,
313 pub incoming_event: Option<Event>,
314 pub use_vtl2_connection_id: bool,
315}
316
317#[derive(Debug)]
318pub struct RestoreRequest {
319 pub incoming_event: Option<Event>,
320 pub redirected_event_flag: Option<u16>,
322 pub connection_id: u32,
324}
325
326pub enum ChannelRequest {
328 Open(FailableRpc<OpenRequest, OpenOutput>),
329 Restore(FailableRpc<RestoreRequest, OpenOutput>),
330 Close(Rpc<(), ()>),
331 Gpadl(FailableRpc<GpadlRequest, ()>),
332 TeardownGpadl(Rpc<GpadlId, ()>),
333 Modify(Rpc<ModifyRequest, i32>),
334}
335
336#[derive(Debug)]
337pub struct OpenOutput {
338 pub redirected_event_flag: Option<u16>,
340}
341
342impl std::fmt::Display for ChannelRequest {
343 fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
344 let s = match self {
345 ChannelRequest::Open(_) => "Open",
346 ChannelRequest::Close(_) => "Close",
347 ChannelRequest::Restore(_) => "Restore",
348 ChannelRequest::Gpadl(_) => "Gpadl",
349 ChannelRequest::TeardownGpadl(_) => "TeardownGpadl",
350 ChannelRequest::Modify(_) => "Modify",
351 };
352 fmt.pad(s)
353 }
354}
355
356#[derive(Debug, Error)]
357pub enum RestoreError {
358 #[error("unsupported protocol version {0:#x}")]
359 UnsupportedVersion(u32),
360
361 #[error("unsupported feature flags {0:#x}")]
362 UnsupportedFeatureFlags(u32),
363
364 #[error("duplicate channel id {0}")]
365 DuplicateChannelId(u32),
366
367 #[error("duplicate gpadl id {0}")]
368 DuplicateGpadlId(u32),
369
370 #[error("gpadl for unknown channel id {0}")]
371 GpadlForUnknownChannelId(u32),
372
373 #[error("invalid pending message")]
374 InvalidPendingMessage(#[source] vmbus_core::MessageTooLarge),
375
376 #[error("failed to offer restored channel")]
377 OfferFailed(#[source] anyhow::Error),
378}
379
380#[derive(Debug, Inspect)]
383pub struct OfferInfo {
384 pub offer: protocol::OfferChannel,
385 #[inspect(skip)]
386 pub guest_to_host_interrupt: Interrupt,
387 #[inspect(skip)]
388 pub request_send: mesh::Sender<ChannelRequest>,
389 #[inspect(skip)]
390 pub revoke_recv: mesh::OneshotReceiver<()>,
391}
392
393#[derive(Debug)]
394enum ClientRequest {
395 Connect(Rpc<ConnectRequest, Result<ConnectResult, ConnectError>>),
396 Unload(Rpc<(), ()>),
397 Modify(Rpc<ModifyConnectionRequest, ConnectionState>),
398 HvsockConnect(Rpc<HvsockConnectRequest, Option<OfferInfo>>),
399}
400
401impl std::fmt::Display for ClientRequest {
402 fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
403 let s = match self {
404 ClientRequest::Connect(..) => "Connect",
405 ClientRequest::Unload { .. } => "Unload",
406 ClientRequest::Modify(..) => "Modify",
407 ClientRequest::HvsockConnect(..) => "HvsockConnect",
408 };
409 fmt.pad(s)
410 }
411}
412
413enum TaskRequest {
414 Inspect(inspect::Deferred),
415 Save(Rpc<(), SavedState>),
416 Restore(Rpc<SavedState, Result<Option<ConnectResult>, RestoreError>>),
417 PostRestore(Rpc<(), ()>),
418 Start,
419 Stop(Rpc<(), ()>),
420}
421
422#[derive(Inspect)]
426#[inspect(external_tag)]
427enum ClientState {
428 Disconnected,
430 Connecting {
432 version: Version,
433 #[inspect(skip)]
434 rpc: Rpc<ConnectRequest, Result<ConnectResult, ConnectError>>,
435 },
436 Connected {
438 version: VersionInfo,
439 #[inspect(skip)]
440 offer_send: mesh::Sender<OfferInfo>,
441 },
442 RequestingOffers {
444 version: VersionInfo,
445 #[inspect(skip)]
446 rpc: Rpc<(), Result<ConnectResult, ConnectError>>,
447 #[inspect(skip)]
448 offers: Vec<OfferInfo>,
449 },
450 Disconnecting {
452 version: VersionInfo,
453 #[inspect(skip)]
454 rpc: Rpc<(), ()>,
455 },
456}
457
458impl ClientState {
459 fn get_version(&self) -> Option<VersionInfo> {
460 match self {
461 ClientState::Connected { version, .. } => Some(*version),
462 ClientState::RequestingOffers { version, .. } => Some(*version),
463 ClientState::Disconnecting { version, .. } => Some(*version),
464 ClientState::Disconnected | ClientState::Connecting { .. } => None,
465 }
466 }
467}
468
469impl std::fmt::Display for ClientState {
470 fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
471 let s = match self {
472 ClientState::Disconnected => "Disconnected",
473 ClientState::Connecting { .. } => "Connecting",
474 ClientState::Connected { .. } => "Connected",
475 ClientState::RequestingOffers { .. } => "RequestingOffers",
476 ClientState::Disconnecting { .. } => "Disconnecting",
477 };
478 fmt.pad(s)
479 }
480}
481
482#[derive(Copy, Clone, Debug, Default)]
483struct ConnectRequest {
484 target_message_vp: u32,
485 monitor_page: Option<MonitorPageGpas>,
486 client_id: Guid,
487}
488
489#[derive(Copy, Clone, Debug, Default)]
490pub struct ModifyConnectionRequest {
491 pub monitor_page: Option<MonitorPageGpas>,
492}
493
494impl From<ModifyConnectionRequest> for protocol::ModifyConnection {
495 fn from(value: ModifyConnectionRequest) -> Self {
496 let monitor_page = value.monitor_page.unwrap_or_default();
497
498 Self {
499 parent_to_child_monitor_page_gpa: monitor_page.parent_to_child,
500 child_to_parent_monitor_page_gpa: monitor_page.child_to_parent,
501 }
502 }
503}
504
505#[derive(Debug, Inspect)]
509#[inspect(external_tag)]
510enum ChannelState {
511 Offered,
513 Opening {
515 redirected_event_flag: Option<u16>,
516 #[inspect(skip)]
517 redirected_event: Option<Event>,
518 #[inspect(skip)]
519 rpc: FailableRpc<(), OpenOutput>,
520 },
521 Restored,
523 Opened {
525 redirected_event_flag: Option<u16>,
526 #[inspect(skip)]
527 redirected_event: Option<Event>,
528 },
529 Revoked,
531}
532
533impl std::fmt::Display for ChannelState {
534 fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
535 let s = match self {
536 ChannelState::Opening { .. } => "Opening",
537 ChannelState::Offered => "Offered",
538 ChannelState::Opened { .. } => "Opened",
539 ChannelState::Restored => "Restored",
540 ChannelState::Revoked => "Revoked",
541 };
542 fmt.pad(s)
543 }
544}
545
546#[derive(Debug, Inspect)]
547struct Channel {
548 offer: protocol::OfferChannel,
549 #[inspect(skip)]
551 revoke_send: Option<mesh::OneshotSender<()>>,
552 state: ChannelState,
553 #[inspect(with = "|x| x.is_some()")]
554 modify_response_send: Option<Rpc<(), i32>>,
555 #[inspect(with = "|x| inspect::iter_by_key(x).map_key(|x| x.0)")]
556 gpadls: HashMap<GpadlId, GpadlState>,
557 is_client_released: bool,
558 connection_id: Arc<AtomicU32>,
559}
560
561impl Channel {
562 fn pending_request(&self) -> Option<&'static str> {
563 if self.modify_response_send.is_some() {
564 return Some("modify");
565 }
566 self.gpadls.iter().find_map(|(_, gpadl)| match gpadl {
567 GpadlState::Offered(_) => Some("creating gpadl"),
568 GpadlState::Created => None,
569 GpadlState::TearingDown { .. } => Some("tearing down gpadl"),
570 })
571 }
572}
573
574#[derive(Inspect)]
575struct ClientTask {
576 #[inspect(flatten)]
577 inner: ClientTaskInner,
578 channels: ChannelList,
579 state: ClientState,
580 hvsock_tracker: hvsock::HvsockRequestTracker,
581 running: bool,
582 #[inspect(with = "|x| x.is_some()")]
583 modify_request: Option<Rpc<ModifyConnectionRequest, ConnectionState>>,
584 #[inspect(skip)]
585 msg_source: Box<dyn VmbusMessageSource>,
586 #[inspect(skip)]
587 task_recv: mesh::Receiver<TaskRequest>,
588 #[inspect(skip)]
589 client_request_recv: mesh::Receiver<ClientRequest>,
590}
591
592impl ClientTask {
593 fn handle_initiate_contact(
594 &mut self,
595 rpc: Rpc<ConnectRequest, Result<ConnectResult, ConnectError>>,
596 version: Version,
597 ) {
598 let ClientState::Disconnected = self.state else {
599 tracing::warn!(client_state = %self.state, "invalid client state for InitiateContact");
600 rpc.complete(Err(ConnectError::InvalidState));
601 return;
602 };
603 let feature_flags = if version >= Version::Copper {
604 SUPPORTED_FEATURE_FLAGS
605 } else {
606 FeatureFlags::new()
607 };
608
609 let request = rpc.input();
610
611 tracing::debug!(version = ?version, ?feature_flags, "VmBus client connecting");
612 let target_info = protocol::TargetInfo::new()
613 .with_sint(SINT)
614 .with_vtl(VTL)
615 .with_feature_flags(feature_flags.into());
616 let monitor_page = request.monitor_page.unwrap_or_default();
617 let msg = protocol::InitiateContact2 {
618 initiate_contact: protocol::InitiateContact {
619 version_requested: version as u32,
620 target_message_vp: request.target_message_vp,
621 interrupt_page_or_target_info: target_info.into(),
622 parent_to_child_monitor_page_gpa: monitor_page.parent_to_child,
623 child_to_parent_monitor_page_gpa: monitor_page.child_to_parent,
624 },
625 client_id: request.client_id,
626 };
627
628 self.state = ClientState::Connecting { version, rpc };
629 if version < Version::Copper {
630 self.inner.messages.send(&msg.initiate_contact)
631 } else {
632 self.inner.messages.send(&msg);
633 }
634 }
635
636 fn handle_unload(&mut self, rpc: Rpc<(), ()>) {
637 tracing::debug!(%self.state, "VmBus client disconnecting");
638 self.state = ClientState::Disconnecting {
639 version: self.state.get_version().expect("invalid state for unload"),
640 rpc,
641 };
642
643 self.inner.messages.send(&protocol::Unload {});
644 }
645
646 fn handle_modify(&mut self, request: Rpc<ModifyConnectionRequest, ConnectionState>) {
647 if !matches!(self.state, ClientState::Connected { version, .. }
648 if version.feature_flags.modify_connection())
649 {
650 tracing::warn!("ModifyConnection not supported");
651 request.complete(ConnectionState::FAILED_UNKNOWN_FAILURE);
652 return;
653 }
654
655 if self.modify_request.is_some() {
656 tracing::warn!("Duplicate ModifyConnection request");
657 request.complete(ConnectionState::FAILED_UNKNOWN_FAILURE);
658 return;
659 }
660
661 let message = protocol::ModifyConnection::from(*request.input());
662 self.modify_request = Some(request);
663 self.inner.messages.send(&message);
664 }
665
666 fn handle_tl_connect(&mut self, rpc: Rpc<HvsockConnectRequest, Option<OfferInfo>>) {
667 let message = protocol::TlConnectRequest2::from(*rpc.input());
671 self.hvsock_tracker.add_request(rpc);
672 self.inner.messages.send(&message);
673 }
674
675 fn handle_client_request(&mut self, request: ClientRequest) {
676 match request {
677 ClientRequest::Connect(rpc) => {
678 self.handle_initiate_contact(rpc, *SUPPORTED_VERSIONS.last().unwrap());
679 }
680 ClientRequest::Unload(rpc) => {
681 self.handle_unload(rpc);
682 }
683 ClientRequest::Modify(request) => self.handle_modify(request),
684 ClientRequest::HvsockConnect(request) => self.handle_tl_connect(request),
685 }
686 }
687
688 fn handle_version_response(&mut self, msg: protocol::VersionResponse2) {
689 let old_state = std::mem::replace(&mut self.state, ClientState::Disconnected);
690 let ClientState::Connecting { version, rpc } = old_state else {
691 self.state = old_state;
692 tracing::warn!(
693 client_state = %self.state,
694 "invalid client state to handle VersionResponse"
695 );
696 return;
697 };
698 if msg.version_response.version_supported > 0 {
699 if msg.version_response.connection_state != ConnectionState::SUCCESSFUL {
700 rpc.complete(Err(ConnectError::FailedToConnect(
701 msg.version_response.connection_state,
702 )));
703 return;
704 }
705
706 let feature_flags = if version >= Version::Copper {
707 FeatureFlags::from(msg.supported_features)
708 } else {
709 FeatureFlags::new()
710 };
711
712 let version = VersionInfo {
713 version,
714 feature_flags,
715 };
716
717 self.inner.messages.send(&protocol::RequestOffers {});
718 self.state = ClientState::RequestingOffers {
719 version,
720 rpc: rpc.split().1,
721 offers: Vec::new(),
722 };
723 tracing::info!(?version, "VmBus client connected, requesting offers");
724 } else {
725 let index = SUPPORTED_VERSIONS
726 .iter()
727 .position(|v| *v == version)
728 .unwrap();
729
730 if index == 0 {
731 rpc.complete(Err(ConnectError::NoSupportedVersions));
732 return;
733 }
734 let next_version = SUPPORTED_VERSIONS[index - 1];
735 tracing::debug!(
736 version = version as u32,
737 next_version = next_version as u32,
738 "Unsupported version, retrying"
739 );
740 self.handle_initiate_contact(rpc, next_version);
741 }
742 }
743
744 fn create_channel(&mut self, offer: protocol::OfferChannel) -> Result<OfferInfo> {
745 self.create_channel_core(offer, ChannelState::Offered)
746 }
747
748 fn create_channel_core(
749 &mut self,
750 offer: protocol::OfferChannel,
751 state: ChannelState,
752 ) -> Result<OfferInfo> {
753 if self.channels.0.contains_key(&offer.channel_id) {
754 anyhow::bail!("channel {:?} exists", offer.channel_id);
755 }
756 let (request_send, request_recv) = mesh::channel();
757 let (revoke_send, revoke_recv) = mesh::oneshot();
758
759 let connection_id = Arc::new(AtomicU32::new(0));
760 self.channels.0.insert(
761 offer.channel_id,
762 Channel {
763 revoke_send: Some(revoke_send),
764 offer,
765 state,
766 modify_response_send: None,
767 gpadls: HashMap::new(),
768 is_client_released: false,
769 connection_id: connection_id.clone(),
770 },
771 );
772
773 self.inner
774 .channel_requests
775 .push(TaggedStream::new(offer.channel_id, request_recv));
776
777 Ok(OfferInfo {
778 offer,
779 guest_to_host_interrupt: self.inner.synic.guest_to_host_interrupt(connection_id),
780 revoke_recv,
781 request_send,
782 })
783 }
784
785 fn handle_offer(&mut self, offer: protocol::OfferChannel) {
786 let offer_info = self
787 .create_channel(offer)
788 .expect("channel should not exist");
789
790 tracing::info!(
791 state = %self.state,
792 channel_id = offer.channel_id.0,
793 interface_id = %offer.interface_id,
794 instance_id = %offer.instance_id,
795 subchannel_index = offer.subchannel_index,
796 "received offer");
797
798 if let Some(offer) = self.hvsock_tracker.check_offer(&offer_info.offer) {
799 offer.complete(Some(offer_info));
800 } else {
801 match &mut self.state {
802 ClientState::Connected { offer_send, .. } => {
803 offer_send.send(offer_info);
804 }
805 ClientState::RequestingOffers { offers, .. } => {
806 offers.push(offer_info);
807 }
808 state => unreachable!("invalid client state for OfferChannel: {state}"),
809 }
810 }
811 }
812
813 fn handle_rescind(&mut self, rescind: protocol::RescindChannelOffer) -> TriedRelease {
814 let mut channel = self.channels.get_mut(rescind.channel_id);
815 tracing::info!(
816 state = %self.state,
817 channel_id = rescind.channel_id.0,
818 key = %OfferKey::from(&channel.offer),
819 "received rescind"
820 );
821 let event_flag = match std::mem::replace(&mut channel.state, ChannelState::Revoked) {
822 ChannelState::Offered => None,
823 ChannelState::Opening {
824 redirected_event_flag,
825 redirected_event: _,
826 rpc,
827 } => {
828 rpc.fail(anyhow::anyhow!("channel revoked"));
829 redirected_event_flag
830 }
831 ChannelState::Restored => None,
832 ChannelState::Opened {
833 redirected_event_flag,
834 redirected_event: _,
835 } => redirected_event_flag,
836 ChannelState::Revoked => {
837 panic!("channel id {:?} already revoked", rescind.channel_id);
838 }
839 };
840 if let Some(event_flag) = event_flag {
841 self.inner.synic.free_event_flag(event_flag);
842 }
843
844 channel.revoke_send.take().unwrap().send(());
846
847 channel.try_release(&mut self.inner.messages)
848 }
849
850 fn handle_offers_delivered(&mut self) {
851 match std::mem::replace(&mut self.state, ClientState::Disconnected) {
852 ClientState::RequestingOffers {
853 version,
854 rpc,
855 offers,
856 } => {
857 tracing::info!(version = ?version, "VmBus client connected, offers delivered");
858 let (offer_send, offer_recv) = mesh::channel();
859 self.state = ClientState::Connected {
860 version,
861 offer_send,
862 };
863 rpc.complete(Ok(ConnectResult {
864 version,
865 offers,
866 offer_recv,
867 }));
868 }
869 state => {
870 tracing::warn!(client_state = %state, "invalid client state for OffersDelivered");
871 self.state = state;
872 }
873 }
874 }
875
876 fn handle_gpadl_created(&mut self, request: protocol::GpadlCreated) -> TriedRelease {
877 let mut channel = self.channels.get_mut(request.channel_id);
878 let Some(gpadl_state) = channel.gpadls.get_mut(&request.gpadl_id) else {
879 panic!("GpadlCreated for unknown gpadl {:#x}", request.gpadl_id.0);
880 };
881
882 let rpc = match std::mem::replace(gpadl_state, GpadlState::Created) {
883 GpadlState::Offered(rpc) => rpc,
884 old_state => {
885 panic!(
886 "invalid state {old_state:?} for gpadl {:#x}:{:#x}",
887 request.channel_id.0, request.gpadl_id.0
888 );
889 }
890 };
891
892 let gpadl_created = request.status == protocol::STATUS_SUCCESS;
893 if gpadl_created {
894 rpc.complete(Ok(()));
895 } else {
896 channel.gpadls.remove(&request.gpadl_id).unwrap();
897 rpc.fail(anyhow::anyhow!(
898 "gpadl creation failed: {:#x}",
899 request.status
900 ));
901 };
902 channel.try_release(&mut self.inner.messages)
903 }
904
905 fn handle_open_result(&mut self, result: protocol::OpenResult) {
906 let mut channel = self.channels.get_mut(result.channel_id);
907 tracing::debug!(
908 channel_id = result.channel_id.0,
909 key = %OfferKey::from(&channel.offer),
910 result = result.status,
911 "received open result"
912 );
913
914 let channel_opened = result.status == protocol::STATUS_SUCCESS as u32;
915 let old_state = std::mem::replace(&mut channel.state, ChannelState::Offered);
916 let ChannelState::Opening {
917 redirected_event_flag,
918 redirected_event,
919 rpc,
920 } = old_state
921 else {
922 tracing::warn!(
923 key = %OfferKey::from(&channel.offer),
924 old_state = ?channel.state,
925 channel_opened,
926 "invalid state for open result"
927 );
928 channel.state = old_state;
929 return;
930 };
931
932 if !channel_opened {
933 if let Some(event_flag) = redirected_event_flag {
934 self.inner.synic.free_event_flag(event_flag);
935 }
936 rpc.fail(anyhow::anyhow!("open failed: {:#x}", result.status));
937 return;
938 }
939
940 channel.state = ChannelState::Opened {
941 redirected_event_flag,
942 redirected_event,
943 };
944
945 rpc.complete(Ok(OpenOutput {
946 redirected_event_flag,
947 }));
948 }
949
950 fn handle_gpadl_torndown(&mut self, request: protocol::GpadlTorndown) -> TriedRelease {
951 let Some(channel_id) = self.inner.teardown_gpadls.remove(&request.gpadl_id) else {
952 panic!("gpadl {:#x} not in teardown list", request.gpadl_id.0);
953 };
954
955 let mut channel = self.channels.get_mut(channel_id);
956 tracing::debug!(
957 gpadl_id = request.gpadl_id.0,
958 channel_id = channel_id.0,
959 key = %OfferKey::from(&channel.offer),
960 "Received GpadlTorndown"
961 );
962
963 let gpadl_state = channel
964 .gpadls
965 .remove(&request.gpadl_id)
966 .expect("gpadl validated above");
967
968 let GpadlState::TearingDown { rpcs } = gpadl_state else {
969 panic!("gpadl should be tearing down if in teardown list, state = {gpadl_state:?}");
970 };
971
972 for rpc in rpcs {
973 rpc.complete(());
974 }
975 channel.try_release(&mut self.inner.messages)
976 }
977
978 fn handle_unload_complete(&mut self) {
979 match std::mem::replace(&mut self.state, ClientState::Disconnected) {
980 ClientState::Disconnecting { version: _, rpc } => {
981 tracing::info!("VmBus client disconnected");
982 rpc.complete(());
983 }
984 state => {
985 tracing::warn!(client_state = %state, "invalid client state for UnloadComplete");
986 }
987 }
988 }
989
990 fn handle_modify_complete(&mut self, response: protocol::ModifyConnectionResponse) {
991 if let Some(request) = self.modify_request.take() {
992 request.complete(response.connection_state)
993 } else {
994 tracing::warn!("Unexpected modify complete request");
995 }
996 }
997
998 fn handle_modify_channel_response(
999 &mut self,
1000 response: protocol::ModifyChannelResponse,
1001 ) -> TriedRelease {
1002 let mut channel = self.channels.get_mut(response.channel_id);
1003 let Some(sender) = channel.modify_response_send.take() else {
1004 panic!(
1005 "unexpected modify channel response for channel {:#x}",
1006 response.channel_id.0
1007 );
1008 };
1009
1010 sender.complete(response.status);
1011 channel.try_release(&mut self.inner.messages)
1012 }
1013
1014 fn handle_tl_connect_result(&mut self, response: protocol::TlConnectResult) {
1015 if let Some(rpc) = self.hvsock_tracker.check_result(&response) {
1016 rpc.complete(None);
1017 }
1018 }
1019
1020 fn handle_synic_message(&mut self, data: &[u8]) -> bool {
1022 let msg = Message::parse(data, self.state.get_version()).unwrap();
1023 tracing::trace!(?msg, "received client message from synic");
1024
1025 match msg {
1026 Message::VersionResponse3(version_response, ..) => {
1027 self.handle_version_response(version_response.version_response2);
1032 }
1033 Message::VersionResponse2(version_response, ..) => {
1034 self.handle_version_response(version_response);
1035 }
1036 Message::VersionResponse(version_response, ..) => {
1037 self.handle_version_response(version_response.into());
1038 }
1039 Message::OfferChannel(offer, ..) => {
1040 self.handle_offer(offer);
1041 }
1042 Message::AllOffersDelivered(..) => {
1043 self.handle_offers_delivered();
1044 }
1045 Message::UnloadComplete(..) => {
1046 self.handle_unload_complete();
1047 }
1048 Message::ModifyConnectionResponse(response, ..) => {
1049 self.handle_modify_complete(response);
1050 }
1051 Message::GpadlCreated(gpadl, ..) => {
1052 self.handle_gpadl_created(gpadl);
1053 }
1054 Message::OpenResult(result, ..) => {
1055 self.handle_open_result(result);
1056 }
1057 Message::GpadlTorndown(gpadl, ..) => {
1058 self.handle_gpadl_torndown(gpadl);
1059 }
1060 Message::RescindChannelOffer(rescind, ..) => {
1061 self.handle_rescind(rescind);
1062 }
1063 Message::ModifyChannelResponse(response, ..) => {
1064 self.handle_modify_channel_response(response);
1065 }
1066 Message::TlConnectResult(response, ..) => self.handle_tl_connect_result(response),
1067 Message::CloseReservedChannelResponse(..) => {
1069 todo!("Unsupported message {msg:?}")
1070 }
1071 Message::PauseResponse(..) => {
1072 return false;
1073 }
1074 Message::RequestOffers(..)
1076 | Message::OpenChannel2(..)
1077 | Message::OpenChannel(..)
1078 | Message::CloseChannel(..)
1079 | Message::GpadlHeader(..)
1080 | Message::GpadlBody(..)
1081 | Message::GpadlTeardown(..)
1082 | Message::RelIdReleased(..)
1083 | Message::InitiateContact(..)
1084 | Message::InitiateContact2(..)
1085 | Message::Unload(..)
1086 | Message::OpenReservedChannel(..)
1087 | Message::CloseReservedChannel(..)
1088 | Message::TlConnectRequest2(..)
1089 | Message::TlConnectRequest(..)
1090 | Message::ModifyChannel(..)
1091 | Message::ModifyConnection(..)
1092 | Message::Pause(..)
1093 | Message::Resume(..) => {
1094 unreachable!("Client received server message {msg:?}");
1095 }
1096 }
1097 true
1098 }
1099
1100 fn handle_open_channel(
1101 &mut self,
1102 channel_id: ChannelId,
1103 rpc: FailableRpc<OpenRequest, OpenOutput>,
1104 ) {
1105 let mut channel = self.channels.get_mut(channel_id);
1106 match &channel.state {
1107 ChannelState::Offered => {}
1108 ChannelState::Revoked => {
1109 rpc.fail(anyhow::anyhow!("channel revoked"));
1110 return;
1111 }
1112 state => {
1113 rpc.fail(anyhow::anyhow!("invalid channel state: {}", state));
1114 return;
1115 }
1116 }
1117
1118 tracing::info!(
1119 channel_id = channel_id.0,
1120 key = %OfferKey::from(&channel.offer),
1121 "opening channel on host"
1122 );
1123
1124 let (request, rpc) = rpc.split();
1125 let open_data = &request.open_data;
1126
1127 let supports_interrupt_redirection =
1128 if let ClientState::Connected { version, .. } = self.state {
1129 version.feature_flags.guest_specified_signal_parameters()
1130 || version.feature_flags.channel_interrupt_redirection()
1131 } else {
1132 false
1133 };
1134
1135 if !supports_interrupt_redirection && open_data.event_flag != channel_id.0 as u16 {
1136 rpc.fail(anyhow::anyhow!(
1137 "host does not support specifying the event flag"
1138 ));
1139 return;
1140 }
1141
1142 let open_channel = protocol::OpenChannel {
1143 channel_id,
1144 open_id: 0,
1145 ring_buffer_gpadl_id: open_data.ring_gpadl_id,
1146 target_vp: open_data
1147 .target_vp
1148 .unwrap_or(protocol::VP_INDEX_DISABLE_INTERRUPT),
1149 downstream_ring_buffer_page_offset: open_data.ring_offset,
1150 user_data: open_data.user_data,
1151 };
1152
1153 let connection_id = if request.use_vtl2_connection_id {
1154 if !supports_interrupt_redirection {
1155 rpc.fail(anyhow::anyhow!(
1156 "host does not support specfiying the connection ID"
1157 ));
1158 return;
1159 }
1160 protocol::ConnectionId::new(channel_id.0, 2.try_into().unwrap(), 7).0
1161 } else {
1162 open_data.connection_id
1163 };
1164
1165 let mut flags = OpenChannelFlags::new();
1168 let event_flag = if let Some(event) = &request.incoming_event {
1169 if !supports_interrupt_redirection {
1170 rpc.fail(anyhow::anyhow!(
1171 "host does not support redirecting interrupts"
1172 ));
1173 return;
1174 }
1175
1176 flags.set_redirect_interrupt(true);
1177 match self.inner.synic.allocate_event_flag(event) {
1178 Ok(flag) => flag,
1179 Err(err) => {
1180 rpc.fail(err.context("failed to allocate event flag"));
1181 return;
1182 }
1183 }
1184 } else {
1185 open_data.event_flag
1186 };
1187
1188 if supports_interrupt_redirection {
1189 self.inner.messages.send(&protocol::OpenChannel2 {
1190 open_channel,
1191 connection_id,
1192 event_flag,
1193 flags,
1194 });
1195 } else {
1196 self.inner.messages.send(&open_channel);
1197 }
1198
1199 channel
1200 .connection_id
1201 .store(connection_id, Ordering::Release);
1202 channel.state = ChannelState::Opening {
1203 redirected_event_flag: (request.incoming_event.is_some()).then_some(event_flag),
1204 redirected_event: request.incoming_event,
1205 rpc,
1206 }
1207 }
1208
1209 fn handle_restore_channel(
1210 &mut self,
1211 channel_id: ChannelId,
1212 request: RestoreRequest,
1213 ) -> Result<OpenOutput> {
1214 let mut channel = self.channels.get_mut(channel_id);
1215 if !matches!(channel.state, ChannelState::Restored) {
1216 anyhow::bail!("invalid channel state: {}", channel.state);
1217 }
1218
1219 if request.incoming_event.is_some() != request.redirected_event_flag.is_some() {
1220 anyhow::bail!("incoming event and redirected event flag must both be set or unset");
1221 }
1222
1223 if let Some((flag, event)) = request
1224 .redirected_event_flag
1225 .zip(request.incoming_event.as_ref())
1226 {
1227 self.inner.synic.restore_event_flag(flag, event)?;
1228 }
1229
1230 channel
1231 .connection_id
1232 .store(request.connection_id, Ordering::Release);
1233 channel.state = ChannelState::Opened {
1234 redirected_event_flag: request.redirected_event_flag,
1235 redirected_event: request.incoming_event,
1236 };
1237 Ok(OpenOutput {
1238 redirected_event_flag: request.redirected_event_flag,
1239 })
1240 }
1241
1242 fn handle_gpadl(&mut self, channel_id: ChannelId, rpc: FailableRpc<GpadlRequest, ()>) {
1243 let (request, rpc) = rpc.split();
1244 let mut channel = self.channels.get_mut(channel_id);
1245 if channel
1246 .gpadls
1247 .insert(request.id, GpadlState::Offered(rpc))
1248 .is_some()
1249 {
1250 panic!(
1251 "duplicate gpadl ID {:?} for channel {:?}.",
1252 request.id, channel_id
1253 );
1254 }
1255
1256 tracing::trace!(
1257 channel_id = channel_id.0,
1258 key = %OfferKey::from(&channel.offer),
1259 gpadl_id = request.id.0,
1260 count = request.count,
1261 len = request.buf.len(),
1262 "received gpadl request"
1263 );
1264
1265 let (first, remaining) = if request.buf.len() > protocol::GpadlHeader::MAX_DATA_VALUES {
1267 request.buf.split_at(protocol::GpadlHeader::MAX_DATA_VALUES)
1268 } else {
1269 (request.buf.as_slice(), [].as_slice())
1270 };
1271
1272 let message = protocol::GpadlHeader {
1273 channel_id,
1274 gpadl_id: request.id,
1275 len: (request.buf.len() * size_of::<u64>())
1276 .try_into()
1277 .expect("Too many GPA values"),
1278 count: request.count,
1279 };
1280
1281 self.inner
1282 .messages
1283 .send_with_data(&message, first.as_bytes());
1284
1285 let message = protocol::GpadlBody {
1287 rsvd: 0,
1288 gpadl_id: request.id,
1289 };
1290 for chunk in remaining.chunks(protocol::GpadlBody::MAX_DATA_VALUES) {
1291 self.inner
1292 .messages
1293 .send_with_data(&message, chunk.as_bytes());
1294 }
1295 }
1296
1297 fn handle_gpadl_teardown(&mut self, channel_id: ChannelId, rpc: Rpc<GpadlId, ()>) {
1298 let (gpadl_id, rpc) = rpc.split();
1299 let mut channel = self.channels.get_mut(channel_id);
1300 let Some(gpadl_state) = channel.gpadls.get_mut(&gpadl_id) else {
1301 tracing::warn!(
1302 gpadl_id = gpadl_id.0,
1303 channel_id = channel_id.0,
1304 key = %OfferKey::from(&channel.offer),
1305 "Gpadl teardown for unknown gpadl or revoked channel"
1306 );
1307 return;
1308 };
1309
1310 match gpadl_state {
1311 GpadlState::Offered(_) => {
1312 tracing::warn!(
1313 gpadl_id = gpadl_id.0,
1314 channel_id = channel_id.0,
1315 key = %OfferKey::from(&channel.offer),
1316 "gpadl teardown for offered gpadl"
1317 );
1318 }
1319 GpadlState::Created => {
1320 *gpadl_state = GpadlState::TearingDown { rpcs: vec![rpc] };
1321 assert!(
1325 self.inner
1326 .teardown_gpadls
1327 .insert(gpadl_id, channel_id)
1328 .is_none(),
1329 "Gpadl state validated above"
1330 );
1331
1332 self.inner.messages.send(&protocol::GpadlTeardown {
1333 channel_id,
1334 gpadl_id,
1335 });
1336 }
1337 GpadlState::TearingDown { rpcs } => {
1338 rpcs.push(rpc);
1339 }
1340 }
1341 }
1342
1343 fn handle_close_channel(&mut self, channel_id: ChannelId) {
1344 let mut channel = self.channels.get_mut(channel_id);
1345 self.inner.close_channel(channel_id, &mut channel);
1346 }
1347
1348 fn handle_modify_channel(&mut self, channel_id: ChannelId, rpc: Rpc<ModifyRequest, i32>) {
1349 assert!(self.check_version(Version::Iron));
1353 let mut channel = self.channels.get_mut(channel_id);
1354 if channel.modify_response_send.is_some() {
1355 panic!("duplicate channel modify request {channel_id:?}");
1356 }
1357
1358 let (request, response) = rpc.split();
1359 channel.modify_response_send = Some(response);
1360 let payload = match request {
1361 ModifyRequest::TargetVp { target_vp } => protocol::ModifyChannel {
1362 channel_id,
1363 target_vp,
1364 },
1365 };
1366
1367 self.inner.messages.send(&payload);
1368 }
1369
1370 fn handle_channel_request(&mut self, channel_id: ChannelId, request: ChannelRequest) {
1371 match request {
1372 ChannelRequest::Open(rpc) => self.handle_open_channel(channel_id, rpc),
1373 ChannelRequest::Restore(rpc) => {
1374 rpc.handle_failable_sync(|request| self.handle_restore_channel(channel_id, request))
1375 }
1376 ChannelRequest::Gpadl(req) => self.handle_gpadl(channel_id, req),
1377 ChannelRequest::TeardownGpadl(req) => self.handle_gpadl_teardown(channel_id, req),
1378 ChannelRequest::Close(req) => {
1379 req.handle_sync(|()| self.handle_close_channel(channel_id))
1380 }
1381 ChannelRequest::Modify(req) => self.handle_modify_channel(channel_id, req),
1382 }
1383 }
1384
1385 async fn handle_task(&mut self, task: TaskRequest) {
1386 match task {
1387 TaskRequest::Inspect(deferred) => {
1388 deferred.inspect(&*self);
1389 }
1390 TaskRequest::Save(rpc) => rpc.handle_sync(|()| self.handle_save()),
1391 TaskRequest::Restore(rpc) => {
1392 rpc.handle_sync(|saved_state| self.handle_restore(saved_state))
1393 }
1394 TaskRequest::PostRestore(rpc) => rpc.handle_sync(|()| self.handle_post_restore()),
1395 TaskRequest::Start => self.handle_start(),
1396 TaskRequest::Stop(rpc) => rpc.handle(async |()| self.handle_stop().await).await,
1397 }
1398 }
1399
1400 fn handle_device_removal(&mut self, channel_id: ChannelId) -> TriedRelease {
1402 let mut channel = self.channels.get_mut(channel_id);
1403 channel.is_client_released = true;
1404 if let ChannelState::Opened { .. } = channel.state {
1406 tracing::warn!(
1407 channel_id = channel_id.0,
1408 key = %OfferKey::from(&channel.offer),
1409 "Channel dropped without closing first"
1410 );
1411 self.inner.close_channel(channel_id, &mut channel);
1412 }
1413 channel.try_release(&mut self.inner.messages)
1414 }
1415
1416 fn check_version(&self, version: Version) -> bool {
1418 matches!(self.state, ClientState::Connected { version: v, .. } if v.version >= version)
1419 }
1420
1421 fn handle_start(&mut self) {
1422 assert!(!self.running);
1423 self.msg_source.resume_message_stream();
1424 self.inner.messages.resume();
1425 self.running = true;
1426 }
1427
1428 async fn handle_stop(&mut self) {
1429 assert!(self.running);
1430
1431 loop {
1432 while let Some((id, request)) = self.channels.revoked_channel_with_pending_request() {
1437 tracelimit::info_ratelimited!(
1438 channel_id = id.0,
1439 request,
1440 "waiting for responses for channel"
1441 );
1442 assert!(self.process_next_message().await);
1443 }
1444
1445 if self.can_pause_resume() {
1446 self.inner.messages.pause();
1447 } else {
1448 self.msg_source.pause_message_stream();
1451 self.inner.messages.force_pause();
1452 }
1453
1454 while self.process_next_message().await {}
1457
1458 if self
1461 .channels
1462 .revoked_channel_with_pending_request()
1463 .is_none()
1464 {
1465 break;
1466 }
1467 if !self.can_pause_resume() {
1468 self.msg_source.resume_message_stream();
1469 }
1470 self.inner.messages.resume();
1471 }
1472
1473 tracing::debug!("messages drained");
1474 self.running = false;
1476 }
1477
1478 async fn process_next_message(&mut self) -> bool {
1479 let mut buf = [0; protocol::MAX_MESSAGE_SIZE];
1480 let recv = self.msg_source.recv(&mut buf);
1481 let flush = async {
1484 self.inner.messages.flush_messages().await;
1485 std::future::pending().await
1486 };
1487 let size = (recv, flush)
1488 .race()
1489 .await
1490 .expect("Fatal error reading messages from synic");
1491 if size == 0 {
1492 return false;
1493 }
1494 self.handle_synic_message(&buf[..size])
1495 }
1496
1497 fn can_pause_resume(&self) -> bool {
1506 if let ClientState::Connected { version, .. } = self.state {
1507 version.feature_flags.pause_resume()
1508 } else {
1509 false
1510 }
1511 }
1512
1513 async fn run(&mut self) {
1514 let mut buf = [0; protocol::MAX_MESSAGE_SIZE];
1515 loop {
1516 let mut message_recv =
1517 OptionFuture::from(self.running.then(|| self.msg_source.recv(&mut buf).fuse()));
1518
1519 let host_backed_up = !self.inner.messages.is_empty();
1529 let flush_messages = OptionFuture::from(
1530 (self.running && host_backed_up)
1531 .then(|| self.inner.messages.flush_messages().fuse()),
1532 );
1533
1534 let mut client_request_recv = OptionFuture::from(
1535 (self.running && !host_backed_up).then(|| self.client_request_recv.next()),
1536 );
1537
1538 let mut channel_requests = OptionFuture::from(
1539 (self.running && !host_backed_up)
1540 .then(|| self.inner.channel_requests.select_next_some()),
1541 );
1542
1543 futures::select! { _r = pin!(flush_messages) => {}
1545 r = self.task_recv.next() => {
1546 if let Some(task) = r {
1547 self.handle_task(task).await;
1548 } else {
1549 break;
1550 }
1551 }
1552 r = client_request_recv => {
1553 if let Some(Some(request)) = r {
1554 self.handle_client_request(request);
1555 } else {
1556 break;
1557 }
1558 }
1559 r = channel_requests => {
1560 match r.unwrap() {
1561 (id, Some(request)) => self.handle_channel_request(id, request),
1562 (id, _) => {
1563 self.handle_device_removal(id);
1564 }
1565 }
1566 }
1567 r = message_recv => {
1568 match r.unwrap() {
1569 Ok(size) => {
1570 if size == 0 {
1571 panic!("Unexpected end of file reading messages from synic.");
1572 }
1573
1574 self.handle_synic_message(&buf[..size]);
1575 }
1576 Err(err) => {
1577 panic!("Error reading messages from synic: {err:?}");
1578 }
1579 }
1580 }
1581 complete => break,
1582 }
1583 }
1584 }
1585}
1586
1587impl ClientTaskInner {
1588 fn close_channel(&mut self, channel_id: ChannelId, channel: &mut Channel) {
1589 if let ChannelState::Opened {
1590 redirected_event_flag,
1591 ..
1592 } = channel.state
1593 {
1594 if let Some(flag) = redirected_event_flag {
1595 self.synic.free_event_flag(flag);
1596 }
1597 tracing::info!(
1598 channel_id = channel_id.0,
1599 key = %OfferKey::from(&channel.offer),
1600 "closing channel on host"
1601 );
1602
1603 self.messages.send(&protocol::CloseChannel { channel_id });
1604 channel.state = ChannelState::Offered;
1605 channel.connection_id.store(0, Ordering::Release);
1606 } else {
1607 tracing::warn!(
1608 channel_id = channel_id.0,
1609 key = %OfferKey::from(&channel.offer),
1610 channel_state = %channel.state,
1611 "invalid channel state for close channel"
1612 );
1613 }
1614 }
1615}
1616
1617#[derive(Debug, Inspect)]
1618#[inspect(external_tag)]
1619enum GpadlState {
1620 Offered(#[inspect(skip)] FailableRpc<(), ()>),
1622 Created,
1624 TearingDown {
1626 #[inspect(skip)]
1627 rpcs: Vec<Rpc<(), ()>>,
1628 },
1629}
1630
1631#[derive(Inspect)]
1632struct OutgoingMessages {
1633 #[inspect(skip)]
1634 poster: Box<dyn PollPostMessage>,
1635 #[inspect(with = "|x| x.len()")]
1636 queued: VecDeque<OutgoingMessage>,
1637 state: OutgoingMessageState,
1638}
1639
1640#[derive(Inspect, PartialEq, Eq, Debug)]
1641enum OutgoingMessageState {
1642 Running,
1643 SendingPauseMessage,
1644 Paused,
1645}
1646
1647impl OutgoingMessages {
1648 fn send<T: IntoBytes + protocol::VmbusMessage + std::fmt::Debug + Immutable + KnownLayout>(
1649 &mut self,
1650 msg: &T,
1651 ) {
1652 self.send_with_data(msg, &[])
1653 }
1654
1655 fn send_with_data<
1656 T: IntoBytes + protocol::VmbusMessage + std::fmt::Debug + Immutable + KnownLayout,
1657 >(
1658 &mut self,
1659 msg: &T,
1660 data: &[u8],
1661 ) {
1662 tracing::trace!(typ = ?T::MESSAGE_TYPE, "Sending message to host");
1663 let msg = OutgoingMessage::with_data(msg, data);
1664 if self.queued.is_empty() && self.state == OutgoingMessageState::Running {
1665 let r = self.poster.poll_post_message(
1666 &mut Context::from_waker(std::task::Waker::noop()),
1667 protocol::VMBUS_MESSAGE_REDIRECT_CONNECTION_ID,
1668 1,
1669 msg.data(),
1670 );
1671 if let Poll::Ready(()) = r {
1672 return;
1673 }
1674 }
1675 tracing::trace!("queueing message");
1676 self.queued.push_back(msg);
1677 }
1678
1679 async fn flush_messages(&mut self) {
1680 let mut send = async |msg: &OutgoingMessage| {
1681 poll_fn(|cx| {
1682 self.poster.poll_post_message(
1683 cx,
1684 protocol::VMBUS_MESSAGE_REDIRECT_CONNECTION_ID,
1685 1,
1686 msg.data(),
1687 )
1688 })
1689 .await
1690 };
1691 match self.state {
1692 OutgoingMessageState::Running => {
1693 while let Some(msg) = self.queued.front() {
1694 send(msg).await;
1695 tracing::trace!("sent queued message");
1696 self.queued.pop_front();
1697 }
1698 }
1699 OutgoingMessageState::SendingPauseMessage => {
1700 send(&OutgoingMessage::new(&protocol::Pause)).await;
1701 tracing::trace!("sent pause message");
1702 self.state = OutgoingMessageState::Paused;
1703 }
1704 OutgoingMessageState::Paused => {}
1705 }
1706 }
1707
1708 fn pause(&mut self) {
1711 assert_eq!(self.state, OutgoingMessageState::Running);
1712 self.state = OutgoingMessageState::SendingPauseMessage;
1713 self.queued
1715 .push_front(OutgoingMessage::new(&protocol::Resume));
1716 }
1717
1718 fn force_pause(&mut self) {
1722 assert_eq!(self.state, OutgoingMessageState::Running);
1723 self.state = OutgoingMessageState::Paused;
1724 }
1725
1726 fn resume(&mut self) {
1727 assert_eq!(self.state, OutgoingMessageState::Paused);
1728 self.state = OutgoingMessageState::Running;
1729 }
1730
1731 fn is_empty(&self) -> bool {
1732 self.queued.is_empty()
1733 }
1734}
1735
1736#[derive(Inspect)]
1737struct ClientTaskInner {
1738 messages: OutgoingMessages,
1739 #[inspect(with = "|x| inspect::iter_by_key(x).map_key(|id| id.0)")]
1740 teardown_gpadls: HashMap<GpadlId, ChannelId>,
1741 #[inspect(skip)]
1742 channel_requests: SelectAll<TaggedStream<ChannelId, mesh::Receiver<ChannelRequest>>>,
1743 synic: SynicState,
1744}
1745
1746#[derive(Inspect)]
1747struct SynicState {
1748 #[inspect(skip)]
1749 event_client: Arc<dyn SynicEventClient>,
1750 #[inspect(iter_by_index)]
1751 event_flag_state: Vec<bool>,
1752}
1753
1754#[derive(Inspect, Default)]
1755#[inspect(transparent)]
1756struct ChannelList(
1757 #[inspect(with = "|x| inspect::iter_by_key(x).map_key(|id| id.0)")] HashMap<ChannelId, Channel>,
1758);
1759
1760struct ChannelRef<'a>(hash_map::OccupiedEntry<'a, ChannelId, Channel>);
1763
1764struct TriedRelease(());
1768
1769impl ChannelRef<'_> {
1770 fn try_release(self, messages: &mut OutgoingMessages) -> TriedRelease {
1774 if self.is_client_released
1775 && matches!(self.state, ChannelState::Revoked)
1776 && self.pending_request().is_none()
1777 {
1778 let channel_id = *self.0.key();
1779 tracelimit::info_ratelimited!(
1780 channel_id = channel_id.0,
1781 key = %OfferKey::from(&self.offer),
1782 "releasing channel"
1783 );
1784
1785 messages.send(&protocol::RelIdReleased { channel_id });
1786 self.0.remove();
1787 }
1788 TriedRelease(())
1789 }
1790}
1791
1792impl Deref for ChannelRef<'_> {
1793 type Target = Channel;
1794
1795 fn deref(&self) -> &Self::Target {
1796 self.0.get()
1797 }
1798}
1799
1800impl DerefMut for ChannelRef<'_> {
1801 fn deref_mut(&mut self) -> &mut Self::Target {
1802 self.0.get_mut()
1803 }
1804}
1805
1806impl ChannelList {
1807 fn revoked_channel_with_pending_request(&self) -> Option<(ChannelId, &'static str)> {
1808 self.0.iter().find_map(|(&id, channel)| {
1809 if !matches!(channel.state, ChannelState::Revoked) {
1810 return None;
1811 }
1812 Some((id, channel.pending_request()?))
1813 })
1814 }
1815
1816 #[track_caller]
1817 fn get_mut(&mut self, channel_id: ChannelId) -> ChannelRef<'_> {
1818 match self.0.entry(channel_id) {
1819 hash_map::Entry::Occupied(entry) => ChannelRef(entry),
1820 hash_map::Entry::Vacant(_) => {
1821 panic!("channel {:?} not found", channel_id);
1822 }
1823 }
1824 }
1825}
1826
1827impl SynicState {
1828 fn guest_to_host_interrupt(&self, connection_id: Arc<AtomicU32>) -> Interrupt {
1829 Interrupt::from_fn({
1830 let event_client = self.event_client.clone();
1831 move || {
1832 let connection_id = connection_id.load(Ordering::Acquire);
1833 if connection_id == 0 {
1834 tracing::debug!("interrupt signal after close");
1835 return;
1836 }
1837
1838 if let Err(err) = event_client.signal_event(connection_id, 0) {
1839 tracelimit::warn_ratelimited!(
1840 error = &err as &dyn std::error::Error,
1841 "failed to signal event"
1842 );
1843 }
1844 }
1845 })
1846 }
1847
1848 const MAX_EVENT_FLAGS: u16 = 2047;
1849
1850 fn allocate_event_flag(&mut self, event: &Event) -> Result<u16> {
1851 let i = self
1852 .event_flag_state
1853 .iter()
1854 .position(|&used| !used)
1855 .ok_or(())
1856 .or_else(|()| {
1857 if self.event_flag_state.len() >= Self::MAX_EVENT_FLAGS as usize {
1858 anyhow::bail!("out of event flags");
1859 }
1860 self.event_flag_state.push(false);
1861 Ok(self.event_flag_state.len() - 1)
1862 })?;
1863
1864 let event_flag = (i + 1) as u16;
1865 self.event_client
1866 .map_event(event_flag, event)
1867 .context("failed to map event")?;
1868 self.event_flag_state[i] = true;
1869 Ok(event_flag)
1870 }
1871
1872 fn restore_event_flag(&mut self, flag: u16, event: &Event) -> Result<()> {
1873 let i = (flag as usize)
1874 .checked_sub(1)
1875 .context("invalid event flag")?;
1876 if i >= Self::MAX_EVENT_FLAGS as usize {
1877 anyhow::bail!("invalid event flag");
1878 }
1879 if self.event_flag_state.len() <= i {
1880 self.event_flag_state.resize(i + 1, false);
1881 }
1882 if self.event_flag_state[i] {
1883 anyhow::bail!("event flag already in use");
1884 }
1885 self.event_client
1886 .map_event(flag, event)
1887 .context("failed to map event")?;
1888 self.event_flag_state[i] = true;
1889 Ok(())
1890 }
1891
1892 fn free_event_flag(&mut self, flag: u16) {
1893 let i = flag as usize - 1;
1894 assert!(i < self.event_flag_state.len());
1895 self.event_flag_state[i] = false;
1896 }
1897}
1898
1899#[cfg(test)]
1900mod tests {
1901 use super::*;
1902 use futures_concurrency::future::Join;
1903 use guid::Guid;
1904 use pal_async::DefaultDriver;
1905 use pal_async::async_test;
1906 use pal_async::timer::PolledTimer;
1907 use protocol::TargetInfo;
1908 use std::fmt::Debug;
1909 use std::task::ready;
1910 use std::time::Duration;
1911 use test_with_tracing::test;
1912 use vmbus_core::protocol::MessageHeader;
1913 use vmbus_core::protocol::MessageType;
1914 use vmbus_core::protocol::OfferFlags;
1915 use vmbus_core::protocol::UserDefinedData;
1916 use vmbus_core::protocol::VmbusMessage;
1917 use zerocopy::FromBytes;
1918 use zerocopy::FromZeros;
1919 use zerocopy::Immutable;
1920 use zerocopy::IntoBytes;
1921 use zerocopy::KnownLayout;
1922
1923 const VMBUS_TEST_CLIENT_ID: Guid = guid::guid!("e6e6e6e6-e6e6-e6e6-e6e6-e6e6e6e6e6e6");
1924
1925 fn in_msg<T: IntoBytes + Immutable + KnownLayout>(message_type: MessageType, t: T) -> Vec<u8> {
1926 let mut data = Vec::new();
1927 data.extend_from_slice(&message_type.0.to_ne_bytes());
1928 data.extend_from_slice(&0u32.to_ne_bytes());
1929 data.extend_from_slice(t.as_bytes());
1930 data
1931 }
1932
1933 #[track_caller]
1934 fn check_message<T>(msg: OutgoingMessage, chk: T)
1935 where
1936 T: IntoBytes + FromBytes + Immutable + KnownLayout + Debug + VmbusMessage,
1937 {
1938 check_message_with_data(msg, chk, &[]);
1939 }
1940
1941 #[track_caller]
1942 fn check_message_with_data<T>(msg: OutgoingMessage, chk: T, data: &[u8])
1943 where
1944 T: IntoBytes + FromBytes + Immutable + KnownLayout + Debug + VmbusMessage,
1945 {
1946 let chk_data = OutgoingMessage::with_data(&chk, data);
1947 if msg.data() != chk_data.data() {
1948 let (header, rest) = MessageHeader::read_from_prefix(msg.data()).unwrap();
1949 assert_eq!(header.message_type(), <T as VmbusMessage>::MESSAGE_TYPE);
1950 let (msg, rest) = T::read_from_prefix(rest).expect("incorrect message size");
1951 if msg.as_bytes() != chk.as_bytes() {
1952 panic!("mismatched messages, expected {:#?}, got {:#?}", chk, msg);
1953 }
1954 if rest != data {
1955 panic!("mismatched data, expected {:#?}, got {:#?}", data, rest);
1956 }
1957 }
1958 }
1959
1960 struct TestServer {
1961 messages: mesh::Receiver<OutgoingMessage>,
1962 send: mesh::Sender<Vec<u8>>,
1963 }
1964
1965 impl TestServer {
1966 async fn next(&mut self) -> Option<OutgoingMessage> {
1967 self.messages.next().await
1968 }
1969
1970 fn send(&self, msg: Vec<u8>) {
1971 self.send.send(msg);
1972 }
1973
1974 async fn connect(&mut self, client: &mut VmbusClient) -> ConnectResult {
1975 self.connect_with_channels(client, |_| {}).await
1976 }
1977
1978 async fn connect_with_channels(
1979 &mut self,
1980 client: &mut VmbusClient,
1981 send_offers: impl FnOnce(&mut Self),
1982 ) -> ConnectResult {
1983 let client_connect = client.connect(0, None, Guid::ZERO);
1984
1985 let server_connect = async {
1986 let _ = self.next().await.unwrap();
1987
1988 self.send(in_msg(
1989 MessageType::VERSION_RESPONSE,
1990 protocol::VersionResponse2 {
1991 version_response: protocol::VersionResponse {
1992 version_supported: 1,
1993 connection_state: ConnectionState::SUCCESSFUL,
1994 padding: 0,
1995 selected_version_or_connection_id: 0,
1996 },
1997 supported_features: SUPPORTED_FEATURE_FLAGS.into(),
1998 },
1999 ));
2000
2001 check_message(self.next().await.unwrap(), protocol::RequestOffers {});
2002
2003 send_offers(self);
2004 self.send(in_msg(MessageType::ALL_OFFERS_DELIVERED, [0x00]));
2005 };
2006
2007 let (connection, ()) = (client_connect, server_connect).join().await;
2008
2009 let connection = connection.unwrap();
2010 assert_eq!(connection.version.version, Version::Copper);
2011 assert_eq!(connection.version.feature_flags, SUPPORTED_FEATURE_FLAGS);
2012 connection
2013 }
2014
2015 async fn get_channel(&mut self, client: &mut VmbusClient) -> OfferInfo {
2016 let [channel] = self
2017 .get_channels(client, 1)
2018 .await
2019 .offers
2020 .try_into()
2021 .unwrap();
2022 channel
2023 }
2024
2025 async fn get_channels(&mut self, client: &mut VmbusClient, count: usize) -> ConnectResult {
2026 self.connect_with_channels(client, |this| {
2027 for i in 0..count {
2028 let offer = protocol::OfferChannel {
2029 interface_id: Guid::new_random(),
2030 instance_id: Guid::new_random(),
2031 rsvd: [0; 4],
2032 flags: OfferFlags::new(),
2033 mmio_megabytes: 0,
2034 user_defined: UserDefinedData::new_zeroed(),
2035 subchannel_index: 0,
2036 mmio_megabytes_optional: 0,
2037 channel_id: ChannelId(i as u32),
2038 monitor_id: 0,
2039 monitor_allocated: 0,
2040 is_dedicated: 0,
2041 connection_id: 0,
2042 };
2043
2044 this.send(in_msg(MessageType::OFFER_CHANNEL, offer));
2045 }
2046 })
2047 .await
2048 }
2049
2050 async fn stop_client(&mut self, client: &mut VmbusClient) {
2051 let client_stop = client.stop();
2052 let server_stop = async {
2053 check_message(self.next().await.unwrap(), protocol::Pause);
2054 self.send(in_msg(MessageType::PAUSE_RESPONSE, protocol::PauseResponse));
2055 };
2056 (client_stop, server_stop).join().await;
2057 }
2058
2059 async fn start_client(&mut self, client: &mut VmbusClient) {
2060 client.start();
2061 check_message(self.next().await.unwrap(), protocol::Resume);
2062 }
2063 }
2064
2065 struct TestServerClient {
2066 sender: mesh::Sender<OutgoingMessage>,
2067 timer: PolledTimer,
2068 deadline: Option<pal_async::timer::Instant>,
2069 }
2070
2071 impl PollPostMessage for TestServerClient {
2072 fn poll_post_message(
2073 &mut self,
2074 cx: &mut Context<'_>,
2075 _connection_id: u32,
2076 _typ: u32,
2077 msg: &[u8],
2078 ) -> Poll<()> {
2079 loop {
2080 if let Some(deadline) = self.deadline {
2081 ready!(self.timer.poll_until(cx, deadline));
2082 self.deadline = None;
2083 }
2084 let mut b = [0];
2089 getrandom::fill(&mut b).unwrap();
2090 if b[0] % 4 == 0 {
2091 self.deadline =
2092 Some(pal_async::timer::Instant::now() + Duration::from_millis(10));
2093 } else {
2094 let msg = OutgoingMessage::from_message(msg).unwrap();
2095 tracing::info!(
2096 msg = ?MessageHeader::read_from_prefix(msg.data()),
2097 "sending message"
2098 );
2099 self.sender.send(msg);
2100 break Poll::Ready(());
2101 }
2102 }
2103 }
2104 }
2105
2106 struct NoopSynicEvents;
2107
2108 impl SynicEventClient for NoopSynicEvents {
2109 fn map_event(&self, _event_flag: u16, _event: &Event) -> std::io::Result<()> {
2110 Ok(())
2111 }
2112
2113 fn unmap_event(&self, _event_flag: u16) {}
2114
2115 fn signal_event(&self, _connection_id: u32, _event_flag: u16) -> std::io::Result<()> {
2116 Err(std::io::ErrorKind::Unsupported.into())
2117 }
2118 }
2119
2120 struct TestMessageSource {
2121 msg_recv: mesh::Receiver<Vec<u8>>,
2122 paused: bool,
2123 }
2124
2125 impl AsyncRecv for TestMessageSource {
2126 fn poll_recv(
2127 &mut self,
2128 cx: &mut Context<'_>,
2129 mut bufs: &mut [std::io::IoSliceMut<'_>],
2130 ) -> Poll<std::io::Result<usize>> {
2131 let value = match self.msg_recv.poll_recv(cx) {
2132 Poll::Ready(v) => v.unwrap(),
2133 Poll::Pending => {
2134 if self.paused {
2135 return Poll::Ready(Ok(0));
2136 } else {
2137 return Poll::Pending;
2138 }
2139 }
2140 };
2141 let mut remaining = value.as_slice();
2142 let mut total_size = 0;
2143 while !remaining.is_empty() && !bufs.is_empty() {
2144 let size = bufs[0].len().min(remaining.len());
2145 bufs[0][..size].copy_from_slice(&remaining[..size]);
2146 remaining = &remaining[size..];
2147 bufs = &mut bufs[1..];
2148 total_size += size;
2149 }
2150
2151 Ok(total_size).into()
2152 }
2153 }
2154
2155 impl VmbusMessageSource for TestMessageSource {
2156 fn pause_message_stream(&mut self) {
2157 self.paused = true;
2158 }
2159
2160 fn resume_message_stream(&mut self) {
2161 self.paused = false;
2162 }
2163 }
2164
2165 fn test_init(driver: &DefaultDriver) -> (TestServer, VmbusClient) {
2166 let (msg_send, msg_recv) = mesh::channel();
2167 let (synic_send, synic_recv) = mesh::channel();
2168 let server = TestServer {
2169 messages: synic_recv,
2170 send: msg_send,
2171 };
2172 let mut client = VmbusClientBuilder::new(
2173 NoopSynicEvents,
2174 TestMessageSource {
2175 msg_recv,
2176 paused: false,
2177 },
2178 TestServerClient {
2179 sender: synic_send,
2180 deadline: None,
2181 timer: PolledTimer::new(driver),
2182 },
2183 )
2184 .build(driver);
2185 client.start();
2186 (server, client)
2187 }
2188
2189 #[async_test]
2190 async fn test_initiate_contact_success(driver: DefaultDriver) {
2191 let (mut server, client) = test_init(&driver);
2192 let _recv = client
2193 .access
2194 .client_request_send
2195 .call(ClientRequest::Connect, ConnectRequest::default());
2196 check_message(
2197 server.next().await.unwrap(),
2198 protocol::InitiateContact2 {
2199 initiate_contact: protocol::InitiateContact {
2200 version_requested: Version::Copper as u32,
2201 target_message_vp: 0,
2202 interrupt_page_or_target_info: TargetInfo::new()
2203 .with_sint(2)
2204 .with_vtl(0)
2205 .with_feature_flags(SUPPORTED_FEATURE_FLAGS.into())
2206 .into(),
2207 parent_to_child_monitor_page_gpa: 0,
2208 child_to_parent_monitor_page_gpa: 0,
2209 },
2210 ..FromZeros::new_zeroed()
2211 },
2212 );
2213 }
2214
2215 #[async_test]
2216 async fn test_connect_success(driver: DefaultDriver) {
2217 let (mut server, mut client) = test_init(&driver);
2218 let client_connect = client.connect(0, None, Guid::ZERO);
2219
2220 let server_connect = async {
2221 check_message(
2222 server.next().await.unwrap(),
2223 protocol::InitiateContact2 {
2224 initiate_contact: protocol::InitiateContact {
2225 version_requested: Version::Copper as u32,
2226 target_message_vp: 0,
2227 interrupt_page_or_target_info: TargetInfo::new()
2228 .with_sint(2)
2229 .with_vtl(0)
2230 .with_feature_flags(SUPPORTED_FEATURE_FLAGS.into())
2231 .into(),
2232 parent_to_child_monitor_page_gpa: 0,
2233 child_to_parent_monitor_page_gpa: 0,
2234 },
2235 ..FromZeros::new_zeroed()
2236 },
2237 );
2238
2239 server.send(in_msg(
2240 MessageType::VERSION_RESPONSE,
2241 protocol::VersionResponse2 {
2242 version_response: protocol::VersionResponse {
2243 version_supported: 1,
2244 connection_state: ConnectionState::SUCCESSFUL,
2245 padding: 0,
2246 selected_version_or_connection_id: 0,
2247 },
2248 supported_features: SUPPORTED_FEATURE_FLAGS.into_bits(),
2249 },
2250 ));
2251
2252 check_message(server.next().await.unwrap(), protocol::RequestOffers {});
2253 server.send(in_msg(MessageType::ALL_OFFERS_DELIVERED, [0x00]));
2254 };
2255
2256 let (connection, ()) = (client_connect, server_connect).join().await;
2257 let connection = connection.unwrap();
2258
2259 assert_eq!(connection.version.version, Version::Copper);
2260 assert_eq!(connection.version.feature_flags, SUPPORTED_FEATURE_FLAGS);
2261 }
2262
2263 #[async_test]
2264 async fn test_feature_flags(driver: DefaultDriver) {
2265 let (mut server, mut client) = test_init(&driver);
2266 let client_connect = client.connect(0, None, Guid::ZERO);
2267
2268 let server_connect = async {
2269 check_message(
2270 server.next().await.unwrap(),
2271 protocol::InitiateContact2 {
2272 initiate_contact: protocol::InitiateContact {
2273 version_requested: Version::Copper as u32,
2274 target_message_vp: 0,
2275 interrupt_page_or_target_info: TargetInfo::new()
2276 .with_sint(2)
2277 .with_vtl(0)
2278 .with_feature_flags(SUPPORTED_FEATURE_FLAGS.into())
2279 .into(),
2280 parent_to_child_monitor_page_gpa: 0,
2281 child_to_parent_monitor_page_gpa: 0,
2282 },
2283 ..FromZeros::new_zeroed()
2284 },
2285 );
2286
2287 server.send(in_msg(
2290 MessageType::VERSION_RESPONSE,
2291 protocol::VersionResponse2 {
2292 version_response: protocol::VersionResponse {
2293 version_supported: 1,
2294 connection_state: ConnectionState::SUCCESSFUL,
2295 padding: 0,
2296 selected_version_or_connection_id: 0,
2297 },
2298 supported_features: 2,
2299 },
2300 ));
2301
2302 check_message(server.next().await.unwrap(), protocol::RequestOffers {});
2303 server.send(in_msg(MessageType::ALL_OFFERS_DELIVERED, [0x00]));
2304 };
2305
2306 let (connection, ()) = (client_connect, server_connect).join().await;
2307 let connection = connection.unwrap();
2308
2309 assert_eq!(connection.version.version, Version::Copper);
2310 assert_eq!(
2311 connection.version.feature_flags,
2312 FeatureFlags::new().with_channel_interrupt_redirection(true)
2313 );
2314 }
2315
2316 #[async_test]
2317 async fn test_client_id(driver: DefaultDriver) {
2318 let (mut server, client) = test_init(&driver);
2319 let initiate_contact = ConnectRequest {
2320 client_id: VMBUS_TEST_CLIENT_ID,
2321 ..Default::default()
2322 };
2323 let _recv = client
2324 .access
2325 .client_request_send
2326 .call(ClientRequest::Connect, initiate_contact);
2327
2328 check_message(
2329 server.next().await.unwrap(),
2330 protocol::InitiateContact2 {
2331 initiate_contact: protocol::InitiateContact {
2332 version_requested: Version::Copper as u32,
2333 target_message_vp: 0,
2334 interrupt_page_or_target_info: TargetInfo::new()
2335 .with_sint(2)
2336 .with_vtl(0)
2337 .with_feature_flags(SUPPORTED_FEATURE_FLAGS.into())
2338 .into(),
2339 parent_to_child_monitor_page_gpa: 0,
2340 child_to_parent_monitor_page_gpa: 0,
2341 },
2342 client_id: VMBUS_TEST_CLIENT_ID,
2343 },
2344 );
2345 }
2346
2347 #[async_test]
2348 async fn test_version_negotiation(driver: DefaultDriver) {
2349 let (mut server, mut client) = test_init(&driver);
2350 let client_connect = client.connect(0, None, Guid::ZERO);
2351
2352 let server_connect = async {
2353 check_message(
2354 server.next().await.unwrap(),
2355 protocol::InitiateContact2 {
2356 initiate_contact: protocol::InitiateContact {
2357 version_requested: Version::Copper as u32,
2358 target_message_vp: 0,
2359 interrupt_page_or_target_info: TargetInfo::new()
2360 .with_sint(2)
2361 .with_vtl(0)
2362 .with_feature_flags(SUPPORTED_FEATURE_FLAGS.into())
2363 .into(),
2364 parent_to_child_monitor_page_gpa: 0,
2365 child_to_parent_monitor_page_gpa: 0,
2366 },
2367 ..FromZeros::new_zeroed()
2368 },
2369 );
2370
2371 server.send(in_msg(
2372 MessageType::VERSION_RESPONSE,
2373 protocol::VersionResponse {
2374 version_supported: 0,
2375 connection_state: ConnectionState::SUCCESSFUL,
2376 padding: 0,
2377 selected_version_or_connection_id: 0,
2378 },
2379 ));
2380
2381 check_message(
2382 server.next().await.unwrap(),
2383 protocol::InitiateContact {
2384 version_requested: Version::Iron as u32,
2385 target_message_vp: 0,
2386 interrupt_page_or_target_info: TargetInfo::new()
2387 .with_sint(2)
2388 .with_vtl(0)
2389 .with_feature_flags(FeatureFlags::new().into())
2390 .into(),
2391 parent_to_child_monitor_page_gpa: 0,
2392 child_to_parent_monitor_page_gpa: 0,
2393 },
2394 );
2395
2396 server.send(in_msg(
2397 MessageType::VERSION_RESPONSE,
2398 protocol::VersionResponse {
2399 version_supported: 1,
2400 connection_state: ConnectionState::SUCCESSFUL,
2401 padding: 0,
2402 selected_version_or_connection_id: 0,
2403 },
2404 ));
2405
2406 check_message(server.next().await.unwrap(), protocol::RequestOffers {});
2407 server.send(in_msg(MessageType::ALL_OFFERS_DELIVERED, [0x00]));
2408 };
2409
2410 let (connection, ()) = (client_connect, server_connect).join().await;
2411 let connection = connection.unwrap();
2412
2413 assert_eq!(connection.version.version, Version::Iron);
2414 assert_eq!(connection.version.feature_flags, FeatureFlags::new());
2415 }
2416
2417 #[async_test]
2418 async fn test_open_channel_success(driver: DefaultDriver) {
2419 let (mut server, mut client) = test_init(&driver);
2420 let channel = server.get_channel(&mut client).await;
2421
2422 let recv = channel.request_send.call(
2423 ChannelRequest::Open,
2424 OpenRequest {
2425 open_data: OpenData {
2426 target_vp: Some(0),
2427 ring_offset: 0,
2428 ring_gpadl_id: GpadlId(0),
2429 event_flag: 0,
2430 connection_id: 0,
2431 user_data: UserDefinedData::new_zeroed(),
2432 },
2433 incoming_event: None,
2434 use_vtl2_connection_id: false,
2435 },
2436 );
2437
2438 check_message(
2439 server.next().await.unwrap(),
2440 protocol::OpenChannel2 {
2441 open_channel: protocol::OpenChannel {
2442 channel_id: ChannelId(0),
2443 open_id: 0,
2444 ring_buffer_gpadl_id: GpadlId(0),
2445 target_vp: 0,
2446 downstream_ring_buffer_page_offset: 0,
2447 user_data: UserDefinedData::new_zeroed(),
2448 },
2449 connection_id: 0,
2450 event_flag: 0,
2451 flags: Default::default(),
2452 },
2453 );
2454
2455 server.send(in_msg(
2456 MessageType::OPEN_CHANNEL_RESULT,
2457 protocol::OpenResult {
2458 channel_id: ChannelId(0),
2459 open_id: 0,
2460 status: protocol::STATUS_SUCCESS as u32,
2461 },
2462 ));
2463
2464 recv.await.unwrap().unwrap();
2465 }
2466
2467 #[async_test]
2468 async fn test_open_channel_fail(driver: DefaultDriver) {
2469 let (mut server, mut client) = test_init(&driver);
2470 let channel = server.get_channel(&mut client).await;
2471
2472 let recv = channel.request_send.call(
2473 ChannelRequest::Open,
2474 OpenRequest {
2475 open_data: OpenData {
2476 target_vp: Some(0),
2477 ring_offset: 0,
2478 ring_gpadl_id: GpadlId(0),
2479 event_flag: 0,
2480 connection_id: 0,
2481 user_data: UserDefinedData::new_zeroed(),
2482 },
2483 incoming_event: None,
2484 use_vtl2_connection_id: false,
2485 },
2486 );
2487
2488 check_message(
2489 server.next().await.unwrap(),
2490 protocol::OpenChannel2 {
2491 open_channel: protocol::OpenChannel {
2492 channel_id: ChannelId(0),
2493 open_id: 0,
2494 ring_buffer_gpadl_id: GpadlId(0),
2495 target_vp: 0,
2496 downstream_ring_buffer_page_offset: 0,
2497 user_data: UserDefinedData::new_zeroed(),
2498 },
2499 connection_id: 0,
2500 event_flag: 0,
2501 flags: Default::default(),
2502 },
2503 );
2504
2505 server.send(in_msg(
2506 MessageType::OPEN_CHANNEL_RESULT,
2507 protocol::OpenResult {
2508 channel_id: ChannelId(0),
2509 open_id: 0,
2510 status: protocol::STATUS_UNSUCCESSFUL as u32,
2511 },
2512 ));
2513
2514 recv.await.unwrap().unwrap_err();
2515 }
2516
2517 #[async_test]
2518 async fn test_modify_channel(driver: DefaultDriver) {
2519 let (mut server, mut client) = test_init(&driver);
2520 let channel = server.get_channel(&mut client).await;
2521
2522 let recv = channel.request_send.call(
2525 ChannelRequest::Modify,
2526 ModifyRequest::TargetVp { target_vp: 1 },
2527 );
2528
2529 check_message(
2530 server.next().await.unwrap(),
2531 protocol::ModifyChannel {
2532 channel_id: ChannelId(0),
2533 target_vp: 1,
2534 },
2535 );
2536
2537 server.send(in_msg(
2538 MessageType::MODIFY_CHANNEL_RESPONSE,
2539 protocol::ModifyChannelResponse {
2540 channel_id: ChannelId(0),
2541 status: protocol::STATUS_SUCCESS,
2542 },
2543 ));
2544
2545 let status = recv.await.unwrap();
2546 assert_eq!(status, protocol::STATUS_SUCCESS);
2547 }
2548
2549 #[async_test]
2550 async fn test_save_restore_connected(driver: DefaultDriver) {
2551 let (mut server, mut client) = test_init(&driver);
2552 server.connect(&mut client).await;
2553 server.stop_client(&mut client).await;
2554 let s0 = client.save().await;
2555 let builder = client.sever().await;
2556 let mut client = builder.build(&driver);
2557 client.restore(s0.clone()).await.unwrap();
2558
2559 let s1 = client.save().await;
2560
2561 assert_eq!(s0, s1);
2562 }
2563
2564 #[async_test]
2565 async fn test_save_restore_connected_with_channel(driver: DefaultDriver) {
2566 let (mut server, mut client) = test_init(&driver);
2567 let c0 = server.get_channel(&mut client).await;
2568 server.stop_client(&mut client).await;
2569 let s0 = client.save().await;
2570 let builder = client.sever().await;
2571 let mut client = builder.build(&driver);
2572 let connection = client.restore(s0.clone()).await.unwrap().unwrap();
2573 let s1 = client.save().await;
2574 assert_eq!(s0, s1);
2575 assert_eq!(connection.offers[0].offer, c0.offer);
2576 }
2577
2578 #[async_test]
2579 async fn test_save_restore_connected_with_revoked_channel(driver: DefaultDriver) {
2580 let (mut server, mut client) = test_init(&driver);
2581 let c0 = server.get_channel(&mut client).await;
2582 server.send(in_msg(
2583 MessageType::RESCIND_CHANNEL_OFFER,
2584 protocol::RescindChannelOffer {
2585 channel_id: ChannelId(0),
2586 },
2587 ));
2588 c0.revoke_recv.await.unwrap();
2589 let rpc = c0.request_send.call(
2590 ChannelRequest::Modify,
2591 ModifyRequest::TargetVp { target_vp: 1 },
2592 );
2593
2594 check_message(
2595 server.next().await.unwrap(),
2596 protocol::ModifyChannel {
2597 channel_id: ChannelId(0),
2598 target_vp: 1,
2599 },
2600 );
2601
2602 let client_stop = client.stop();
2603 let server_stop = async {
2604 server.send(in_msg(
2605 MessageType::MODIFY_CHANNEL_RESPONSE,
2606 protocol::ModifyChannelResponse {
2607 channel_id: ChannelId(0),
2608 status: protocol::STATUS_SUCCESS,
2609 },
2610 ));
2611 check_message(server.next().await.unwrap(), protocol::Pause);
2612 server.send(in_msg(MessageType::PAUSE_RESPONSE, protocol::PauseResponse));
2613 };
2614 (client_stop, server_stop).join().await;
2615
2616 rpc.await.unwrap();
2617
2618 let s0 = client.save().await;
2619 let builder = client.sever().await;
2620 let mut client = builder.build(&driver);
2621 let connection = client.restore(s0.clone()).await.unwrap().unwrap();
2622 let s1 = client.save().await;
2623 assert_eq!(s0, s1);
2624 assert!(connection.offers.is_empty());
2625 server.start_client(&mut client).await;
2626 check_message(
2627 server.next().await.unwrap(),
2628 protocol::RelIdReleased {
2629 channel_id: ChannelId(0),
2630 },
2631 );
2632 }
2633
2634 #[async_test]
2635 async fn test_connect_fails_on_incorrect_state(driver: DefaultDriver) {
2636 let (mut server, mut client) = test_init(&driver);
2637 server.connect(&mut client).await;
2638 let err = client.connect(0, None, Guid::ZERO).await.unwrap_err();
2639 assert!(matches!(err, ConnectError::InvalidState), "{:?}", err);
2640 }
2641
2642 #[async_test]
2643 async fn test_hot_add_remove(driver: DefaultDriver) {
2644 let (mut server, mut client) = test_init(&driver);
2645
2646 let mut connection = server.connect(&mut client).await;
2647 let offer = protocol::OfferChannel {
2648 interface_id: Guid::new_random(),
2649 instance_id: Guid::new_random(),
2650 rsvd: [0; 4],
2651 flags: OfferFlags::new(),
2652 mmio_megabytes: 0,
2653 user_defined: UserDefinedData::new_zeroed(),
2654 subchannel_index: 0,
2655 mmio_megabytes_optional: 0,
2656 channel_id: ChannelId(5),
2657 monitor_id: 0,
2658 monitor_allocated: 0,
2659 is_dedicated: 0,
2660 connection_id: 0,
2661 };
2662
2663 server.send(in_msg(MessageType::OFFER_CHANNEL, offer));
2664 let info = connection.offer_recv.next().await.unwrap();
2665
2666 assert_eq!(offer, info.offer);
2667
2668 server.send(in_msg(
2669 MessageType::RESCIND_CHANNEL_OFFER,
2670 protocol::RescindChannelOffer {
2671 channel_id: ChannelId(5),
2672 },
2673 ));
2674
2675 info.revoke_recv.await.unwrap();
2676 drop(info.request_send);
2677
2678 check_message(
2679 server.next().await.unwrap(),
2680 protocol::RelIdReleased {
2681 channel_id: ChannelId(5),
2682 },
2683 );
2684 }
2685
2686 #[async_test]
2687 async fn test_gpadl_success(driver: DefaultDriver) {
2688 let (mut server, mut client) = test_init(&driver);
2689 let channel = server.get_channel(&mut client).await;
2690 let recv = channel.request_send.call(
2691 ChannelRequest::Gpadl,
2692 GpadlRequest {
2693 id: GpadlId(1),
2694 count: 1,
2695 buf: vec![5],
2696 },
2697 );
2698
2699 check_message_with_data(
2700 server.next().await.unwrap(),
2701 protocol::GpadlHeader {
2702 channel_id: ChannelId(0),
2703 gpadl_id: GpadlId(1),
2704 len: 8,
2705 count: 1,
2706 },
2707 0x5u64.as_bytes(),
2708 );
2709
2710 server.send(in_msg(
2711 MessageType::GPADL_CREATED,
2712 protocol::GpadlCreated {
2713 channel_id: ChannelId(0),
2714 gpadl_id: GpadlId(1),
2715 status: protocol::STATUS_SUCCESS,
2716 },
2717 ));
2718
2719 recv.await.unwrap().unwrap();
2720
2721 let rpc = channel
2722 .request_send
2723 .call(ChannelRequest::TeardownGpadl, GpadlId(1));
2724
2725 check_message(
2726 server.next().await.unwrap(),
2727 protocol::GpadlTeardown {
2728 channel_id: ChannelId(0),
2729 gpadl_id: GpadlId(1),
2730 },
2731 );
2732
2733 server.send(in_msg(
2734 MessageType::GPADL_TORNDOWN,
2735 protocol::GpadlTorndown {
2736 gpadl_id: GpadlId(1),
2737 },
2738 ));
2739
2740 rpc.await.unwrap();
2741 }
2742
2743 #[async_test]
2744 async fn test_gpadl_fail(driver: DefaultDriver) {
2745 let (mut server, mut client) = test_init(&driver);
2746 let channel = server.get_channel(&mut client).await;
2747 let recv = channel.request_send.call(
2748 ChannelRequest::Gpadl,
2749 GpadlRequest {
2750 id: GpadlId(1),
2751 count: 1,
2752 buf: vec![7],
2753 },
2754 );
2755
2756 check_message_with_data(
2757 server.next().await.unwrap(),
2758 protocol::GpadlHeader {
2759 channel_id: ChannelId(0),
2760 gpadl_id: GpadlId(1),
2761 len: 8,
2762 count: 1,
2763 },
2764 0x7u64.as_bytes(),
2765 );
2766
2767 server.send(in_msg(
2768 MessageType::GPADL_CREATED,
2769 protocol::GpadlCreated {
2770 channel_id: ChannelId(0),
2771 gpadl_id: GpadlId(1),
2772 status: protocol::STATUS_UNSUCCESSFUL,
2773 },
2774 ));
2775
2776 recv.await.unwrap().unwrap_err();
2777 }
2778
2779 #[async_test]
2780 async fn test_gpadl_with_revoke(driver: DefaultDriver) {
2781 let (mut server, mut client) = test_init(&driver);
2782 let channel = server.get_channel(&mut client).await;
2783 let channel_id = ChannelId(0);
2784 for gpadl_id in [1, 2, 3].map(GpadlId) {
2785 let recv = channel.request_send.call(
2786 ChannelRequest::Gpadl,
2787 GpadlRequest {
2788 id: gpadl_id,
2789 count: 1,
2790 buf: vec![3],
2791 },
2792 );
2793
2794 check_message_with_data(
2795 server.next().await.unwrap(),
2796 protocol::GpadlHeader {
2797 channel_id,
2798 gpadl_id,
2799 len: 8,
2800 count: 1,
2801 },
2802 0x3u64.as_bytes(),
2803 );
2804
2805 server.send(in_msg(
2806 MessageType::GPADL_CREATED,
2807 protocol::GpadlCreated {
2808 channel_id,
2809 gpadl_id,
2810 status: protocol::STATUS_SUCCESS,
2811 },
2812 ));
2813
2814 recv.await.unwrap().unwrap();
2815 }
2816
2817 let rpc = channel
2818 .request_send
2819 .call(ChannelRequest::TeardownGpadl, GpadlId(1));
2820
2821 check_message(
2822 server.next().await.unwrap(),
2823 protocol::GpadlTeardown {
2824 channel_id,
2825 gpadl_id: GpadlId(1),
2826 },
2827 );
2828
2829 server.send(in_msg(
2830 MessageType::RESCIND_CHANNEL_OFFER,
2831 protocol::RescindChannelOffer { channel_id },
2832 ));
2833
2834 let recv = channel.request_send.call_failable(
2835 ChannelRequest::Gpadl,
2836 GpadlRequest {
2837 id: GpadlId(4),
2838 count: 1,
2839 buf: vec![3],
2840 },
2841 );
2842
2843 check_message_with_data(
2844 server.next().await.unwrap(),
2845 protocol::GpadlHeader {
2846 channel_id,
2847 gpadl_id: GpadlId(4),
2848 len: 8,
2849 count: 1,
2850 },
2851 0x3u64.as_bytes(),
2852 );
2853
2854 server.send(in_msg(
2855 MessageType::GPADL_CREATED,
2856 protocol::GpadlCreated {
2857 channel_id,
2858 gpadl_id: GpadlId(4),
2859 status: protocol::STATUS_UNSUCCESSFUL,
2860 },
2861 ));
2862
2863 server.send(in_msg(
2864 MessageType::GPADL_TORNDOWN,
2865 protocol::GpadlTorndown {
2866 gpadl_id: GpadlId(1),
2867 },
2868 ));
2869
2870 rpc.await.unwrap();
2871 recv.await.unwrap_err();
2872
2873 channel.revoke_recv.await.unwrap();
2874
2875 let rpc = channel
2876 .request_send
2877 .call(ChannelRequest::TeardownGpadl, GpadlId(2));
2878 drop(channel.request_send);
2879
2880 check_message(
2881 server.next().await.unwrap(),
2882 protocol::GpadlTeardown {
2883 channel_id,
2884 gpadl_id: GpadlId(2),
2885 },
2886 );
2887
2888 server.send(in_msg(
2889 MessageType::GPADL_TORNDOWN,
2890 protocol::GpadlTorndown {
2891 gpadl_id: GpadlId(2),
2892 },
2893 ));
2894
2895 rpc.await.unwrap();
2896
2897 check_message(
2898 server.next().await.unwrap(),
2899 protocol::RelIdReleased { channel_id },
2900 );
2901 }
2902
2903 #[async_test]
2904 async fn test_modify_connection(driver: DefaultDriver) {
2905 let (mut server, mut client) = test_init(&driver);
2906 server.connect(&mut client).await;
2907 let call = client.access.client_request_send.call(
2908 ClientRequest::Modify,
2909 ModifyConnectionRequest {
2910 monitor_page: Some(MonitorPageGpas {
2911 child_to_parent: 5,
2912 parent_to_child: 6,
2913 }),
2914 },
2915 );
2916
2917 check_message(
2918 server.next().await.unwrap(),
2919 protocol::ModifyConnection {
2920 child_to_parent_monitor_page_gpa: 5,
2921 parent_to_child_monitor_page_gpa: 6,
2922 },
2923 );
2924
2925 server.send(in_msg(
2926 MessageType::MODIFY_CONNECTION_RESPONSE,
2927 protocol::ModifyConnectionResponse {
2928 connection_state: ConnectionState::FAILED_LOW_RESOURCES,
2929 },
2930 ));
2931
2932 let result = call.await.unwrap();
2933 assert_eq!(ConnectionState::FAILED_LOW_RESOURCES, result);
2934 }
2935
2936 #[async_test]
2937 async fn test_hvsock(driver: DefaultDriver) {
2938 let (mut server, mut client) = test_init(&driver);
2939 server.connect(&mut client).await;
2940 let request = HvsockConnectRequest {
2941 service_id: Guid::new_random(),
2942 endpoint_id: Guid::new_random(),
2943 silo_id: Guid::new_random(),
2944 hosted_silo_unaware: false,
2945 };
2946
2947 let resp = client.access().connect_hvsock(request);
2948 check_message(
2949 server.next().await.unwrap(),
2950 protocol::TlConnectRequest2 {
2951 base: protocol::TlConnectRequest {
2952 service_id: request.service_id,
2953 endpoint_id: request.endpoint_id,
2954 },
2955 silo_id: request.silo_id,
2956 },
2957 );
2958
2959 server.send(in_msg(
2961 MessageType::TL_CONNECT_REQUEST_RESULT,
2962 protocol::TlConnectResult {
2963 service_id: request.service_id,
2964 endpoint_id: request.endpoint_id,
2965 status: protocol::STATUS_CONNECTION_REFUSED,
2966 },
2967 ));
2968
2969 let result = resp.await;
2970 assert!(result.is_none());
2971 }
2972
2973 #[async_test]
2974 async fn test_synic_event_flags(driver: DefaultDriver) {
2975 let (mut server, mut client) = test_init(&driver);
2976 let connection = server.get_channels(&mut client, 5).await;
2977 let event = Event::new();
2978
2979 for _ in 0..5 {
2980 for (i, channel) in connection.offers.iter().enumerate() {
2981 let recv = channel.request_send.call(
2982 ChannelRequest::Open,
2983 OpenRequest {
2984 open_data: OpenData {
2985 target_vp: Some(0),
2986 ring_offset: 0,
2987 ring_gpadl_id: GpadlId(0),
2988 event_flag: 0,
2989 connection_id: 0,
2990 user_data: UserDefinedData::new_zeroed(),
2991 },
2992 incoming_event: Some(event.clone()),
2993 use_vtl2_connection_id: false,
2994 },
2995 );
2996
2997 let expected_event_flag = i as u16 + 1;
2998
2999 check_message(
3000 server.next().await.unwrap(),
3001 protocol::OpenChannel2 {
3002 open_channel: protocol::OpenChannel {
3003 channel_id: channel.offer.channel_id,
3004 open_id: 0,
3005 ring_buffer_gpadl_id: GpadlId(0),
3006 target_vp: 0,
3007 downstream_ring_buffer_page_offset: 0,
3008 user_data: UserDefinedData::new_zeroed(),
3009 },
3010 connection_id: 0,
3011 event_flag: expected_event_flag,
3012 flags: OpenChannelFlags::new().with_redirect_interrupt(true),
3013 },
3014 );
3015
3016 server.send(in_msg(
3017 MessageType::OPEN_CHANNEL_RESULT,
3018 protocol::OpenResult {
3019 channel_id: channel.offer.channel_id,
3020 open_id: 0,
3021 status: protocol::STATUS_SUCCESS as u32,
3022 },
3023 ));
3024
3025 let output = recv.await.unwrap().unwrap();
3026 assert_eq!(output.redirected_event_flag, Some(expected_event_flag));
3027 }
3028
3029 for (i, channel) in connection.offers.iter().enumerate() {
3030 channel
3033 .request_send
3034 .call(ChannelRequest::Close, ())
3035 .await
3036 .unwrap();
3037
3038 check_message(
3039 server.next().await.unwrap(),
3040 protocol::CloseChannel {
3041 channel_id: ChannelId(i as u32),
3042 },
3043 );
3044 }
3045 }
3046 }
3047
3048 #[async_test]
3049 async fn test_revoke(driver: DefaultDriver) {
3050 let (mut server, mut client) = test_init(&driver);
3051 let channel = server.get_channel(&mut client).await;
3052
3053 server.send(in_msg(
3054 MessageType::RESCIND_CHANNEL_OFFER,
3055 protocol::RescindChannelOffer {
3056 channel_id: ChannelId(0),
3057 },
3058 ));
3059
3060 channel.revoke_recv.await.unwrap();
3061
3062 channel
3063 .request_send
3064 .call_failable(
3065 ChannelRequest::Open,
3066 OpenRequest {
3067 open_data: OpenData {
3068 target_vp: Some(0),
3069 ring_offset: 0,
3070 ring_gpadl_id: GpadlId(0),
3071 event_flag: 0,
3072 connection_id: 0,
3073 user_data: UserDefinedData::new_zeroed(),
3074 },
3075 incoming_event: None,
3076 use_vtl2_connection_id: false,
3077 },
3078 )
3079 .await
3080 .unwrap_err();
3081 }
3082
3083 #[async_test]
3084 #[should_panic(expected = "channel should not exist")]
3085 async fn test_reoffer_in_use_rel_id(driver: DefaultDriver) {
3086 let (mut server, mut client) = test_init(&driver);
3087 let mut connection = server.get_channels(&mut client, 1).await;
3088 let [channel] = connection.offers.try_into().unwrap();
3089
3090 server.send(in_msg(
3091 MessageType::RESCIND_CHANNEL_OFFER,
3092 protocol::RescindChannelOffer {
3093 channel_id: ChannelId(0),
3094 },
3095 ));
3096
3097 channel.revoke_recv.await.unwrap();
3098
3099 let offer = protocol::OfferChannel {
3101 interface_id: Guid::new_random(),
3102 instance_id: Guid::new_random(),
3103 rsvd: [0; 4],
3104 flags: OfferFlags::new(),
3105 mmio_megabytes: 0,
3106 user_defined: UserDefinedData::new_zeroed(),
3107 subchannel_index: 0,
3108 mmio_megabytes_optional: 0,
3109 channel_id: ChannelId(0),
3110 monitor_id: 0,
3111 monitor_allocated: 0,
3112 is_dedicated: 0,
3113 connection_id: 0,
3114 };
3115
3116 server.send(in_msg(MessageType::OFFER_CHANNEL, offer));
3117
3118 connection.offer_recv.next().await;
3119 }
3120
3121 #[async_test]
3122 async fn test_revoke_release_and_reoffer(driver: DefaultDriver) {
3123 let (mut server, mut client) = test_init(&driver);
3124 let mut connection = server.get_channels(&mut client, 1).await;
3125 let [channel] = connection.offers.try_into().unwrap();
3126
3127 server.send(in_msg(
3128 MessageType::RESCIND_CHANNEL_OFFER,
3129 protocol::RescindChannelOffer {
3130 channel_id: ChannelId(0),
3131 },
3132 ));
3133
3134 channel.revoke_recv.await.unwrap();
3135 drop(channel.request_send);
3136
3137 check_message(
3138 server.next().await.unwrap(),
3139 protocol::RelIdReleased {
3140 channel_id: ChannelId(0),
3141 },
3142 );
3143
3144 let offer = protocol::OfferChannel {
3145 interface_id: Guid::new_random(),
3146 instance_id: Guid::new_random(),
3147 rsvd: [0; 4],
3148 flags: OfferFlags::new(),
3149 mmio_megabytes: 0,
3150 user_defined: UserDefinedData::new_zeroed(),
3151 subchannel_index: 0,
3152 mmio_megabytes_optional: 0,
3153 channel_id: ChannelId(0),
3154 monitor_id: 0,
3155 monitor_allocated: 0,
3156 is_dedicated: 0,
3157 connection_id: 0,
3158 };
3159
3160 server.send(in_msg(MessageType::OFFER_CHANNEL, offer));
3161
3162 connection.offer_recv.next().await.unwrap();
3163 }
3164
3165 #[async_test]
3166 async fn test_release_revoke_and_reoffer(driver: DefaultDriver) {
3167 let (mut server, mut client) = test_init(&driver);
3168 let mut connection = server.get_channels(&mut client, 1).await;
3169 let [channel] = connection.offers.try_into().unwrap();
3170
3171 let open = channel.request_send.call_failable(
3172 ChannelRequest::Open,
3173 OpenRequest {
3174 open_data: OpenData {
3175 target_vp: Some(0),
3176 ring_offset: 0,
3177 ring_gpadl_id: GpadlId(0),
3178 event_flag: 0,
3179 connection_id: 0,
3180 user_data: UserDefinedData::new_zeroed(),
3181 },
3182 incoming_event: None,
3183 use_vtl2_connection_id: false,
3184 },
3185 );
3186
3187 let server_open = async {
3188 check_message(
3189 server.next().await.unwrap(),
3190 protocol::OpenChannel2 {
3191 open_channel: protocol::OpenChannel {
3192 channel_id: ChannelId(0),
3193 open_id: 0,
3194 ring_buffer_gpadl_id: GpadlId(0),
3195 target_vp: 0,
3196 downstream_ring_buffer_page_offset: 0,
3197 user_data: UserDefinedData::new_zeroed(),
3198 },
3199 connection_id: 0,
3200 event_flag: 0,
3201 flags: Default::default(),
3202 },
3203 );
3204 server.send(in_msg(
3205 MessageType::OPEN_CHANNEL_RESULT,
3206 protocol::OpenResult {
3207 channel_id: ChannelId(0),
3208 open_id: 0,
3209 status: protocol::STATUS_SUCCESS as u32,
3210 },
3211 ));
3212 };
3213
3214 (open, server_open).join().await.0.unwrap();
3215
3216 drop(channel);
3218
3219 check_message(
3220 server.next().await.unwrap(),
3221 protocol::CloseChannel {
3222 channel_id: ChannelId(0),
3223 },
3224 );
3225
3226 server.send(in_msg(
3227 MessageType::RESCIND_CHANNEL_OFFER,
3228 protocol::RescindChannelOffer {
3229 channel_id: ChannelId(0),
3230 },
3231 ));
3232
3233 check_message(
3235 server.next().await.unwrap(),
3236 protocol::RelIdReleased {
3237 channel_id: ChannelId(0),
3238 },
3239 );
3240
3241 let offer = protocol::OfferChannel {
3242 interface_id: Guid::new_random(),
3243 instance_id: Guid::new_random(),
3244 rsvd: [0; 4],
3245 flags: OfferFlags::new(),
3246 mmio_megabytes: 0,
3247 user_defined: UserDefinedData::new_zeroed(),
3248 subchannel_index: 0,
3249 mmio_megabytes_optional: 0,
3250 channel_id: ChannelId(0),
3251 monitor_id: 0,
3252 monitor_allocated: 0,
3253 is_dedicated: 0,
3254 connection_id: 0,
3255 };
3256
3257 server.send(in_msg(MessageType::OFFER_CHANNEL, offer));
3258
3259 connection.offer_recv.next().await.unwrap();
3261 }
3262}