1#![expect(missing_docs)]
7#![forbid(unsafe_code)]
8
9pub mod driver;
10pub mod filter;
11mod hvsock;
12pub mod saved_state;
13
14pub use self::saved_state::SavedState;
15use anyhow::Context as _;
16use anyhow::Result;
17use futures::FutureExt;
18use futures::StreamExt;
19use futures::future::OptionFuture;
20use futures::stream::SelectAll;
21use futures_concurrency::future::Race;
22use guid::Guid;
23use inspect::Inspect;
24use mesh::rpc::FailableRpc;
25use mesh::rpc::Rpc;
26use mesh::rpc::RpcSend;
27use pal_async::task::Spawn;
28use pal_async::task::Task;
29use pal_event::Event;
30use std::collections::HashMap;
31use std::collections::VecDeque;
32use std::collections::hash_map;
33use std::convert::TryInto;
34use std::future::Future;
35use std::future::poll_fn;
36use std::ops::Deref;
37use std::ops::DerefMut;
38use std::pin::pin;
39use std::sync::Arc;
40use std::sync::atomic::AtomicU32;
41use std::sync::atomic::Ordering;
42use std::task::Context;
43use std::task::Poll;
44use thiserror::Error;
45use vmbus_async::async_dgram::AsyncRecv;
46use vmbus_async::async_dgram::AsyncRecvExt;
47use vmbus_channel::bus::GpadlRequest;
48use vmbus_channel::bus::ModifyRequest;
49use vmbus_channel::bus::OpenData;
50use vmbus_channel::gpadl::GpadlId;
51use vmbus_core::HvsockConnectRequest;
52use vmbus_core::OutgoingMessage;
53use vmbus_core::TaggedStream;
54use vmbus_core::VersionInfo;
55use vmbus_core::protocol;
56use vmbus_core::protocol::ChannelId;
57use vmbus_core::protocol::ConnectionState;
58use vmbus_core::protocol::FeatureFlags;
59use vmbus_core::protocol::Message;
60use vmbus_core::protocol::OpenChannelFlags;
61use vmbus_core::protocol::Version;
62use vmcore::interrupt::Interrupt;
63use vmcore::synic::MonitorPageGpas;
64use zerocopy::Immutable;
65use zerocopy::IntoBytes;
66use zerocopy::KnownLayout;
67
68const SINT: u8 = 2;
69const VTL: u8 = 0;
70const SUPPORTED_VERSIONS: &[Version] = &[Version::Iron, Version::Copper];
71const SUPPORTED_FEATURE_FLAGS: FeatureFlags = FeatureFlags::new()
72 .with_guest_specified_signal_parameters(true)
73 .with_channel_interrupt_redirection(true)
74 .with_modify_connection(true)
75 .with_client_id(true)
76 .with_pause_resume(true);
77
78pub trait SynicEventClient: Send + Sync {
80 fn map_event(&self, event_flag: u16, event: &Event) -> std::io::Result<()>;
82
83 fn unmap_event(&self, event_flag: u16);
85
86 fn signal_event(&self, connection_id: u32, event_flag: u16) -> std::io::Result<()>;
88}
89
90pub trait VmbusMessageSource: AsyncRecv + Send {
92 fn pause_message_stream(&mut self) {}
95
96 fn resume_message_stream(&mut self) {}
98}
99
100pub trait PollPostMessage: Send {
101 fn poll_post_message(
102 &mut self,
103 cx: &mut Context<'_>,
104 connection_id: u32,
105 typ: u32,
106 msg: &[u8],
107 ) -> Poll<()>;
108}
109
110pub struct VmbusClient {
111 task_send: mesh::Sender<TaskRequest>,
112 access: VmbusClientAccess,
113 task: Task<ClientTask>,
114}
115
116#[derive(Debug, thiserror::Error)]
117pub enum ConnectError {
118 #[error("invalid state to connect to the server")]
119 InvalidState,
120 #[error("no supported protocol versions")]
121 NoSupportedVersions,
122 #[error("failed to connect to the server: {0:?}")]
123 FailedToConnect(ConnectionState),
124}
125
126#[derive(Clone)]
127pub struct VmbusClientAccess {
128 client_request_send: mesh::Sender<ClientRequest>,
129}
130
131pub struct VmbusClientBuilder {
133 event_client: Arc<dyn SynicEventClient>,
134 msg_source: Box<dyn VmbusMessageSource>,
135 msg_client: Box<dyn PollPostMessage>,
136}
137
138impl VmbusClientBuilder {
139 pub fn new(
141 event_client: impl SynicEventClient + 'static,
142 msg_source: impl VmbusMessageSource + 'static,
143 msg_client: impl PollPostMessage + 'static,
144 ) -> Self {
145 Self {
146 event_client: Arc::new(event_client),
147 msg_source: Box::new(msg_source),
148 msg_client: Box::new(msg_client),
149 }
150 }
151
152 pub fn build(self, spawner: &impl Spawn) -> VmbusClient {
154 let (task_send, task_recv) = mesh::channel();
155 let (client_request_send, client_request_recv) = mesh::channel();
156
157 let inner = ClientTaskInner {
158 messages: OutgoingMessages {
159 poster: self.msg_client,
160 queued: VecDeque::new(),
161 state: OutgoingMessageState::Paused,
162 },
163 teardown_gpadls: HashMap::new(),
164 channel_requests: SelectAll::new(),
165 synic: SynicState {
166 event_flag_state: Vec::new(),
167 event_client: self.event_client,
168 },
169 };
170
171 let mut task = ClientTask {
172 inner,
173 channels: ChannelList::default(),
174 task_recv,
175 running: false,
176 msg_source: self.msg_source,
177 client_request_recv,
178 state: ClientState::Disconnected,
179 modify_request: None,
180 hvsock_tracker: hvsock::HvsockRequestTracker::new(),
181 };
182
183 let task = spawner.spawn("vmbus client", async move {
184 task.run().await;
185 task
186 });
187
188 VmbusClient {
189 access: VmbusClientAccess {
190 client_request_send,
191 },
192 task_send,
193 task,
194 }
195 }
196}
197
198impl VmbusClient {
199 pub async fn connect(
202 &mut self,
203 target_message_vp: u32,
204 monitor_page: Option<MonitorPageGpas>,
205 client_id: Guid,
206 ) -> Result<ConnectResult, ConnectError> {
207 let request = ConnectRequest {
208 target_message_vp,
209 monitor_page,
210 client_id,
211 };
212
213 self.access
214 .client_request_send
215 .call(ClientRequest::Connect, request)
216 .await
217 .unwrap()
218 }
219
220 pub async fn unload(self) {
221 self.access
222 .client_request_send
223 .call(ClientRequest::Unload, ())
224 .await
225 .unwrap();
226
227 self.sever().await;
228 }
229
230 pub fn access(&self) -> &VmbusClientAccess {
231 &self.access
232 }
233
234 pub fn start(&mut self) {
235 self.task_send.send(TaskRequest::Start);
236 }
237
238 pub async fn stop(&mut self) {
239 self.task_send
240 .call(TaskRequest::Stop, ())
241 .await
242 .expect("Failed to send stop request");
243 }
244
245 pub async fn save(&self) -> SavedState {
246 self.task_send
247 .call(TaskRequest::Save, ())
248 .await
249 .expect("Failed to send save request")
250 }
251
252 pub async fn restore(
253 &mut self,
254 state: SavedState,
255 ) -> Result<Option<ConnectResult>, RestoreError> {
256 self.task_send
257 .call(TaskRequest::Restore, state)
258 .await
259 .expect("Failed to send restore request")
260 }
261
262 pub async fn post_restore(&mut self) {
263 self.task_send
264 .call(TaskRequest::PostRestore, ())
265 .await
266 .expect("Failed to send post-restore request");
267 }
268
269 async fn sever(self) -> VmbusClientBuilder {
270 drop(self.task_send);
271 let task = self.task.await;
272 VmbusClientBuilder {
273 event_client: task.inner.synic.event_client,
274 msg_source: task.msg_source,
275 msg_client: task.inner.messages.poster,
276 }
277 }
278}
279
280impl Inspect for VmbusClient {
281 fn inspect(&self, req: inspect::Request<'_>) {
282 self.task_send.send(TaskRequest::Inspect(req.defer()));
283 }
284}
285
286#[derive(Debug)]
287pub struct ConnectResult {
288 pub version: VersionInfo,
289 pub offers: Vec<OfferInfo>,
290 pub offer_recv: mesh::Receiver<OfferInfo>,
291}
292
293impl VmbusClientAccess {
294 pub async fn modify(&self, request: ModifyConnectionRequest) -> ConnectionState {
295 self.client_request_send
296 .call(ClientRequest::Modify, request)
297 .await
298 .expect("Failed to send modify request")
299 }
300
301 pub fn connect_hvsock(
302 &self,
303 request: HvsockConnectRequest,
304 ) -> impl Future<Output = Option<OfferInfo>> + use<> {
305 self.client_request_send
306 .call(ClientRequest::HvsockConnect, request)
307 .map(|r| r.ok().flatten())
308 }
309}
310
311#[derive(Debug)]
312pub struct OpenRequest {
313 pub open_data: OpenData,
314 pub incoming_event: Option<Event>,
315 pub use_vtl2_connection_id: bool,
316}
317
318#[derive(Debug)]
319pub struct RestoreRequest {
320 pub incoming_event: Option<Event>,
321 pub redirected_event_flag: Option<u16>,
323 pub connection_id: u32,
325}
326
327pub enum ChannelRequest {
329 Open(FailableRpc<OpenRequest, OpenOutput>),
330 Restore(FailableRpc<RestoreRequest, OpenOutput>),
331 Close(Rpc<(), ()>),
332 Gpadl(FailableRpc<GpadlRequest, ()>),
333 TeardownGpadl(Rpc<GpadlId, ()>),
334 Modify(Rpc<ModifyRequest, i32>),
335}
336
337#[derive(Debug)]
338pub struct OpenOutput {
339 pub redirected_event_flag: Option<u16>,
341}
342
343impl std::fmt::Display for ChannelRequest {
344 fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
345 let s = match self {
346 ChannelRequest::Open(_) => "Open",
347 ChannelRequest::Close(_) => "Close",
348 ChannelRequest::Restore(_) => "Restore",
349 ChannelRequest::Gpadl(_) => "Gpadl",
350 ChannelRequest::TeardownGpadl(_) => "TeardownGpadl",
351 ChannelRequest::Modify(_) => "Modify",
352 };
353 fmt.pad(s)
354 }
355}
356
357#[derive(Debug, Error)]
358pub enum RestoreError {
359 #[error("unsupported protocol version {0:#x}")]
360 UnsupportedVersion(u32),
361
362 #[error("unsupported feature flags {0:#x}")]
363 UnsupportedFeatureFlags(u32),
364
365 #[error("duplicate channel id {0}")]
366 DuplicateChannelId(u32),
367
368 #[error("duplicate gpadl id {0}")]
369 DuplicateGpadlId(u32),
370
371 #[error("gpadl for unknown channel id {0}")]
372 GpadlForUnknownChannelId(u32),
373
374 #[error("invalid pending message")]
375 InvalidPendingMessage(#[source] vmbus_core::MessageTooLarge),
376
377 #[error("failed to offer restored channel")]
378 OfferFailed(#[source] anyhow::Error),
379}
380
381#[derive(Debug, Inspect)]
384pub struct OfferInfo {
385 pub offer: protocol::OfferChannel,
386 #[inspect(skip)]
387 pub guest_to_host_interrupt: Interrupt,
388 #[inspect(skip)]
389 pub request_send: mesh::Sender<ChannelRequest>,
390 #[inspect(skip)]
391 pub revoke_recv: mesh::OneshotReceiver<()>,
392}
393
394#[derive(Debug)]
395enum ClientRequest {
396 Connect(Rpc<ConnectRequest, Result<ConnectResult, ConnectError>>),
397 Unload(Rpc<(), ()>),
398 Modify(Rpc<ModifyConnectionRequest, ConnectionState>),
399 HvsockConnect(Rpc<HvsockConnectRequest, Option<OfferInfo>>),
400}
401
402impl std::fmt::Display for ClientRequest {
403 fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
404 let s = match self {
405 ClientRequest::Connect(..) => "Connect",
406 ClientRequest::Unload { .. } => "Unload",
407 ClientRequest::Modify(..) => "Modify",
408 ClientRequest::HvsockConnect(..) => "HvsockConnect",
409 };
410 fmt.pad(s)
411 }
412}
413
414enum TaskRequest {
415 Inspect(inspect::Deferred),
416 Save(Rpc<(), SavedState>),
417 Restore(Rpc<SavedState, Result<Option<ConnectResult>, RestoreError>>),
418 PostRestore(Rpc<(), ()>),
419 Start,
420 Stop(Rpc<(), ()>),
421}
422
423#[derive(Inspect)]
427#[inspect(external_tag)]
428enum ClientState {
429 Disconnected,
431 Connecting {
433 version: Version,
434 #[inspect(skip)]
435 rpc: Rpc<ConnectRequest, Result<ConnectResult, ConnectError>>,
436 },
437 Connected {
439 version: VersionInfo,
440 #[inspect(skip)]
441 offer_send: mesh::Sender<OfferInfo>,
442 },
443 RequestingOffers {
445 version: VersionInfo,
446 #[inspect(skip)]
447 rpc: Rpc<(), Result<ConnectResult, ConnectError>>,
448 #[inspect(skip)]
449 offers: Vec<OfferInfo>,
450 },
451 Disconnecting {
453 version: VersionInfo,
454 #[inspect(skip)]
455 rpc: Rpc<(), ()>,
456 },
457}
458
459impl ClientState {
460 fn get_version(&self) -> Option<VersionInfo> {
461 match self {
462 ClientState::Connected { version, .. } => Some(*version),
463 ClientState::RequestingOffers { version, .. } => Some(*version),
464 ClientState::Disconnecting { version, .. } => Some(*version),
465 ClientState::Disconnected | ClientState::Connecting { .. } => None,
466 }
467 }
468}
469
470impl std::fmt::Display for ClientState {
471 fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
472 let s = match self {
473 ClientState::Disconnected => "Disconnected",
474 ClientState::Connecting { .. } => "Connecting",
475 ClientState::Connected { .. } => "Connected",
476 ClientState::RequestingOffers { .. } => "RequestingOffers",
477 ClientState::Disconnecting { .. } => "Disconnecting",
478 };
479 fmt.pad(s)
480 }
481}
482
483#[derive(Copy, Clone, Debug, Default)]
484struct ConnectRequest {
485 target_message_vp: u32,
486 monitor_page: Option<MonitorPageGpas>,
487 client_id: Guid,
488}
489
490#[derive(Copy, Clone, Debug, Default)]
491pub struct ModifyConnectionRequest {
492 pub monitor_page: Option<MonitorPageGpas>,
493}
494
495impl From<ModifyConnectionRequest> for protocol::ModifyConnection {
496 fn from(value: ModifyConnectionRequest) -> Self {
497 let monitor_page = value.monitor_page.unwrap_or_default();
498
499 Self {
500 parent_to_child_monitor_page_gpa: monitor_page.parent_to_child,
501 child_to_parent_monitor_page_gpa: monitor_page.child_to_parent,
502 }
503 }
504}
505
506#[derive(Debug, Inspect)]
510#[inspect(external_tag)]
511enum ChannelState {
512 Offered,
514 Opening {
516 redirected_event_flag: Option<u16>,
517 #[inspect(skip)]
518 redirected_event: Option<Event>,
519 #[inspect(skip)]
520 rpc: FailableRpc<(), OpenOutput>,
521 },
522 Restored,
524 Opened {
526 redirected_event_flag: Option<u16>,
527 #[inspect(skip)]
528 redirected_event: Option<Event>,
529 },
530 Revoked,
532}
533
534impl std::fmt::Display for ChannelState {
535 fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
536 let s = match self {
537 ChannelState::Opening { .. } => "Opening",
538 ChannelState::Offered => "Offered",
539 ChannelState::Opened { .. } => "Opened",
540 ChannelState::Restored => "Restored",
541 ChannelState::Revoked => "Revoked",
542 };
543 fmt.pad(s)
544 }
545}
546
547#[derive(Debug, Inspect)]
548struct Channel {
549 offer: protocol::OfferChannel,
550 #[inspect(skip)]
552 revoke_send: Option<mesh::OneshotSender<()>>,
553 state: ChannelState,
554 #[inspect(with = "|x| x.is_some()")]
555 modify_response_send: Option<Rpc<(), i32>>,
556 #[inspect(with = "|x| inspect::iter_by_key(x).map_key(|x| x.0)")]
557 gpadls: HashMap<GpadlId, GpadlState>,
558 is_client_released: bool,
559 connection_id: Arc<AtomicU32>,
560}
561
562impl Channel {
563 fn pending_request(&self) -> Option<&'static str> {
564 if self.modify_response_send.is_some() {
565 return Some("modify");
566 }
567 self.gpadls.iter().find_map(|(_, gpadl)| match gpadl {
568 GpadlState::Offered(_) => Some("creating gpadl"),
569 GpadlState::Created => None,
570 GpadlState::TearingDown { .. } => Some("tearing down gpadl"),
571 })
572 }
573}
574
575#[derive(Inspect)]
576struct ClientTask {
577 #[inspect(flatten)]
578 inner: ClientTaskInner,
579 channels: ChannelList,
580 state: ClientState,
581 hvsock_tracker: hvsock::HvsockRequestTracker,
582 running: bool,
583 #[inspect(with = "|x| x.is_some()")]
584 modify_request: Option<Rpc<ModifyConnectionRequest, ConnectionState>>,
585 #[inspect(skip)]
586 msg_source: Box<dyn VmbusMessageSource>,
587 #[inspect(skip)]
588 task_recv: mesh::Receiver<TaskRequest>,
589 #[inspect(skip)]
590 client_request_recv: mesh::Receiver<ClientRequest>,
591}
592
593impl ClientTask {
594 fn handle_initiate_contact(
595 &mut self,
596 rpc: Rpc<ConnectRequest, Result<ConnectResult, ConnectError>>,
597 version: Version,
598 ) {
599 let ClientState::Disconnected = self.state else {
600 tracing::warn!(client_state = %self.state, "invalid client state for InitiateContact");
601 rpc.complete(Err(ConnectError::InvalidState));
602 return;
603 };
604 let feature_flags = if version >= Version::Copper {
605 SUPPORTED_FEATURE_FLAGS
606 } else {
607 FeatureFlags::new()
608 };
609
610 let request = rpc.input();
611
612 tracing::debug!(version = ?version, ?feature_flags, "VmBus client connecting");
613 let target_info = protocol::TargetInfo::new()
614 .with_sint(SINT)
615 .with_vtl(VTL)
616 .with_feature_flags(feature_flags.into());
617 let monitor_page = request.monitor_page.unwrap_or_default();
618 let msg = protocol::InitiateContact2 {
619 initiate_contact: protocol::InitiateContact {
620 version_requested: version as u32,
621 target_message_vp: request.target_message_vp,
622 interrupt_page_or_target_info: target_info.into(),
623 parent_to_child_monitor_page_gpa: monitor_page.parent_to_child,
624 child_to_parent_monitor_page_gpa: monitor_page.child_to_parent,
625 },
626 client_id: request.client_id,
627 };
628
629 self.state = ClientState::Connecting { version, rpc };
630 if version < Version::Copper {
631 self.inner.messages.send(&msg.initiate_contact)
632 } else {
633 self.inner.messages.send(&msg);
634 }
635 }
636
637 fn handle_unload(&mut self, rpc: Rpc<(), ()>) {
638 tracing::debug!(%self.state, "VmBus client disconnecting");
639 self.state = ClientState::Disconnecting {
640 version: self.state.get_version().expect("invalid state for unload"),
641 rpc,
642 };
643
644 self.inner.messages.send(&protocol::Unload {});
645 }
646
647 fn handle_modify(&mut self, request: Rpc<ModifyConnectionRequest, ConnectionState>) {
648 if !matches!(self.state, ClientState::Connected { version, .. }
649 if version.feature_flags.modify_connection())
650 {
651 tracing::warn!("ModifyConnection not supported");
652 request.complete(ConnectionState::FAILED_UNKNOWN_FAILURE);
653 return;
654 }
655
656 if self.modify_request.is_some() {
657 tracing::warn!("Duplicate ModifyConnection request");
658 request.complete(ConnectionState::FAILED_UNKNOWN_FAILURE);
659 return;
660 }
661
662 let message = protocol::ModifyConnection::from(*request.input());
663 self.modify_request = Some(request);
664 self.inner.messages.send(&message);
665 }
666
667 fn handle_tl_connect(&mut self, rpc: Rpc<HvsockConnectRequest, Option<OfferInfo>>) {
668 let message = protocol::TlConnectRequest2::from(*rpc.input());
672 self.hvsock_tracker.add_request(rpc);
673 self.inner.messages.send(&message);
674 }
675
676 fn handle_client_request(&mut self, request: ClientRequest) {
677 match request {
678 ClientRequest::Connect(rpc) => {
679 self.handle_initiate_contact(rpc, *SUPPORTED_VERSIONS.last().unwrap());
680 }
681 ClientRequest::Unload(rpc) => {
682 self.handle_unload(rpc);
683 }
684 ClientRequest::Modify(request) => self.handle_modify(request),
685 ClientRequest::HvsockConnect(request) => self.handle_tl_connect(request),
686 }
687 }
688
689 fn handle_version_response(&mut self, msg: protocol::VersionResponse2) {
690 let old_state = std::mem::replace(&mut self.state, ClientState::Disconnected);
691 let ClientState::Connecting { version, rpc } = old_state else {
692 self.state = old_state;
693 tracing::warn!(
694 client_state = %self.state,
695 "invalid client state to handle VersionResponse"
696 );
697 return;
698 };
699 if msg.version_response.version_supported > 0 {
700 if msg.version_response.connection_state != ConnectionState::SUCCESSFUL {
701 rpc.complete(Err(ConnectError::FailedToConnect(
702 msg.version_response.connection_state,
703 )));
704 return;
705 }
706
707 let feature_flags = if version >= Version::Copper {
708 FeatureFlags::from(msg.supported_features)
709 } else {
710 FeatureFlags::new()
711 };
712
713 let version = VersionInfo {
714 version,
715 feature_flags,
716 };
717
718 self.inner.messages.send(&protocol::RequestOffers {});
719 self.state = ClientState::RequestingOffers {
720 version,
721 rpc: rpc.split().1,
722 offers: Vec::new(),
723 };
724 tracing::info!(?version, "VmBus client connected, requesting offers");
725 } else {
726 let index = SUPPORTED_VERSIONS
727 .iter()
728 .position(|v| *v == version)
729 .unwrap();
730
731 if index == 0 {
732 rpc.complete(Err(ConnectError::NoSupportedVersions));
733 return;
734 }
735 let next_version = SUPPORTED_VERSIONS[index - 1];
736 tracing::debug!(
737 version = version as u32,
738 next_version = next_version as u32,
739 "Unsupported version, retrying"
740 );
741 self.handle_initiate_contact(rpc, next_version);
742 }
743 }
744
745 fn create_channel(&mut self, offer: protocol::OfferChannel) -> Result<OfferInfo> {
746 self.create_channel_core(offer, ChannelState::Offered)
747 }
748
749 fn create_channel_core(
750 &mut self,
751 offer: protocol::OfferChannel,
752 state: ChannelState,
753 ) -> Result<OfferInfo> {
754 if self.channels.0.contains_key(&offer.channel_id) {
755 anyhow::bail!("channel {:?} exists", offer.channel_id);
756 }
757 let (request_send, request_recv) = mesh::channel();
758 let (revoke_send, revoke_recv) = mesh::oneshot();
759
760 let connection_id = Arc::new(AtomicU32::new(0));
761 self.channels.0.insert(
762 offer.channel_id,
763 Channel {
764 revoke_send: Some(revoke_send),
765 offer,
766 state,
767 modify_response_send: None,
768 gpadls: HashMap::new(),
769 is_client_released: false,
770 connection_id: connection_id.clone(),
771 },
772 );
773
774 self.inner
775 .channel_requests
776 .push(TaggedStream::new(offer.channel_id, request_recv));
777
778 Ok(OfferInfo {
779 offer,
780 guest_to_host_interrupt: self.inner.synic.guest_to_host_interrupt(connection_id),
781 revoke_recv,
782 request_send,
783 })
784 }
785
786 fn handle_offer(&mut self, offer: protocol::OfferChannel) {
787 let offer_info = self
788 .create_channel(offer)
789 .expect("channel should not exist");
790
791 tracing::info!(
792 state = %self.state,
793 channel_id = offer.channel_id.0,
794 interface_id = %offer.interface_id,
795 instance_id = %offer.instance_id,
796 subchannel_index = offer.subchannel_index,
797 "received offer");
798
799 if let Some(offer) = self.hvsock_tracker.check_offer(&offer_info.offer) {
800 offer.complete(Some(offer_info));
801 } else {
802 match &mut self.state {
803 ClientState::Connected { offer_send, .. } => {
804 offer_send.send(offer_info);
805 }
806 ClientState::RequestingOffers { offers, .. } => {
807 offers.push(offer_info);
808 }
809 state => unreachable!("invalid client state for OfferChannel: {state}"),
810 }
811 }
812 }
813
814 fn handle_rescind(&mut self, rescind: protocol::RescindChannelOffer) -> TriedRelease {
815 tracing::info!(state = %self.state, channel_id = rescind.channel_id.0, "received rescind");
816
817 let mut channel = self.channels.get_mut(rescind.channel_id);
818 let event_flag = match std::mem::replace(&mut channel.state, ChannelState::Revoked) {
819 ChannelState::Offered => None,
820 ChannelState::Opening {
821 redirected_event_flag,
822 redirected_event: _,
823 rpc,
824 } => {
825 rpc.fail(anyhow::anyhow!("channel revoked"));
826 redirected_event_flag
827 }
828 ChannelState::Restored => None,
829 ChannelState::Opened {
830 redirected_event_flag,
831 redirected_event: _,
832 } => redirected_event_flag,
833 ChannelState::Revoked => {
834 panic!("channel id {:?} already revoked", rescind.channel_id);
835 }
836 };
837 if let Some(event_flag) = event_flag {
838 self.inner.synic.free_event_flag(event_flag);
839 }
840
841 channel.revoke_send.take().unwrap().send(());
843
844 channel.try_release(&mut self.inner.messages)
845 }
846
847 fn handle_offers_delivered(&mut self) {
848 match std::mem::replace(&mut self.state, ClientState::Disconnected) {
849 ClientState::RequestingOffers {
850 version,
851 rpc,
852 offers,
853 } => {
854 tracing::info!(version = ?version, "VmBus client connected, offers delivered");
855 let (offer_send, offer_recv) = mesh::channel();
856 self.state = ClientState::Connected {
857 version,
858 offer_send,
859 };
860 rpc.complete(Ok(ConnectResult {
861 version,
862 offers,
863 offer_recv,
864 }));
865 }
866 state => {
867 tracing::warn!(client_state = %state, "invalid client state for OffersDelivered");
868 self.state = state;
869 }
870 }
871 }
872
873 fn handle_gpadl_created(&mut self, request: protocol::GpadlCreated) -> TriedRelease {
874 let mut channel = self.channels.get_mut(request.channel_id);
875 let Some(gpadl_state) = channel.gpadls.get_mut(&request.gpadl_id) else {
876 panic!("GpadlCreated for unknown gpadl {:#x}", request.gpadl_id.0);
877 };
878
879 let rpc = match std::mem::replace(gpadl_state, GpadlState::Created) {
880 GpadlState::Offered(rpc) => rpc,
881 old_state => {
882 panic!(
883 "invalid state {old_state:?} for gpadl {:#x}:{:#x}",
884 request.channel_id.0, request.gpadl_id.0
885 );
886 }
887 };
888
889 let gpadl_created = request.status == protocol::STATUS_SUCCESS;
890 if gpadl_created {
891 rpc.complete(Ok(()));
892 } else {
893 channel.gpadls.remove(&request.gpadl_id).unwrap();
894 rpc.fail(anyhow::anyhow!(
895 "gpadl creation failed: {:#x}",
896 request.status
897 ));
898 };
899 channel.try_release(&mut self.inner.messages)
900 }
901
902 fn handle_open_result(&mut self, result: protocol::OpenResult) {
903 tracing::debug!(
904 channel_id = result.channel_id.0,
905 result = result.status,
906 "received open result"
907 );
908
909 let mut channel = self.channels.get_mut(result.channel_id);
910
911 let channel_opened = result.status == protocol::STATUS_SUCCESS as u32;
912 let old_state = std::mem::replace(&mut channel.state, ChannelState::Offered);
913 let ChannelState::Opening {
914 redirected_event_flag,
915 redirected_event,
916 rpc,
917 } = old_state
918 else {
919 tracing::warn!(
920 old_state = ?channel.state,
921 channel_opened,
922 "invalid state for open result"
923 );
924 channel.state = old_state;
925 return;
926 };
927
928 if !channel_opened {
929 if let Some(event_flag) = redirected_event_flag {
930 self.inner.synic.free_event_flag(event_flag);
931 }
932 rpc.fail(anyhow::anyhow!("open failed: {:#x}", result.status));
933 return;
934 }
935
936 channel.state = ChannelState::Opened {
937 redirected_event_flag,
938 redirected_event,
939 };
940
941 rpc.complete(Ok(OpenOutput {
942 redirected_event_flag,
943 }));
944 }
945
946 fn handle_gpadl_torndown(&mut self, request: protocol::GpadlTorndown) -> TriedRelease {
947 let Some(channel_id) = self.inner.teardown_gpadls.remove(&request.gpadl_id) else {
948 panic!("gpadl {:#x} not in teardown list", request.gpadl_id.0);
949 };
950
951 tracing::debug!(
952 gpadl_id = request.gpadl_id.0,
953 channel_id = channel_id.0,
954 "Received GpadlTorndown"
955 );
956
957 let mut channel = self.channels.get_mut(channel_id);
958 let gpadl_state = channel
959 .gpadls
960 .remove(&request.gpadl_id)
961 .expect("gpadl validated above");
962
963 let GpadlState::TearingDown { rpcs } = gpadl_state else {
964 panic!("gpadl should be tearing down if in teardown list, state = {gpadl_state:?}");
965 };
966
967 for rpc in rpcs {
968 rpc.complete(());
969 }
970 channel.try_release(&mut self.inner.messages)
971 }
972
973 fn handle_unload_complete(&mut self) {
974 match std::mem::replace(&mut self.state, ClientState::Disconnected) {
975 ClientState::Disconnecting { version: _, rpc } => {
976 tracing::info!("VmBus client disconnected");
977 rpc.complete(());
978 }
979 state => {
980 tracing::warn!(client_state = %state, "invalid client state for UnloadComplete");
981 }
982 }
983 }
984
985 fn handle_modify_complete(&mut self, response: protocol::ModifyConnectionResponse) {
986 if let Some(request) = self.modify_request.take() {
987 request.complete(response.connection_state)
988 } else {
989 tracing::warn!("Unexpected modify complete request");
990 }
991 }
992
993 fn handle_modify_channel_response(
994 &mut self,
995 response: protocol::ModifyChannelResponse,
996 ) -> TriedRelease {
997 let mut channel = self.channels.get_mut(response.channel_id);
998 let Some(sender) = channel.modify_response_send.take() else {
999 panic!(
1000 "unexpected modify channel response for channel {:#x}",
1001 response.channel_id.0
1002 );
1003 };
1004
1005 sender.complete(response.status);
1006 channel.try_release(&mut self.inner.messages)
1007 }
1008
1009 fn handle_tl_connect_result(&mut self, response: protocol::TlConnectResult) {
1010 if let Some(rpc) = self.hvsock_tracker.check_result(&response) {
1011 rpc.complete(None);
1012 }
1013 }
1014
1015 fn handle_synic_message(&mut self, data: &[u8]) -> bool {
1017 let msg = Message::parse(data, self.state.get_version()).unwrap();
1018 tracing::trace!(?msg, "received client message from synic");
1019
1020 match msg {
1021 Message::VersionResponse3(version_response, ..) => {
1022 self.handle_version_response(version_response.version_response2);
1027 }
1028 Message::VersionResponse2(version_response, ..) => {
1029 self.handle_version_response(version_response);
1030 }
1031 Message::VersionResponse(version_response, ..) => {
1032 self.handle_version_response(version_response.into());
1033 }
1034 Message::OfferChannel(offer, ..) => {
1035 self.handle_offer(offer);
1036 }
1037 Message::AllOffersDelivered(..) => {
1038 self.handle_offers_delivered();
1039 }
1040 Message::UnloadComplete(..) => {
1041 self.handle_unload_complete();
1042 }
1043 Message::ModifyConnectionResponse(response, ..) => {
1044 self.handle_modify_complete(response);
1045 }
1046 Message::GpadlCreated(gpadl, ..) => {
1047 self.handle_gpadl_created(gpadl);
1048 }
1049 Message::OpenResult(result, ..) => {
1050 self.handle_open_result(result);
1051 }
1052 Message::GpadlTorndown(gpadl, ..) => {
1053 self.handle_gpadl_torndown(gpadl);
1054 }
1055 Message::RescindChannelOffer(rescind, ..) => {
1056 self.handle_rescind(rescind);
1057 }
1058 Message::ModifyChannelResponse(response, ..) => {
1059 self.handle_modify_channel_response(response);
1060 }
1061 Message::TlConnectResult(response, ..) => self.handle_tl_connect_result(response),
1062 Message::CloseReservedChannelResponse(..) => {
1064 todo!("Unsupported message {msg:?}")
1065 }
1066 Message::PauseResponse(..) => {
1067 return false;
1068 }
1069 Message::RequestOffers(..)
1071 | Message::OpenChannel2(..)
1072 | Message::OpenChannel(..)
1073 | Message::CloseChannel(..)
1074 | Message::GpadlHeader(..)
1075 | Message::GpadlBody(..)
1076 | Message::GpadlTeardown(..)
1077 | Message::RelIdReleased(..)
1078 | Message::InitiateContact(..)
1079 | Message::InitiateContact2(..)
1080 | Message::Unload(..)
1081 | Message::OpenReservedChannel(..)
1082 | Message::CloseReservedChannel(..)
1083 | Message::TlConnectRequest2(..)
1084 | Message::TlConnectRequest(..)
1085 | Message::ModifyChannel(..)
1086 | Message::ModifyConnection(..)
1087 | Message::Pause(..)
1088 | Message::Resume(..) => {
1089 unreachable!("Client received server message {msg:?}");
1090 }
1091 }
1092 true
1093 }
1094
1095 fn handle_open_channel(
1096 &mut self,
1097 channel_id: ChannelId,
1098 rpc: FailableRpc<OpenRequest, OpenOutput>,
1099 ) {
1100 let mut channel = self.channels.get_mut(channel_id);
1101 match &channel.state {
1102 ChannelState::Offered => {}
1103 ChannelState::Revoked => {
1104 rpc.fail(anyhow::anyhow!("channel revoked"));
1105 return;
1106 }
1107 state => {
1108 rpc.fail(anyhow::anyhow!("invalid channel state: {}", state));
1109 return;
1110 }
1111 }
1112
1113 tracing::info!(channel_id = channel_id.0, "opening channel on host");
1114 let (request, rpc) = rpc.split();
1115 let open_data = &request.open_data;
1116
1117 let supports_interrupt_redirection =
1118 if let ClientState::Connected { version, .. } = self.state {
1119 version.feature_flags.guest_specified_signal_parameters()
1120 || version.feature_flags.channel_interrupt_redirection()
1121 } else {
1122 false
1123 };
1124
1125 if !supports_interrupt_redirection && open_data.event_flag != channel_id.0 as u16 {
1126 rpc.fail(anyhow::anyhow!(
1127 "host does not support specifying the event flag"
1128 ));
1129 return;
1130 }
1131
1132 let open_channel = protocol::OpenChannel {
1133 channel_id,
1134 open_id: 0,
1135 ring_buffer_gpadl_id: open_data.ring_gpadl_id,
1136 target_vp: open_data.target_vp,
1137 downstream_ring_buffer_page_offset: open_data.ring_offset,
1138 user_data: open_data.user_data,
1139 };
1140
1141 let connection_id = if request.use_vtl2_connection_id {
1142 if !supports_interrupt_redirection {
1143 rpc.fail(anyhow::anyhow!(
1144 "host does not support specfiying the connection ID"
1145 ));
1146 return;
1147 }
1148 protocol::ConnectionId::new(channel_id.0, 2.try_into().unwrap(), 7).0
1149 } else {
1150 open_data.connection_id
1151 };
1152
1153 let mut flags = OpenChannelFlags::new();
1156 let event_flag = if let Some(event) = &request.incoming_event {
1157 if !supports_interrupt_redirection {
1158 rpc.fail(anyhow::anyhow!(
1159 "host does not support redirecting interrupts"
1160 ));
1161 return;
1162 }
1163
1164 flags.set_redirect_interrupt(true);
1165 match self.inner.synic.allocate_event_flag(event) {
1166 Ok(flag) => flag,
1167 Err(err) => {
1168 rpc.fail(err.context("failed to allocate event flag"));
1169 return;
1170 }
1171 }
1172 } else {
1173 open_data.event_flag
1174 };
1175
1176 if supports_interrupt_redirection {
1177 self.inner.messages.send(&protocol::OpenChannel2 {
1178 open_channel,
1179 connection_id,
1180 event_flag,
1181 flags,
1182 });
1183 } else {
1184 self.inner.messages.send(&open_channel);
1185 }
1186
1187 channel
1188 .connection_id
1189 .store(connection_id, Ordering::Release);
1190 channel.state = ChannelState::Opening {
1191 redirected_event_flag: (request.incoming_event.is_some()).then_some(event_flag),
1192 redirected_event: request.incoming_event,
1193 rpc,
1194 }
1195 }
1196
1197 fn handle_restore_channel(
1198 &mut self,
1199 channel_id: ChannelId,
1200 request: RestoreRequest,
1201 ) -> Result<OpenOutput> {
1202 let mut channel = self.channels.get_mut(channel_id);
1203 if !matches!(channel.state, ChannelState::Restored) {
1204 anyhow::bail!("invalid channel state: {}", channel.state);
1205 }
1206
1207 if request.incoming_event.is_some() != request.redirected_event_flag.is_some() {
1208 anyhow::bail!("incoming event and redirected event flag must both be set or unset");
1209 }
1210
1211 if let Some((flag, event)) = request
1212 .redirected_event_flag
1213 .zip(request.incoming_event.as_ref())
1214 {
1215 self.inner.synic.restore_event_flag(flag, event)?;
1216 }
1217
1218 channel
1219 .connection_id
1220 .store(request.connection_id, Ordering::Release);
1221 channel.state = ChannelState::Opened {
1222 redirected_event_flag: request.redirected_event_flag,
1223 redirected_event: request.incoming_event,
1224 };
1225 Ok(OpenOutput {
1226 redirected_event_flag: request.redirected_event_flag,
1227 })
1228 }
1229
1230 fn handle_gpadl(&mut self, channel_id: ChannelId, rpc: FailableRpc<GpadlRequest, ()>) {
1231 let (request, rpc) = rpc.split();
1232 let mut channel = self.channels.get_mut(channel_id);
1233 if channel
1234 .gpadls
1235 .insert(request.id, GpadlState::Offered(rpc))
1236 .is_some()
1237 {
1238 panic!(
1239 "duplicate gpadl ID {:?} for channel {:?}.",
1240 request.id, channel_id
1241 );
1242 }
1243
1244 tracing::trace!(
1245 channel_id = channel_id.0,
1246 gpadl_id = request.id.0,
1247 count = request.count,
1248 len = request.buf.len(),
1249 "received gpadl request"
1250 );
1251
1252 let (first, remaining) = if request.buf.len() > protocol::GpadlHeader::MAX_DATA_VALUES {
1254 request.buf.split_at(protocol::GpadlHeader::MAX_DATA_VALUES)
1255 } else {
1256 (request.buf.as_slice(), [].as_slice())
1257 };
1258
1259 let message = protocol::GpadlHeader {
1260 channel_id,
1261 gpadl_id: request.id,
1262 len: (request.buf.len() * size_of::<u64>())
1263 .try_into()
1264 .expect("Too many GPA values"),
1265 count: request.count,
1266 };
1267
1268 self.inner
1269 .messages
1270 .send_with_data(&message, first.as_bytes());
1271
1272 let message = protocol::GpadlBody {
1274 rsvd: 0,
1275 gpadl_id: request.id,
1276 };
1277 for chunk in remaining.chunks(protocol::GpadlBody::MAX_DATA_VALUES) {
1278 self.inner
1279 .messages
1280 .send_with_data(&message, chunk.as_bytes());
1281 }
1282 }
1283
1284 fn handle_gpadl_teardown(&mut self, channel_id: ChannelId, rpc: Rpc<GpadlId, ()>) {
1285 let (gpadl_id, rpc) = rpc.split();
1286 let mut channel = self.channels.get_mut(channel_id);
1287 let Some(gpadl_state) = channel.gpadls.get_mut(&gpadl_id) else {
1288 tracing::warn!(
1289 gpadl_id = gpadl_id.0,
1290 channel_id = channel_id.0,
1291 "Gpadl teardown for unknown gpadl or revoked channel"
1292 );
1293 return;
1294 };
1295
1296 match gpadl_state {
1297 GpadlState::Offered(_) => {
1298 tracing::warn!(
1299 gpadl_id = gpadl_id.0,
1300 channel_id = channel_id.0,
1301 "gpadl teardown for offered gpadl"
1302 );
1303 }
1304 GpadlState::Created => {
1305 *gpadl_state = GpadlState::TearingDown { rpcs: vec![rpc] };
1306 assert!(
1310 self.inner
1311 .teardown_gpadls
1312 .insert(gpadl_id, channel_id)
1313 .is_none(),
1314 "Gpadl state validated above"
1315 );
1316
1317 self.inner.messages.send(&protocol::GpadlTeardown {
1318 channel_id,
1319 gpadl_id,
1320 });
1321 }
1322 GpadlState::TearingDown { rpcs } => {
1323 rpcs.push(rpc);
1324 }
1325 }
1326 }
1327
1328 fn handle_close_channel(&mut self, channel_id: ChannelId) {
1329 let mut channel = self.channels.get_mut(channel_id);
1330 self.inner.close_channel(channel_id, &mut channel);
1331 }
1332
1333 fn handle_modify_channel(&mut self, channel_id: ChannelId, rpc: Rpc<ModifyRequest, i32>) {
1334 assert!(self.check_version(Version::Iron));
1338 let mut channel = self.channels.get_mut(channel_id);
1339 if channel.modify_response_send.is_some() {
1340 panic!("duplicate channel modify request {channel_id:?}");
1341 }
1342
1343 let (request, response) = rpc.split();
1344 channel.modify_response_send = Some(response);
1345 let payload = match request {
1346 ModifyRequest::TargetVp { target_vp } => protocol::ModifyChannel {
1347 channel_id,
1348 target_vp,
1349 },
1350 };
1351
1352 self.inner.messages.send(&payload);
1353 }
1354
1355 fn handle_channel_request(&mut self, channel_id: ChannelId, request: ChannelRequest) {
1356 match request {
1357 ChannelRequest::Open(rpc) => self.handle_open_channel(channel_id, rpc),
1358 ChannelRequest::Restore(rpc) => {
1359 rpc.handle_failable_sync(|request| self.handle_restore_channel(channel_id, request))
1360 }
1361 ChannelRequest::Gpadl(req) => self.handle_gpadl(channel_id, req),
1362 ChannelRequest::TeardownGpadl(req) => self.handle_gpadl_teardown(channel_id, req),
1363 ChannelRequest::Close(req) => {
1364 req.handle_sync(|()| self.handle_close_channel(channel_id))
1365 }
1366 ChannelRequest::Modify(req) => self.handle_modify_channel(channel_id, req),
1367 }
1368 }
1369
1370 async fn handle_task(&mut self, task: TaskRequest) {
1371 match task {
1372 TaskRequest::Inspect(deferred) => {
1373 deferred.inspect(&*self);
1374 }
1375 TaskRequest::Save(rpc) => rpc.handle_sync(|()| self.handle_save()),
1376 TaskRequest::Restore(rpc) => {
1377 rpc.handle_sync(|saved_state| self.handle_restore(saved_state))
1378 }
1379 TaskRequest::PostRestore(rpc) => rpc.handle_sync(|()| self.handle_post_restore()),
1380 TaskRequest::Start => self.handle_start(),
1381 TaskRequest::Stop(rpc) => rpc.handle(async |()| self.handle_stop().await).await,
1382 }
1383 }
1384
1385 fn handle_device_removal(&mut self, channel_id: ChannelId) -> TriedRelease {
1387 let mut channel = self.channels.get_mut(channel_id);
1388 channel.is_client_released = true;
1389 if let ChannelState::Opened { .. } = channel.state {
1391 tracing::warn!(
1392 channel_id = channel_id.0,
1393 "Channel dropped without closing first"
1394 );
1395 self.inner.close_channel(channel_id, &mut channel);
1396 }
1397 channel.try_release(&mut self.inner.messages)
1398 }
1399
1400 fn check_version(&self, version: Version) -> bool {
1402 matches!(self.state, ClientState::Connected { version: v, .. } if v.version >= version)
1403 }
1404
1405 fn handle_start(&mut self) {
1406 assert!(!self.running);
1407 self.msg_source.resume_message_stream();
1408 self.inner.messages.resume();
1409 self.running = true;
1410 }
1411
1412 async fn handle_stop(&mut self) {
1413 assert!(self.running);
1414
1415 loop {
1416 while let Some((id, request)) = self.channels.revoked_channel_with_pending_request() {
1421 tracelimit::info_ratelimited!(
1422 channel_id = id.0,
1423 request,
1424 "waiting for responses for channel"
1425 );
1426 assert!(self.process_next_message().await);
1427 }
1428
1429 if self.can_pause_resume() {
1430 self.inner.messages.pause();
1431 } else {
1432 self.msg_source.pause_message_stream();
1435 self.inner.messages.force_pause();
1436 }
1437
1438 while self.process_next_message().await {}
1441
1442 if self
1445 .channels
1446 .revoked_channel_with_pending_request()
1447 .is_none()
1448 {
1449 break;
1450 }
1451 if !self.can_pause_resume() {
1452 self.msg_source.resume_message_stream();
1453 }
1454 self.inner.messages.resume();
1455 }
1456
1457 tracing::debug!("messages drained");
1458 self.running = false;
1460 }
1461
1462 async fn process_next_message(&mut self) -> bool {
1463 let mut buf = [0; protocol::MAX_MESSAGE_SIZE];
1464 let recv = self.msg_source.recv(&mut buf);
1465 let flush = async {
1468 self.inner.messages.flush_messages().await;
1469 std::future::pending().await
1470 };
1471 let size = (recv, flush)
1472 .race()
1473 .await
1474 .expect("Fatal error reading messages from synic");
1475 if size == 0 {
1476 return false;
1477 }
1478 self.handle_synic_message(&buf[..size])
1479 }
1480
1481 fn can_pause_resume(&self) -> bool {
1490 if let ClientState::Connected { version, .. } = self.state {
1491 version.feature_flags.pause_resume()
1492 } else {
1493 false
1494 }
1495 }
1496
1497 async fn run(&mut self) {
1498 let mut buf = [0; protocol::MAX_MESSAGE_SIZE];
1499 loop {
1500 let mut message_recv =
1501 OptionFuture::from(self.running.then(|| self.msg_source.recv(&mut buf).fuse()));
1502
1503 let host_backed_up = !self.inner.messages.is_empty();
1513 let flush_messages = OptionFuture::from(
1514 (self.running && host_backed_up)
1515 .then(|| self.inner.messages.flush_messages().fuse()),
1516 );
1517
1518 let mut client_request_recv = OptionFuture::from(
1519 (self.running && !host_backed_up).then(|| self.client_request_recv.next()),
1520 );
1521
1522 let mut channel_requests = OptionFuture::from(
1523 (self.running && !host_backed_up)
1524 .then(|| self.inner.channel_requests.select_next_some()),
1525 );
1526
1527 futures::select! { _r = pin!(flush_messages) => {}
1529 r = self.task_recv.next() => {
1530 if let Some(task) = r {
1531 self.handle_task(task).await;
1532 } else {
1533 break;
1534 }
1535 }
1536 r = client_request_recv => {
1537 if let Some(Some(request)) = r {
1538 self.handle_client_request(request);
1539 } else {
1540 break;
1541 }
1542 }
1543 r = channel_requests => {
1544 match r.unwrap() {
1545 (id, Some(request)) => self.handle_channel_request(id, request),
1546 (id, _) => {
1547 self.handle_device_removal(id);
1548 }
1549 }
1550 }
1551 r = message_recv => {
1552 match r.unwrap() {
1553 Ok(size) => {
1554 if size == 0 {
1555 panic!("Unexpected end of file reading messages from synic.");
1556 }
1557
1558 self.handle_synic_message(&buf[..size]);
1559 }
1560 Err(err) => {
1561 panic!("Error reading messages from synic: {err:?}");
1562 }
1563 }
1564 }
1565 complete => break,
1566 }
1567 }
1568 }
1569}
1570
1571impl ClientTaskInner {
1572 fn close_channel(&mut self, channel_id: ChannelId, channel: &mut Channel) {
1573 if let ChannelState::Opened {
1574 redirected_event_flag,
1575 ..
1576 } = channel.state
1577 {
1578 if let Some(flag) = redirected_event_flag {
1579 self.synic.free_event_flag(flag);
1580 }
1581 tracing::info!(channel_id = channel_id.0, "closing channel on host");
1582 self.messages.send(&protocol::CloseChannel { channel_id });
1583 channel.state = ChannelState::Offered;
1584 channel.connection_id.store(0, Ordering::Release);
1585 } else {
1586 tracing::warn!(
1587 id = %channel_id.0,
1588 channel_state = %channel.state,
1589 "invalid channel state for close channel"
1590 );
1591 }
1592 }
1593}
1594
1595#[derive(Debug, Inspect)]
1596#[inspect(external_tag)]
1597enum GpadlState {
1598 Offered(#[inspect(skip)] FailableRpc<(), ()>),
1600 Created,
1602 TearingDown {
1604 #[inspect(skip)]
1605 rpcs: Vec<Rpc<(), ()>>,
1606 },
1607}
1608
1609#[derive(Inspect)]
1610struct OutgoingMessages {
1611 #[inspect(skip)]
1612 poster: Box<dyn PollPostMessage>,
1613 #[inspect(with = "|x| x.len()")]
1614 queued: VecDeque<OutgoingMessage>,
1615 state: OutgoingMessageState,
1616}
1617
1618#[derive(Inspect, PartialEq, Eq, Debug)]
1619enum OutgoingMessageState {
1620 Running,
1621 SendingPauseMessage,
1622 Paused,
1623}
1624
1625impl OutgoingMessages {
1626 fn send<T: IntoBytes + protocol::VmbusMessage + std::fmt::Debug + Immutable + KnownLayout>(
1627 &mut self,
1628 msg: &T,
1629 ) {
1630 self.send_with_data(msg, &[])
1631 }
1632
1633 fn send_with_data<
1634 T: IntoBytes + protocol::VmbusMessage + std::fmt::Debug + Immutable + KnownLayout,
1635 >(
1636 &mut self,
1637 msg: &T,
1638 data: &[u8],
1639 ) {
1640 tracing::trace!(typ = ?T::MESSAGE_TYPE, "Sending message to host");
1641 let msg = OutgoingMessage::with_data(msg, data);
1642 if self.queued.is_empty() && self.state == OutgoingMessageState::Running {
1643 let r = self.poster.poll_post_message(
1644 &mut Context::from_waker(std::task::Waker::noop()),
1645 protocol::VMBUS_MESSAGE_REDIRECT_CONNECTION_ID,
1646 1,
1647 msg.data(),
1648 );
1649 if let Poll::Ready(()) = r {
1650 return;
1651 }
1652 }
1653 tracing::trace!("queueing message");
1654 self.queued.push_back(msg);
1655 }
1656
1657 async fn flush_messages(&mut self) {
1658 let mut send = async |msg: &OutgoingMessage| {
1659 poll_fn(|cx| {
1660 self.poster.poll_post_message(
1661 cx,
1662 protocol::VMBUS_MESSAGE_REDIRECT_CONNECTION_ID,
1663 1,
1664 msg.data(),
1665 )
1666 })
1667 .await
1668 };
1669 match self.state {
1670 OutgoingMessageState::Running => {
1671 while let Some(msg) = self.queued.front() {
1672 send(msg).await;
1673 tracing::trace!("sent queued message");
1674 self.queued.pop_front();
1675 }
1676 }
1677 OutgoingMessageState::SendingPauseMessage => {
1678 send(&OutgoingMessage::new(&protocol::Pause)).await;
1679 tracing::trace!("sent pause message");
1680 self.state = OutgoingMessageState::Paused;
1681 }
1682 OutgoingMessageState::Paused => {}
1683 }
1684 }
1685
1686 fn pause(&mut self) {
1689 assert_eq!(self.state, OutgoingMessageState::Running);
1690 self.state = OutgoingMessageState::SendingPauseMessage;
1691 self.queued
1693 .push_front(OutgoingMessage::new(&protocol::Resume));
1694 }
1695
1696 fn force_pause(&mut self) {
1700 assert_eq!(self.state, OutgoingMessageState::Running);
1701 self.state = OutgoingMessageState::Paused;
1702 }
1703
1704 fn resume(&mut self) {
1705 assert_eq!(self.state, OutgoingMessageState::Paused);
1706 self.state = OutgoingMessageState::Running;
1707 }
1708
1709 fn is_empty(&self) -> bool {
1710 self.queued.is_empty()
1711 }
1712}
1713
1714#[derive(Inspect)]
1715struct ClientTaskInner {
1716 messages: OutgoingMessages,
1717 #[inspect(with = "|x| inspect::iter_by_key(x).map_key(|id| id.0)")]
1718 teardown_gpadls: HashMap<GpadlId, ChannelId>,
1719 #[inspect(skip)]
1720 channel_requests: SelectAll<TaggedStream<ChannelId, mesh::Receiver<ChannelRequest>>>,
1721 synic: SynicState,
1722}
1723
1724#[derive(Inspect)]
1725struct SynicState {
1726 #[inspect(skip)]
1727 event_client: Arc<dyn SynicEventClient>,
1728 #[inspect(iter_by_index)]
1729 event_flag_state: Vec<bool>,
1730}
1731
1732#[derive(Inspect, Default)]
1733#[inspect(transparent)]
1734struct ChannelList(
1735 #[inspect(with = "|x| inspect::iter_by_key(x).map_key(|id| id.0)")] HashMap<ChannelId, Channel>,
1736);
1737
1738struct ChannelRef<'a>(hash_map::OccupiedEntry<'a, ChannelId, Channel>);
1741
1742struct TriedRelease(());
1746
1747impl ChannelRef<'_> {
1748 fn try_release(self, messages: &mut OutgoingMessages) -> TriedRelease {
1752 if self.is_client_released
1753 && matches!(self.state, ChannelState::Revoked)
1754 && self.pending_request().is_none()
1755 {
1756 let channel_id = *self.0.key();
1757 tracelimit::info_ratelimited!(channel_id = channel_id.0, "releasing channel");
1758 messages.send(&protocol::RelIdReleased { channel_id });
1759 self.0.remove();
1760 }
1761 TriedRelease(())
1762 }
1763}
1764
1765impl Deref for ChannelRef<'_> {
1766 type Target = Channel;
1767
1768 fn deref(&self) -> &Self::Target {
1769 self.0.get()
1770 }
1771}
1772
1773impl DerefMut for ChannelRef<'_> {
1774 fn deref_mut(&mut self) -> &mut Self::Target {
1775 self.0.get_mut()
1776 }
1777}
1778
1779impl ChannelList {
1780 fn revoked_channel_with_pending_request(&self) -> Option<(ChannelId, &'static str)> {
1781 self.0.iter().find_map(|(&id, channel)| {
1782 if !matches!(channel.state, ChannelState::Revoked) {
1783 return None;
1784 }
1785 Some((id, channel.pending_request()?))
1786 })
1787 }
1788
1789 #[track_caller]
1790 fn get_mut(&mut self, channel_id: ChannelId) -> ChannelRef<'_> {
1791 match self.0.entry(channel_id) {
1792 hash_map::Entry::Occupied(entry) => ChannelRef(entry),
1793 hash_map::Entry::Vacant(_) => {
1794 panic!("channel {:?} not found", channel_id);
1795 }
1796 }
1797 }
1798}
1799
1800impl SynicState {
1801 fn guest_to_host_interrupt(&self, connection_id: Arc<AtomicU32>) -> Interrupt {
1802 Interrupt::from_fn({
1803 let event_client = self.event_client.clone();
1804 move || {
1805 let connection_id = connection_id.load(Ordering::Acquire);
1806 if connection_id == 0 {
1807 tracing::debug!("interrupt signal after close");
1808 return;
1809 }
1810
1811 if let Err(err) = event_client.signal_event(connection_id, 0) {
1812 tracelimit::warn_ratelimited!(
1813 error = &err as &dyn std::error::Error,
1814 "failed to signal event"
1815 );
1816 }
1817 }
1818 })
1819 }
1820
1821 const MAX_EVENT_FLAGS: u16 = 2047;
1822
1823 fn allocate_event_flag(&mut self, event: &Event) -> Result<u16> {
1824 let i = self
1825 .event_flag_state
1826 .iter()
1827 .position(|&used| !used)
1828 .ok_or(())
1829 .or_else(|()| {
1830 if self.event_flag_state.len() >= Self::MAX_EVENT_FLAGS as usize {
1831 anyhow::bail!("out of event flags");
1832 }
1833 self.event_flag_state.push(false);
1834 Ok(self.event_flag_state.len() - 1)
1835 })?;
1836
1837 let event_flag = (i + 1) as u16;
1838 self.event_client
1839 .map_event(event_flag, event)
1840 .context("failed to map event")?;
1841 self.event_flag_state[i] = true;
1842 Ok(event_flag)
1843 }
1844
1845 fn restore_event_flag(&mut self, flag: u16, event: &Event) -> Result<()> {
1846 let i = (flag as usize)
1847 .checked_sub(1)
1848 .context("invalid event flag")?;
1849 if i >= Self::MAX_EVENT_FLAGS as usize {
1850 anyhow::bail!("invalid event flag");
1851 }
1852 if self.event_flag_state.len() <= i {
1853 self.event_flag_state.resize(i + 1, false);
1854 }
1855 if self.event_flag_state[i] {
1856 anyhow::bail!("event flag already in use");
1857 }
1858 self.event_client
1859 .map_event(flag, event)
1860 .context("failed to map event")?;
1861 self.event_flag_state[i] = true;
1862 Ok(())
1863 }
1864
1865 fn free_event_flag(&mut self, flag: u16) {
1866 let i = flag as usize - 1;
1867 assert!(i < self.event_flag_state.len());
1868 self.event_flag_state[i] = false;
1869 }
1870}
1871
1872#[cfg(test)]
1873mod tests {
1874 use super::*;
1875 use futures_concurrency::future::Join;
1876 use guid::Guid;
1877 use pal_async::DefaultDriver;
1878 use pal_async::async_test;
1879 use pal_async::timer::PolledTimer;
1880 use protocol::TargetInfo;
1881 use std::fmt::Debug;
1882 use std::task::ready;
1883 use std::time::Duration;
1884 use test_with_tracing::test;
1885 use vmbus_core::protocol::MessageHeader;
1886 use vmbus_core::protocol::MessageType;
1887 use vmbus_core::protocol::OfferFlags;
1888 use vmbus_core::protocol::UserDefinedData;
1889 use vmbus_core::protocol::VmbusMessage;
1890 use zerocopy::FromBytes;
1891 use zerocopy::FromZeros;
1892 use zerocopy::Immutable;
1893 use zerocopy::IntoBytes;
1894 use zerocopy::KnownLayout;
1895
1896 const VMBUS_TEST_CLIENT_ID: Guid = guid::guid!("e6e6e6e6-e6e6-e6e6-e6e6-e6e6e6e6e6e6");
1897
1898 fn in_msg<T: IntoBytes + Immutable + KnownLayout>(message_type: MessageType, t: T) -> Vec<u8> {
1899 let mut data = Vec::new();
1900 data.extend_from_slice(&message_type.0.to_ne_bytes());
1901 data.extend_from_slice(&0u32.to_ne_bytes());
1902 data.extend_from_slice(t.as_bytes());
1903 data
1904 }
1905
1906 #[track_caller]
1907 fn check_message<T>(msg: OutgoingMessage, chk: T)
1908 where
1909 T: IntoBytes + FromBytes + Immutable + KnownLayout + Debug + VmbusMessage,
1910 {
1911 check_message_with_data(msg, chk, &[]);
1912 }
1913
1914 #[track_caller]
1915 fn check_message_with_data<T>(msg: OutgoingMessage, chk: T, data: &[u8])
1916 where
1917 T: IntoBytes + FromBytes + Immutable + KnownLayout + Debug + VmbusMessage,
1918 {
1919 let chk_data = OutgoingMessage::with_data(&chk, data);
1920 if msg.data() != chk_data.data() {
1921 let (header, rest) = MessageHeader::read_from_prefix(msg.data()).unwrap();
1922 assert_eq!(header.message_type(), <T as VmbusMessage>::MESSAGE_TYPE);
1923 let (msg, rest) = T::read_from_prefix(rest).expect("incorrect message size");
1924 if msg.as_bytes() != chk.as_bytes() {
1925 panic!("mismatched messages, expected {:#?}, got {:#?}", chk, msg);
1926 }
1927 if rest != data {
1928 panic!("mismatched data, expected {:#?}, got {:#?}", data, rest);
1929 }
1930 }
1931 }
1932
1933 struct TestServer {
1934 messages: mesh::Receiver<OutgoingMessage>,
1935 send: mesh::Sender<Vec<u8>>,
1936 }
1937
1938 impl TestServer {
1939 async fn next(&mut self) -> Option<OutgoingMessage> {
1940 self.messages.next().await
1941 }
1942
1943 fn send(&self, msg: Vec<u8>) {
1944 self.send.send(msg);
1945 }
1946
1947 async fn connect(&mut self, client: &mut VmbusClient) -> ConnectResult {
1948 self.connect_with_channels(client, |_| {}).await
1949 }
1950
1951 async fn connect_with_channels(
1952 &mut self,
1953 client: &mut VmbusClient,
1954 send_offers: impl FnOnce(&mut Self),
1955 ) -> ConnectResult {
1956 let client_connect = client.connect(0, None, Guid::ZERO);
1957
1958 let server_connect = async {
1959 let _ = self.next().await.unwrap();
1960
1961 self.send(in_msg(
1962 MessageType::VERSION_RESPONSE,
1963 protocol::VersionResponse2 {
1964 version_response: protocol::VersionResponse {
1965 version_supported: 1,
1966 connection_state: ConnectionState::SUCCESSFUL,
1967 padding: 0,
1968 selected_version_or_connection_id: 0,
1969 },
1970 supported_features: SUPPORTED_FEATURE_FLAGS.into(),
1971 },
1972 ));
1973
1974 check_message(self.next().await.unwrap(), protocol::RequestOffers {});
1975
1976 send_offers(self);
1977 self.send(in_msg(MessageType::ALL_OFFERS_DELIVERED, [0x00]));
1978 };
1979
1980 let (connection, ()) = (client_connect, server_connect).join().await;
1981
1982 let connection = connection.unwrap();
1983 assert_eq!(connection.version.version, Version::Copper);
1984 assert_eq!(connection.version.feature_flags, SUPPORTED_FEATURE_FLAGS);
1985 connection
1986 }
1987
1988 async fn get_channel(&mut self, client: &mut VmbusClient) -> OfferInfo {
1989 let [channel] = self
1990 .get_channels(client, 1)
1991 .await
1992 .offers
1993 .try_into()
1994 .unwrap();
1995 channel
1996 }
1997
1998 async fn get_channels(&mut self, client: &mut VmbusClient, count: usize) -> ConnectResult {
1999 self.connect_with_channels(client, |this| {
2000 for i in 0..count {
2001 let offer = protocol::OfferChannel {
2002 interface_id: Guid::new_random(),
2003 instance_id: Guid::new_random(),
2004 rsvd: [0; 4],
2005 flags: OfferFlags::new(),
2006 mmio_megabytes: 0,
2007 user_defined: UserDefinedData::new_zeroed(),
2008 subchannel_index: 0,
2009 mmio_megabytes_optional: 0,
2010 channel_id: ChannelId(i as u32),
2011 monitor_id: 0,
2012 monitor_allocated: 0,
2013 is_dedicated: 0,
2014 connection_id: 0,
2015 };
2016
2017 this.send(in_msg(MessageType::OFFER_CHANNEL, offer));
2018 }
2019 })
2020 .await
2021 }
2022
2023 async fn stop_client(&mut self, client: &mut VmbusClient) {
2024 let client_stop = client.stop();
2025 let server_stop = async {
2026 check_message(self.next().await.unwrap(), protocol::Pause);
2027 self.send(in_msg(MessageType::PAUSE_RESPONSE, protocol::PauseResponse));
2028 };
2029 (client_stop, server_stop).join().await;
2030 }
2031
2032 async fn start_client(&mut self, client: &mut VmbusClient) {
2033 client.start();
2034 check_message(self.next().await.unwrap(), protocol::Resume);
2035 }
2036 }
2037
2038 struct TestServerClient {
2039 sender: mesh::Sender<OutgoingMessage>,
2040 timer: PolledTimer,
2041 deadline: Option<pal_async::timer::Instant>,
2042 }
2043
2044 impl PollPostMessage for TestServerClient {
2045 fn poll_post_message(
2046 &mut self,
2047 cx: &mut Context<'_>,
2048 _connection_id: u32,
2049 _typ: u32,
2050 msg: &[u8],
2051 ) -> Poll<()> {
2052 loop {
2053 if let Some(deadline) = self.deadline {
2054 ready!(self.timer.poll_until(cx, deadline));
2055 self.deadline = None;
2056 }
2057 let mut b = [0];
2062 getrandom::fill(&mut b).unwrap();
2063 if b[0] % 4 == 0 {
2064 self.deadline =
2065 Some(pal_async::timer::Instant::now() + Duration::from_millis(10));
2066 } else {
2067 let msg = OutgoingMessage::from_message(msg).unwrap();
2068 tracing::info!(
2069 msg = ?MessageHeader::read_from_prefix(msg.data()),
2070 "sending message"
2071 );
2072 self.sender.send(msg);
2073 break Poll::Ready(());
2074 }
2075 }
2076 }
2077 }
2078
2079 struct NoopSynicEvents;
2080
2081 impl SynicEventClient for NoopSynicEvents {
2082 fn map_event(&self, _event_flag: u16, _event: &Event) -> std::io::Result<()> {
2083 Ok(())
2084 }
2085
2086 fn unmap_event(&self, _event_flag: u16) {}
2087
2088 fn signal_event(&self, _connection_id: u32, _event_flag: u16) -> std::io::Result<()> {
2089 Err(std::io::ErrorKind::Unsupported.into())
2090 }
2091 }
2092
2093 struct TestMessageSource {
2094 msg_recv: mesh::Receiver<Vec<u8>>,
2095 paused: bool,
2096 }
2097
2098 impl AsyncRecv for TestMessageSource {
2099 fn poll_recv(
2100 &mut self,
2101 cx: &mut Context<'_>,
2102 mut bufs: &mut [std::io::IoSliceMut<'_>],
2103 ) -> Poll<std::io::Result<usize>> {
2104 let value = match self.msg_recv.poll_recv(cx) {
2105 Poll::Ready(v) => v.unwrap(),
2106 Poll::Pending => {
2107 if self.paused {
2108 return Poll::Ready(Ok(0));
2109 } else {
2110 return Poll::Pending;
2111 }
2112 }
2113 };
2114 let mut remaining = value.as_slice();
2115 let mut total_size = 0;
2116 while !remaining.is_empty() && !bufs.is_empty() {
2117 let size = bufs[0].len().min(remaining.len());
2118 bufs[0][..size].copy_from_slice(&remaining[..size]);
2119 remaining = &remaining[size..];
2120 bufs = &mut bufs[1..];
2121 total_size += size;
2122 }
2123
2124 Ok(total_size).into()
2125 }
2126 }
2127
2128 impl VmbusMessageSource for TestMessageSource {
2129 fn pause_message_stream(&mut self) {
2130 self.paused = true;
2131 }
2132
2133 fn resume_message_stream(&mut self) {
2134 self.paused = false;
2135 }
2136 }
2137
2138 fn test_init(driver: &DefaultDriver) -> (TestServer, VmbusClient) {
2139 let (msg_send, msg_recv) = mesh::channel();
2140 let (synic_send, synic_recv) = mesh::channel();
2141 let server = TestServer {
2142 messages: synic_recv,
2143 send: msg_send,
2144 };
2145 let mut client = VmbusClientBuilder::new(
2146 NoopSynicEvents,
2147 TestMessageSource {
2148 msg_recv,
2149 paused: false,
2150 },
2151 TestServerClient {
2152 sender: synic_send,
2153 deadline: None,
2154 timer: PolledTimer::new(driver),
2155 },
2156 )
2157 .build(driver);
2158 client.start();
2159 (server, client)
2160 }
2161
2162 #[async_test]
2163 async fn test_initiate_contact_success(driver: DefaultDriver) {
2164 let (mut server, client) = test_init(&driver);
2165 let _recv = client
2166 .access
2167 .client_request_send
2168 .call(ClientRequest::Connect, ConnectRequest::default());
2169 check_message(
2170 server.next().await.unwrap(),
2171 protocol::InitiateContact2 {
2172 initiate_contact: protocol::InitiateContact {
2173 version_requested: Version::Copper as u32,
2174 target_message_vp: 0,
2175 interrupt_page_or_target_info: TargetInfo::new()
2176 .with_sint(2)
2177 .with_vtl(0)
2178 .with_feature_flags(SUPPORTED_FEATURE_FLAGS.into())
2179 .into(),
2180 parent_to_child_monitor_page_gpa: 0,
2181 child_to_parent_monitor_page_gpa: 0,
2182 },
2183 ..FromZeros::new_zeroed()
2184 },
2185 );
2186 }
2187
2188 #[async_test]
2189 async fn test_connect_success(driver: DefaultDriver) {
2190 let (mut server, mut client) = test_init(&driver);
2191 let client_connect = client.connect(0, None, Guid::ZERO);
2192
2193 let server_connect = async {
2194 check_message(
2195 server.next().await.unwrap(),
2196 protocol::InitiateContact2 {
2197 initiate_contact: protocol::InitiateContact {
2198 version_requested: Version::Copper as u32,
2199 target_message_vp: 0,
2200 interrupt_page_or_target_info: TargetInfo::new()
2201 .with_sint(2)
2202 .with_vtl(0)
2203 .with_feature_flags(SUPPORTED_FEATURE_FLAGS.into())
2204 .into(),
2205 parent_to_child_monitor_page_gpa: 0,
2206 child_to_parent_monitor_page_gpa: 0,
2207 },
2208 ..FromZeros::new_zeroed()
2209 },
2210 );
2211
2212 server.send(in_msg(
2213 MessageType::VERSION_RESPONSE,
2214 protocol::VersionResponse2 {
2215 version_response: protocol::VersionResponse {
2216 version_supported: 1,
2217 connection_state: ConnectionState::SUCCESSFUL,
2218 padding: 0,
2219 selected_version_or_connection_id: 0,
2220 },
2221 supported_features: SUPPORTED_FEATURE_FLAGS.into_bits(),
2222 },
2223 ));
2224
2225 check_message(server.next().await.unwrap(), protocol::RequestOffers {});
2226 server.send(in_msg(MessageType::ALL_OFFERS_DELIVERED, [0x00]));
2227 };
2228
2229 let (connection, ()) = (client_connect, server_connect).join().await;
2230 let connection = connection.unwrap();
2231
2232 assert_eq!(connection.version.version, Version::Copper);
2233 assert_eq!(connection.version.feature_flags, SUPPORTED_FEATURE_FLAGS);
2234 }
2235
2236 #[async_test]
2237 async fn test_feature_flags(driver: DefaultDriver) {
2238 let (mut server, mut client) = test_init(&driver);
2239 let client_connect = client.connect(0, None, Guid::ZERO);
2240
2241 let server_connect = async {
2242 check_message(
2243 server.next().await.unwrap(),
2244 protocol::InitiateContact2 {
2245 initiate_contact: protocol::InitiateContact {
2246 version_requested: Version::Copper as u32,
2247 target_message_vp: 0,
2248 interrupt_page_or_target_info: TargetInfo::new()
2249 .with_sint(2)
2250 .with_vtl(0)
2251 .with_feature_flags(SUPPORTED_FEATURE_FLAGS.into())
2252 .into(),
2253 parent_to_child_monitor_page_gpa: 0,
2254 child_to_parent_monitor_page_gpa: 0,
2255 },
2256 ..FromZeros::new_zeroed()
2257 },
2258 );
2259
2260 server.send(in_msg(
2263 MessageType::VERSION_RESPONSE,
2264 protocol::VersionResponse2 {
2265 version_response: protocol::VersionResponse {
2266 version_supported: 1,
2267 connection_state: ConnectionState::SUCCESSFUL,
2268 padding: 0,
2269 selected_version_or_connection_id: 0,
2270 },
2271 supported_features: 2,
2272 },
2273 ));
2274
2275 check_message(server.next().await.unwrap(), protocol::RequestOffers {});
2276 server.send(in_msg(MessageType::ALL_OFFERS_DELIVERED, [0x00]));
2277 };
2278
2279 let (connection, ()) = (client_connect, server_connect).join().await;
2280 let connection = connection.unwrap();
2281
2282 assert_eq!(connection.version.version, Version::Copper);
2283 assert_eq!(
2284 connection.version.feature_flags,
2285 FeatureFlags::new().with_channel_interrupt_redirection(true)
2286 );
2287 }
2288
2289 #[async_test]
2290 async fn test_client_id(driver: DefaultDriver) {
2291 let (mut server, client) = test_init(&driver);
2292 let initiate_contact = ConnectRequest {
2293 client_id: VMBUS_TEST_CLIENT_ID,
2294 ..Default::default()
2295 };
2296 let _recv = client
2297 .access
2298 .client_request_send
2299 .call(ClientRequest::Connect, initiate_contact);
2300
2301 check_message(
2302 server.next().await.unwrap(),
2303 protocol::InitiateContact2 {
2304 initiate_contact: protocol::InitiateContact {
2305 version_requested: Version::Copper as u32,
2306 target_message_vp: 0,
2307 interrupt_page_or_target_info: TargetInfo::new()
2308 .with_sint(2)
2309 .with_vtl(0)
2310 .with_feature_flags(SUPPORTED_FEATURE_FLAGS.into())
2311 .into(),
2312 parent_to_child_monitor_page_gpa: 0,
2313 child_to_parent_monitor_page_gpa: 0,
2314 },
2315 client_id: VMBUS_TEST_CLIENT_ID,
2316 },
2317 );
2318 }
2319
2320 #[async_test]
2321 async fn test_version_negotiation(driver: DefaultDriver) {
2322 let (mut server, mut client) = test_init(&driver);
2323 let client_connect = client.connect(0, None, Guid::ZERO);
2324
2325 let server_connect = async {
2326 check_message(
2327 server.next().await.unwrap(),
2328 protocol::InitiateContact2 {
2329 initiate_contact: protocol::InitiateContact {
2330 version_requested: Version::Copper as u32,
2331 target_message_vp: 0,
2332 interrupt_page_or_target_info: TargetInfo::new()
2333 .with_sint(2)
2334 .with_vtl(0)
2335 .with_feature_flags(SUPPORTED_FEATURE_FLAGS.into())
2336 .into(),
2337 parent_to_child_monitor_page_gpa: 0,
2338 child_to_parent_monitor_page_gpa: 0,
2339 },
2340 ..FromZeros::new_zeroed()
2341 },
2342 );
2343
2344 server.send(in_msg(
2345 MessageType::VERSION_RESPONSE,
2346 protocol::VersionResponse {
2347 version_supported: 0,
2348 connection_state: ConnectionState::SUCCESSFUL,
2349 padding: 0,
2350 selected_version_or_connection_id: 0,
2351 },
2352 ));
2353
2354 check_message(
2355 server.next().await.unwrap(),
2356 protocol::InitiateContact {
2357 version_requested: Version::Iron as u32,
2358 target_message_vp: 0,
2359 interrupt_page_or_target_info: TargetInfo::new()
2360 .with_sint(2)
2361 .with_vtl(0)
2362 .with_feature_flags(FeatureFlags::new().into())
2363 .into(),
2364 parent_to_child_monitor_page_gpa: 0,
2365 child_to_parent_monitor_page_gpa: 0,
2366 },
2367 );
2368
2369 server.send(in_msg(
2370 MessageType::VERSION_RESPONSE,
2371 protocol::VersionResponse {
2372 version_supported: 1,
2373 connection_state: ConnectionState::SUCCESSFUL,
2374 padding: 0,
2375 selected_version_or_connection_id: 0,
2376 },
2377 ));
2378
2379 check_message(server.next().await.unwrap(), protocol::RequestOffers {});
2380 server.send(in_msg(MessageType::ALL_OFFERS_DELIVERED, [0x00]));
2381 };
2382
2383 let (connection, ()) = (client_connect, server_connect).join().await;
2384 let connection = connection.unwrap();
2385
2386 assert_eq!(connection.version.version, Version::Iron);
2387 assert_eq!(connection.version.feature_flags, FeatureFlags::new());
2388 }
2389
2390 #[async_test]
2391 async fn test_open_channel_success(driver: DefaultDriver) {
2392 let (mut server, mut client) = test_init(&driver);
2393 let channel = server.get_channel(&mut client).await;
2394
2395 let recv = channel.request_send.call(
2396 ChannelRequest::Open,
2397 OpenRequest {
2398 open_data: OpenData {
2399 target_vp: 0,
2400 ring_offset: 0,
2401 ring_gpadl_id: GpadlId(0),
2402 event_flag: 0,
2403 connection_id: 0,
2404 user_data: UserDefinedData::new_zeroed(),
2405 },
2406 incoming_event: None,
2407 use_vtl2_connection_id: false,
2408 },
2409 );
2410
2411 check_message(
2412 server.next().await.unwrap(),
2413 protocol::OpenChannel2 {
2414 open_channel: protocol::OpenChannel {
2415 channel_id: ChannelId(0),
2416 open_id: 0,
2417 ring_buffer_gpadl_id: GpadlId(0),
2418 target_vp: 0,
2419 downstream_ring_buffer_page_offset: 0,
2420 user_data: UserDefinedData::new_zeroed(),
2421 },
2422 connection_id: 0,
2423 event_flag: 0,
2424 flags: Default::default(),
2425 },
2426 );
2427
2428 server.send(in_msg(
2429 MessageType::OPEN_CHANNEL_RESULT,
2430 protocol::OpenResult {
2431 channel_id: ChannelId(0),
2432 open_id: 0,
2433 status: protocol::STATUS_SUCCESS as u32,
2434 },
2435 ));
2436
2437 recv.await.unwrap().unwrap();
2438 }
2439
2440 #[async_test]
2441 async fn test_open_channel_fail(driver: DefaultDriver) {
2442 let (mut server, mut client) = test_init(&driver);
2443 let channel = server.get_channel(&mut client).await;
2444
2445 let recv = channel.request_send.call(
2446 ChannelRequest::Open,
2447 OpenRequest {
2448 open_data: OpenData {
2449 target_vp: 0,
2450 ring_offset: 0,
2451 ring_gpadl_id: GpadlId(0),
2452 event_flag: 0,
2453 connection_id: 0,
2454 user_data: UserDefinedData::new_zeroed(),
2455 },
2456 incoming_event: None,
2457 use_vtl2_connection_id: false,
2458 },
2459 );
2460
2461 check_message(
2462 server.next().await.unwrap(),
2463 protocol::OpenChannel2 {
2464 open_channel: protocol::OpenChannel {
2465 channel_id: ChannelId(0),
2466 open_id: 0,
2467 ring_buffer_gpadl_id: GpadlId(0),
2468 target_vp: 0,
2469 downstream_ring_buffer_page_offset: 0,
2470 user_data: UserDefinedData::new_zeroed(),
2471 },
2472 connection_id: 0,
2473 event_flag: 0,
2474 flags: Default::default(),
2475 },
2476 );
2477
2478 server.send(in_msg(
2479 MessageType::OPEN_CHANNEL_RESULT,
2480 protocol::OpenResult {
2481 channel_id: ChannelId(0),
2482 open_id: 0,
2483 status: protocol::STATUS_UNSUCCESSFUL as u32,
2484 },
2485 ));
2486
2487 recv.await.unwrap().unwrap_err();
2488 }
2489
2490 #[async_test]
2491 async fn test_modify_channel(driver: DefaultDriver) {
2492 let (mut server, mut client) = test_init(&driver);
2493 let channel = server.get_channel(&mut client).await;
2494
2495 let recv = channel.request_send.call(
2498 ChannelRequest::Modify,
2499 ModifyRequest::TargetVp { target_vp: 1 },
2500 );
2501
2502 check_message(
2503 server.next().await.unwrap(),
2504 protocol::ModifyChannel {
2505 channel_id: ChannelId(0),
2506 target_vp: 1,
2507 },
2508 );
2509
2510 server.send(in_msg(
2511 MessageType::MODIFY_CHANNEL_RESPONSE,
2512 protocol::ModifyChannelResponse {
2513 channel_id: ChannelId(0),
2514 status: protocol::STATUS_SUCCESS,
2515 },
2516 ));
2517
2518 let status = recv.await.unwrap();
2519 assert_eq!(status, protocol::STATUS_SUCCESS);
2520 }
2521
2522 #[async_test]
2523 async fn test_save_restore_connected(driver: DefaultDriver) {
2524 let (mut server, mut client) = test_init(&driver);
2525 server.connect(&mut client).await;
2526 server.stop_client(&mut client).await;
2527 let s0 = client.save().await;
2528 let builder = client.sever().await;
2529 let mut client = builder.build(&driver);
2530 client.restore(s0.clone()).await.unwrap();
2531
2532 let s1 = client.save().await;
2533
2534 assert_eq!(s0, s1);
2535 }
2536
2537 #[async_test]
2538 async fn test_save_restore_connected_with_channel(driver: DefaultDriver) {
2539 let (mut server, mut client) = test_init(&driver);
2540 let c0 = server.get_channel(&mut client).await;
2541 server.stop_client(&mut client).await;
2542 let s0 = client.save().await;
2543 let builder = client.sever().await;
2544 let mut client = builder.build(&driver);
2545 let connection = client.restore(s0.clone()).await.unwrap().unwrap();
2546 let s1 = client.save().await;
2547 assert_eq!(s0, s1);
2548 assert_eq!(connection.offers[0].offer, c0.offer);
2549 }
2550
2551 #[async_test]
2552 async fn test_save_restore_connected_with_revoked_channel(driver: DefaultDriver) {
2553 let (mut server, mut client) = test_init(&driver);
2554 let c0 = server.get_channel(&mut client).await;
2555 server.send(in_msg(
2556 MessageType::RESCIND_CHANNEL_OFFER,
2557 protocol::RescindChannelOffer {
2558 channel_id: ChannelId(0),
2559 },
2560 ));
2561 c0.revoke_recv.await.unwrap();
2562 let rpc = c0.request_send.call(
2563 ChannelRequest::Modify,
2564 ModifyRequest::TargetVp { target_vp: 1 },
2565 );
2566
2567 check_message(
2568 server.next().await.unwrap(),
2569 protocol::ModifyChannel {
2570 channel_id: ChannelId(0),
2571 target_vp: 1,
2572 },
2573 );
2574
2575 let client_stop = client.stop();
2576 let server_stop = async {
2577 server.send(in_msg(
2578 MessageType::MODIFY_CHANNEL_RESPONSE,
2579 protocol::ModifyChannelResponse {
2580 channel_id: ChannelId(0),
2581 status: protocol::STATUS_SUCCESS,
2582 },
2583 ));
2584 check_message(server.next().await.unwrap(), protocol::Pause);
2585 server.send(in_msg(MessageType::PAUSE_RESPONSE, protocol::PauseResponse));
2586 };
2587 (client_stop, server_stop).join().await;
2588
2589 rpc.await.unwrap();
2590
2591 let s0 = client.save().await;
2592 let builder = client.sever().await;
2593 let mut client = builder.build(&driver);
2594 let connection = client.restore(s0.clone()).await.unwrap().unwrap();
2595 let s1 = client.save().await;
2596 assert_eq!(s0, s1);
2597 assert!(connection.offers.is_empty());
2598 server.start_client(&mut client).await;
2599 check_message(
2600 server.next().await.unwrap(),
2601 protocol::RelIdReleased {
2602 channel_id: ChannelId(0),
2603 },
2604 );
2605 }
2606
2607 #[async_test]
2608 async fn test_connect_fails_on_incorrect_state(driver: DefaultDriver) {
2609 let (mut server, mut client) = test_init(&driver);
2610 server.connect(&mut client).await;
2611 let err = client.connect(0, None, Guid::ZERO).await.unwrap_err();
2612 assert!(matches!(err, ConnectError::InvalidState), "{:?}", err);
2613 }
2614
2615 #[async_test]
2616 async fn test_hot_add_remove(driver: DefaultDriver) {
2617 let (mut server, mut client) = test_init(&driver);
2618
2619 let mut connection = server.connect(&mut client).await;
2620 let offer = protocol::OfferChannel {
2621 interface_id: Guid::new_random(),
2622 instance_id: Guid::new_random(),
2623 rsvd: [0; 4],
2624 flags: OfferFlags::new(),
2625 mmio_megabytes: 0,
2626 user_defined: UserDefinedData::new_zeroed(),
2627 subchannel_index: 0,
2628 mmio_megabytes_optional: 0,
2629 channel_id: ChannelId(5),
2630 monitor_id: 0,
2631 monitor_allocated: 0,
2632 is_dedicated: 0,
2633 connection_id: 0,
2634 };
2635
2636 server.send(in_msg(MessageType::OFFER_CHANNEL, offer));
2637 let info = connection.offer_recv.next().await.unwrap();
2638
2639 assert_eq!(offer, info.offer);
2640
2641 server.send(in_msg(
2642 MessageType::RESCIND_CHANNEL_OFFER,
2643 protocol::RescindChannelOffer {
2644 channel_id: ChannelId(5),
2645 },
2646 ));
2647
2648 info.revoke_recv.await.unwrap();
2649 drop(info.request_send);
2650
2651 check_message(
2652 server.next().await.unwrap(),
2653 protocol::RelIdReleased {
2654 channel_id: ChannelId(5),
2655 },
2656 );
2657 }
2658
2659 #[async_test]
2660 async fn test_gpadl_success(driver: DefaultDriver) {
2661 let (mut server, mut client) = test_init(&driver);
2662 let channel = server.get_channel(&mut client).await;
2663 let recv = channel.request_send.call(
2664 ChannelRequest::Gpadl,
2665 GpadlRequest {
2666 id: GpadlId(1),
2667 count: 1,
2668 buf: vec![5],
2669 },
2670 );
2671
2672 check_message_with_data(
2673 server.next().await.unwrap(),
2674 protocol::GpadlHeader {
2675 channel_id: ChannelId(0),
2676 gpadl_id: GpadlId(1),
2677 len: 8,
2678 count: 1,
2679 },
2680 0x5u64.as_bytes(),
2681 );
2682
2683 server.send(in_msg(
2684 MessageType::GPADL_CREATED,
2685 protocol::GpadlCreated {
2686 channel_id: ChannelId(0),
2687 gpadl_id: GpadlId(1),
2688 status: protocol::STATUS_SUCCESS,
2689 },
2690 ));
2691
2692 recv.await.unwrap().unwrap();
2693
2694 let rpc = channel
2695 .request_send
2696 .call(ChannelRequest::TeardownGpadl, GpadlId(1));
2697
2698 check_message(
2699 server.next().await.unwrap(),
2700 protocol::GpadlTeardown {
2701 channel_id: ChannelId(0),
2702 gpadl_id: GpadlId(1),
2703 },
2704 );
2705
2706 server.send(in_msg(
2707 MessageType::GPADL_TORNDOWN,
2708 protocol::GpadlTorndown {
2709 gpadl_id: GpadlId(1),
2710 },
2711 ));
2712
2713 rpc.await.unwrap();
2714 }
2715
2716 #[async_test]
2717 async fn test_gpadl_fail(driver: DefaultDriver) {
2718 let (mut server, mut client) = test_init(&driver);
2719 let channel = server.get_channel(&mut client).await;
2720 let recv = channel.request_send.call(
2721 ChannelRequest::Gpadl,
2722 GpadlRequest {
2723 id: GpadlId(1),
2724 count: 1,
2725 buf: vec![7],
2726 },
2727 );
2728
2729 check_message_with_data(
2730 server.next().await.unwrap(),
2731 protocol::GpadlHeader {
2732 channel_id: ChannelId(0),
2733 gpadl_id: GpadlId(1),
2734 len: 8,
2735 count: 1,
2736 },
2737 0x7u64.as_bytes(),
2738 );
2739
2740 server.send(in_msg(
2741 MessageType::GPADL_CREATED,
2742 protocol::GpadlCreated {
2743 channel_id: ChannelId(0),
2744 gpadl_id: GpadlId(1),
2745 status: protocol::STATUS_UNSUCCESSFUL,
2746 },
2747 ));
2748
2749 recv.await.unwrap().unwrap_err();
2750 }
2751
2752 #[async_test]
2753 async fn test_gpadl_with_revoke(driver: DefaultDriver) {
2754 let (mut server, mut client) = test_init(&driver);
2755 let channel = server.get_channel(&mut client).await;
2756 let channel_id = ChannelId(0);
2757 for gpadl_id in [1, 2, 3].map(GpadlId) {
2758 let recv = channel.request_send.call(
2759 ChannelRequest::Gpadl,
2760 GpadlRequest {
2761 id: gpadl_id,
2762 count: 1,
2763 buf: vec![3],
2764 },
2765 );
2766
2767 check_message_with_data(
2768 server.next().await.unwrap(),
2769 protocol::GpadlHeader {
2770 channel_id,
2771 gpadl_id,
2772 len: 8,
2773 count: 1,
2774 },
2775 0x3u64.as_bytes(),
2776 );
2777
2778 server.send(in_msg(
2779 MessageType::GPADL_CREATED,
2780 protocol::GpadlCreated {
2781 channel_id,
2782 gpadl_id,
2783 status: protocol::STATUS_SUCCESS,
2784 },
2785 ));
2786
2787 recv.await.unwrap().unwrap();
2788 }
2789
2790 let rpc = channel
2791 .request_send
2792 .call(ChannelRequest::TeardownGpadl, GpadlId(1));
2793
2794 check_message(
2795 server.next().await.unwrap(),
2796 protocol::GpadlTeardown {
2797 channel_id,
2798 gpadl_id: GpadlId(1),
2799 },
2800 );
2801
2802 server.send(in_msg(
2803 MessageType::RESCIND_CHANNEL_OFFER,
2804 protocol::RescindChannelOffer { channel_id },
2805 ));
2806
2807 let recv = channel.request_send.call_failable(
2808 ChannelRequest::Gpadl,
2809 GpadlRequest {
2810 id: GpadlId(4),
2811 count: 1,
2812 buf: vec![3],
2813 },
2814 );
2815
2816 check_message_with_data(
2817 server.next().await.unwrap(),
2818 protocol::GpadlHeader {
2819 channel_id,
2820 gpadl_id: GpadlId(4),
2821 len: 8,
2822 count: 1,
2823 },
2824 0x3u64.as_bytes(),
2825 );
2826
2827 server.send(in_msg(
2828 MessageType::GPADL_CREATED,
2829 protocol::GpadlCreated {
2830 channel_id,
2831 gpadl_id: GpadlId(4),
2832 status: protocol::STATUS_UNSUCCESSFUL,
2833 },
2834 ));
2835
2836 server.send(in_msg(
2837 MessageType::GPADL_TORNDOWN,
2838 protocol::GpadlTorndown {
2839 gpadl_id: GpadlId(1),
2840 },
2841 ));
2842
2843 rpc.await.unwrap();
2844 recv.await.unwrap_err();
2845
2846 channel.revoke_recv.await.unwrap();
2847
2848 let rpc = channel
2849 .request_send
2850 .call(ChannelRequest::TeardownGpadl, GpadlId(2));
2851 drop(channel.request_send);
2852
2853 check_message(
2854 server.next().await.unwrap(),
2855 protocol::GpadlTeardown {
2856 channel_id,
2857 gpadl_id: GpadlId(2),
2858 },
2859 );
2860
2861 server.send(in_msg(
2862 MessageType::GPADL_TORNDOWN,
2863 protocol::GpadlTorndown {
2864 gpadl_id: GpadlId(2),
2865 },
2866 ));
2867
2868 rpc.await.unwrap();
2869
2870 check_message(
2871 server.next().await.unwrap(),
2872 protocol::RelIdReleased { channel_id },
2873 );
2874 }
2875
2876 #[async_test]
2877 async fn test_modify_connection(driver: DefaultDriver) {
2878 let (mut server, mut client) = test_init(&driver);
2879 server.connect(&mut client).await;
2880 let call = client.access.client_request_send.call(
2881 ClientRequest::Modify,
2882 ModifyConnectionRequest {
2883 monitor_page: Some(MonitorPageGpas {
2884 child_to_parent: 5,
2885 parent_to_child: 6,
2886 }),
2887 },
2888 );
2889
2890 check_message(
2891 server.next().await.unwrap(),
2892 protocol::ModifyConnection {
2893 child_to_parent_monitor_page_gpa: 5,
2894 parent_to_child_monitor_page_gpa: 6,
2895 },
2896 );
2897
2898 server.send(in_msg(
2899 MessageType::MODIFY_CONNECTION_RESPONSE,
2900 protocol::ModifyConnectionResponse {
2901 connection_state: ConnectionState::FAILED_LOW_RESOURCES,
2902 },
2903 ));
2904
2905 let result = call.await.unwrap();
2906 assert_eq!(ConnectionState::FAILED_LOW_RESOURCES, result);
2907 }
2908
2909 #[async_test]
2910 async fn test_hvsock(driver: DefaultDriver) {
2911 let (mut server, mut client) = test_init(&driver);
2912 server.connect(&mut client).await;
2913 let request = HvsockConnectRequest {
2914 service_id: Guid::new_random(),
2915 endpoint_id: Guid::new_random(),
2916 silo_id: Guid::new_random(),
2917 hosted_silo_unaware: false,
2918 };
2919
2920 let resp = client.access().connect_hvsock(request);
2921 check_message(
2922 server.next().await.unwrap(),
2923 protocol::TlConnectRequest2 {
2924 base: protocol::TlConnectRequest {
2925 service_id: request.service_id,
2926 endpoint_id: request.endpoint_id,
2927 },
2928 silo_id: request.silo_id,
2929 },
2930 );
2931
2932 server.send(in_msg(
2934 MessageType::TL_CONNECT_REQUEST_RESULT,
2935 protocol::TlConnectResult {
2936 service_id: request.service_id,
2937 endpoint_id: request.endpoint_id,
2938 status: protocol::STATUS_CONNECTION_REFUSED,
2939 },
2940 ));
2941
2942 let result = resp.await;
2943 assert!(result.is_none());
2944 }
2945
2946 #[async_test]
2947 async fn test_synic_event_flags(driver: DefaultDriver) {
2948 let (mut server, mut client) = test_init(&driver);
2949 let connection = server.get_channels(&mut client, 5).await;
2950 let event = Event::new();
2951
2952 for _ in 0..5 {
2953 for (i, channel) in connection.offers.iter().enumerate() {
2954 let recv = channel.request_send.call(
2955 ChannelRequest::Open,
2956 OpenRequest {
2957 open_data: OpenData {
2958 target_vp: 0,
2959 ring_offset: 0,
2960 ring_gpadl_id: GpadlId(0),
2961 event_flag: 0,
2962 connection_id: 0,
2963 user_data: UserDefinedData::new_zeroed(),
2964 },
2965 incoming_event: Some(event.clone()),
2966 use_vtl2_connection_id: false,
2967 },
2968 );
2969
2970 let expected_event_flag = i as u16 + 1;
2971
2972 check_message(
2973 server.next().await.unwrap(),
2974 protocol::OpenChannel2 {
2975 open_channel: protocol::OpenChannel {
2976 channel_id: channel.offer.channel_id,
2977 open_id: 0,
2978 ring_buffer_gpadl_id: GpadlId(0),
2979 target_vp: 0,
2980 downstream_ring_buffer_page_offset: 0,
2981 user_data: UserDefinedData::new_zeroed(),
2982 },
2983 connection_id: 0,
2984 event_flag: expected_event_flag,
2985 flags: OpenChannelFlags::new().with_redirect_interrupt(true),
2986 },
2987 );
2988
2989 server.send(in_msg(
2990 MessageType::OPEN_CHANNEL_RESULT,
2991 protocol::OpenResult {
2992 channel_id: channel.offer.channel_id,
2993 open_id: 0,
2994 status: protocol::STATUS_SUCCESS as u32,
2995 },
2996 ));
2997
2998 let output = recv.await.unwrap().unwrap();
2999 assert_eq!(output.redirected_event_flag, Some(expected_event_flag));
3000 }
3001
3002 for (i, channel) in connection.offers.iter().enumerate() {
3003 channel
3006 .request_send
3007 .call(ChannelRequest::Close, ())
3008 .await
3009 .unwrap();
3010
3011 check_message(
3012 server.next().await.unwrap(),
3013 protocol::CloseChannel {
3014 channel_id: ChannelId(i as u32),
3015 },
3016 );
3017 }
3018 }
3019 }
3020
3021 #[async_test]
3022 async fn test_revoke(driver: DefaultDriver) {
3023 let (mut server, mut client) = test_init(&driver);
3024 let channel = server.get_channel(&mut client).await;
3025
3026 server.send(in_msg(
3027 MessageType::RESCIND_CHANNEL_OFFER,
3028 protocol::RescindChannelOffer {
3029 channel_id: ChannelId(0),
3030 },
3031 ));
3032
3033 channel.revoke_recv.await.unwrap();
3034
3035 channel
3036 .request_send
3037 .call_failable(
3038 ChannelRequest::Open,
3039 OpenRequest {
3040 open_data: OpenData {
3041 target_vp: 0,
3042 ring_offset: 0,
3043 ring_gpadl_id: GpadlId(0),
3044 event_flag: 0,
3045 connection_id: 0,
3046 user_data: UserDefinedData::new_zeroed(),
3047 },
3048 incoming_event: None,
3049 use_vtl2_connection_id: false,
3050 },
3051 )
3052 .await
3053 .unwrap_err();
3054 }
3055
3056 #[async_test]
3057 #[should_panic(expected = "channel should not exist")]
3058 async fn test_reoffer_in_use_rel_id(driver: DefaultDriver) {
3059 let (mut server, mut client) = test_init(&driver);
3060 let mut connection = server.get_channels(&mut client, 1).await;
3061 let [channel] = connection.offers.try_into().unwrap();
3062
3063 server.send(in_msg(
3064 MessageType::RESCIND_CHANNEL_OFFER,
3065 protocol::RescindChannelOffer {
3066 channel_id: ChannelId(0),
3067 },
3068 ));
3069
3070 channel.revoke_recv.await.unwrap();
3071
3072 let offer = protocol::OfferChannel {
3074 interface_id: Guid::new_random(),
3075 instance_id: Guid::new_random(),
3076 rsvd: [0; 4],
3077 flags: OfferFlags::new(),
3078 mmio_megabytes: 0,
3079 user_defined: UserDefinedData::new_zeroed(),
3080 subchannel_index: 0,
3081 mmio_megabytes_optional: 0,
3082 channel_id: ChannelId(0),
3083 monitor_id: 0,
3084 monitor_allocated: 0,
3085 is_dedicated: 0,
3086 connection_id: 0,
3087 };
3088
3089 server.send(in_msg(MessageType::OFFER_CHANNEL, offer));
3090
3091 connection.offer_recv.next().await;
3092 }
3093
3094 #[async_test]
3095 async fn test_revoke_release_and_reoffer(driver: DefaultDriver) {
3096 let (mut server, mut client) = test_init(&driver);
3097 let mut connection = server.get_channels(&mut client, 1).await;
3098 let [channel] = connection.offers.try_into().unwrap();
3099
3100 server.send(in_msg(
3101 MessageType::RESCIND_CHANNEL_OFFER,
3102 protocol::RescindChannelOffer {
3103 channel_id: ChannelId(0),
3104 },
3105 ));
3106
3107 channel.revoke_recv.await.unwrap();
3108 drop(channel.request_send);
3109
3110 check_message(
3111 server.next().await.unwrap(),
3112 protocol::RelIdReleased {
3113 channel_id: ChannelId(0),
3114 },
3115 );
3116
3117 let offer = protocol::OfferChannel {
3118 interface_id: Guid::new_random(),
3119 instance_id: Guid::new_random(),
3120 rsvd: [0; 4],
3121 flags: OfferFlags::new(),
3122 mmio_megabytes: 0,
3123 user_defined: UserDefinedData::new_zeroed(),
3124 subchannel_index: 0,
3125 mmio_megabytes_optional: 0,
3126 channel_id: ChannelId(0),
3127 monitor_id: 0,
3128 monitor_allocated: 0,
3129 is_dedicated: 0,
3130 connection_id: 0,
3131 };
3132
3133 server.send(in_msg(MessageType::OFFER_CHANNEL, offer));
3134
3135 connection.offer_recv.next().await.unwrap();
3136 }
3137
3138 #[async_test]
3139 async fn test_release_revoke_and_reoffer(driver: DefaultDriver) {
3140 let (mut server, mut client) = test_init(&driver);
3141 let mut connection = server.get_channels(&mut client, 1).await;
3142 let [channel] = connection.offers.try_into().unwrap();
3143
3144 let open = channel.request_send.call_failable(
3145 ChannelRequest::Open,
3146 OpenRequest {
3147 open_data: OpenData {
3148 target_vp: 0,
3149 ring_offset: 0,
3150 ring_gpadl_id: GpadlId(0),
3151 event_flag: 0,
3152 connection_id: 0,
3153 user_data: UserDefinedData::new_zeroed(),
3154 },
3155 incoming_event: None,
3156 use_vtl2_connection_id: false,
3157 },
3158 );
3159
3160 let server_open = async {
3161 check_message(
3162 server.next().await.unwrap(),
3163 protocol::OpenChannel2 {
3164 open_channel: protocol::OpenChannel {
3165 channel_id: ChannelId(0),
3166 open_id: 0,
3167 ring_buffer_gpadl_id: GpadlId(0),
3168 target_vp: 0,
3169 downstream_ring_buffer_page_offset: 0,
3170 user_data: UserDefinedData::new_zeroed(),
3171 },
3172 connection_id: 0,
3173 event_flag: 0,
3174 flags: Default::default(),
3175 },
3176 );
3177 server.send(in_msg(
3178 MessageType::OPEN_CHANNEL_RESULT,
3179 protocol::OpenResult {
3180 channel_id: ChannelId(0),
3181 open_id: 0,
3182 status: protocol::STATUS_SUCCESS as u32,
3183 },
3184 ));
3185 };
3186
3187 (open, server_open).join().await.0.unwrap();
3188
3189 drop(channel);
3191
3192 check_message(
3193 server.next().await.unwrap(),
3194 protocol::CloseChannel {
3195 channel_id: ChannelId(0),
3196 },
3197 );
3198
3199 server.send(in_msg(
3200 MessageType::RESCIND_CHANNEL_OFFER,
3201 protocol::RescindChannelOffer {
3202 channel_id: ChannelId(0),
3203 },
3204 ));
3205
3206 check_message(
3208 server.next().await.unwrap(),
3209 protocol::RelIdReleased {
3210 channel_id: ChannelId(0),
3211 },
3212 );
3213
3214 let offer = protocol::OfferChannel {
3215 interface_id: Guid::new_random(),
3216 instance_id: Guid::new_random(),
3217 rsvd: [0; 4],
3218 flags: OfferFlags::new(),
3219 mmio_megabytes: 0,
3220 user_defined: UserDefinedData::new_zeroed(),
3221 subchannel_index: 0,
3222 mmio_megabytes_optional: 0,
3223 channel_id: ChannelId(0),
3224 monitor_id: 0,
3225 monitor_allocated: 0,
3226 is_dedicated: 0,
3227 connection_id: 0,
3228 };
3229
3230 server.send(in_msg(MessageType::OFFER_CHANNEL, offer));
3231
3232 connection.offer_recv.next().await.unwrap();
3234 }
3235}