1#![expect(missing_docs)]
7#![forbid(unsafe_code)]
8
9pub mod driver;
10pub mod filter;
11mod hvsock;
12pub mod saved_state;
13
14pub use self::saved_state::SavedState;
15use anyhow::Context as _;
16use anyhow::Result;
17use futures::FutureExt;
18use futures::StreamExt;
19use futures::future::OptionFuture;
20use futures::stream::SelectAll;
21use futures_concurrency::future::Race;
22use guid::Guid;
23use inspect::Inspect;
24use mesh::rpc::FailableRpc;
25use mesh::rpc::Rpc;
26use mesh::rpc::RpcSend;
27use pal_async::task::Spawn;
28use pal_async::task::Task;
29use pal_event::Event;
30use std::collections::HashMap;
31use std::collections::VecDeque;
32use std::collections::hash_map;
33use std::convert::TryInto;
34use std::future::Future;
35use std::future::poll_fn;
36use std::ops::Deref;
37use std::ops::DerefMut;
38use std::pin::pin;
39use std::sync::Arc;
40use std::task::Context;
41use std::task::Poll;
42use thiserror::Error;
43use vmbus_async::async_dgram::AsyncRecv;
44use vmbus_async::async_dgram::AsyncRecvExt;
45use vmbus_channel::bus::GpadlRequest;
46use vmbus_channel::bus::ModifyRequest;
47use vmbus_channel::bus::OpenData;
48use vmbus_channel::gpadl::GpadlId;
49use vmbus_core::HvsockConnectRequest;
50use vmbus_core::OutgoingMessage;
51use vmbus_core::TaggedStream;
52use vmbus_core::VersionInfo;
53use vmbus_core::protocol;
54use vmbus_core::protocol::ChannelId;
55use vmbus_core::protocol::ConnectionState;
56use vmbus_core::protocol::FeatureFlags;
57use vmbus_core::protocol::Message;
58use vmbus_core::protocol::OpenChannelFlags;
59use vmbus_core::protocol::Version;
60use vmcore::interrupt::Interrupt;
61use vmcore::synic::MonitorPageGpas;
62use zerocopy::Immutable;
63use zerocopy::IntoBytes;
64use zerocopy::KnownLayout;
65
66const SINT: u8 = 2;
67const VTL: u8 = 0;
68const SUPPORTED_VERSIONS: &[Version] = &[Version::Iron, Version::Copper];
69const SUPPORTED_FEATURE_FLAGS: FeatureFlags = FeatureFlags::new()
70 .with_guest_specified_signal_parameters(true)
71 .with_channel_interrupt_redirection(true)
72 .with_modify_connection(true)
73 .with_client_id(true)
74 .with_pause_resume(true);
75
76pub trait SynicEventClient: Send + Sync {
78 fn map_event(&self, event_flag: u16, event: &Event) -> std::io::Result<()>;
80
81 fn unmap_event(&self, event_flag: u16);
83
84 fn signal_event(&self, connection_id: u32, event_flag: u16) -> std::io::Result<()>;
86}
87
88pub trait VmbusMessageSource: AsyncRecv + Send {
90 fn pause_message_stream(&mut self) {}
93
94 fn resume_message_stream(&mut self) {}
96}
97
98pub trait PollPostMessage: Send {
99 fn poll_post_message(
100 &mut self,
101 cx: &mut Context<'_>,
102 connection_id: u32,
103 typ: u32,
104 msg: &[u8],
105 ) -> Poll<()>;
106}
107
108pub struct VmbusClient {
109 task_send: mesh::Sender<TaskRequest>,
110 access: VmbusClientAccess,
111 task: Task<ClientTask>,
112}
113
114#[derive(Debug, thiserror::Error)]
115pub enum ConnectError {
116 #[error("invalid state to connect to the server")]
117 InvalidState,
118 #[error("no supported protocol versions")]
119 NoSupportedVersions,
120 #[error("failed to connect to the server: {0:?}")]
121 FailedToConnect(ConnectionState),
122}
123
124#[derive(Clone)]
125pub struct VmbusClientAccess {
126 client_request_send: mesh::Sender<ClientRequest>,
127}
128
129pub struct VmbusClientBuilder {
131 event_client: Arc<dyn SynicEventClient>,
132 msg_source: Box<dyn VmbusMessageSource>,
133 msg_client: Box<dyn PollPostMessage>,
134}
135
136impl VmbusClientBuilder {
137 pub fn new(
139 event_client: impl SynicEventClient + 'static,
140 msg_source: impl VmbusMessageSource + 'static,
141 msg_client: impl PollPostMessage + 'static,
142 ) -> Self {
143 Self {
144 event_client: Arc::new(event_client),
145 msg_source: Box::new(msg_source),
146 msg_client: Box::new(msg_client),
147 }
148 }
149
150 pub fn build(self, spawner: &impl Spawn) -> VmbusClient {
152 let (task_send, task_recv) = mesh::channel();
153 let (client_request_send, client_request_recv) = mesh::channel();
154
155 let inner = ClientTaskInner {
156 messages: OutgoingMessages {
157 poster: self.msg_client,
158 queued: VecDeque::new(),
159 state: OutgoingMessageState::Paused,
160 },
161 teardown_gpadls: HashMap::new(),
162 channel_requests: SelectAll::new(),
163 synic: SynicState {
164 event_flag_state: Vec::new(),
165 event_client: self.event_client,
166 },
167 };
168
169 let mut task = ClientTask {
170 inner,
171 channels: ChannelList::default(),
172 task_recv,
173 running: false,
174 msg_source: self.msg_source,
175 client_request_recv,
176 state: ClientState::Disconnected,
177 modify_request: None,
178 hvsock_tracker: hvsock::HvsockRequestTracker::new(),
179 };
180
181 let task = spawner.spawn("vmbus client", async move {
182 task.run().await;
183 task
184 });
185
186 VmbusClient {
187 access: VmbusClientAccess {
188 client_request_send,
189 },
190 task_send,
191 task,
192 }
193 }
194}
195
196impl VmbusClient {
197 pub async fn connect(
200 &mut self,
201 target_message_vp: u32,
202 monitor_page: Option<MonitorPageGpas>,
203 client_id: Guid,
204 ) -> Result<ConnectResult, ConnectError> {
205 let request = ConnectRequest {
206 target_message_vp,
207 monitor_page,
208 client_id,
209 };
210
211 self.access
212 .client_request_send
213 .call(ClientRequest::Connect, request)
214 .await
215 .unwrap()
216 }
217
218 pub async fn unload(self) {
219 self.access
220 .client_request_send
221 .call(ClientRequest::Unload, ())
222 .await
223 .unwrap();
224
225 self.sever().await;
226 }
227
228 pub fn access(&self) -> &VmbusClientAccess {
229 &self.access
230 }
231
232 pub fn start(&mut self) {
233 self.task_send.send(TaskRequest::Start);
234 }
235
236 pub async fn stop(&mut self) {
237 self.task_send
238 .call(TaskRequest::Stop, ())
239 .await
240 .expect("Failed to send stop request");
241 }
242
243 pub async fn save(&self) -> SavedState {
244 self.task_send
245 .call(TaskRequest::Save, ())
246 .await
247 .expect("Failed to send save request")
248 }
249
250 pub async fn restore(
251 &mut self,
252 state: SavedState,
253 ) -> Result<Option<ConnectResult>, RestoreError> {
254 self.task_send
255 .call(TaskRequest::Restore, state)
256 .await
257 .expect("Failed to send restore request")
258 }
259
260 pub async fn post_restore(&mut self) {
261 self.task_send
262 .call(TaskRequest::PostRestore, ())
263 .await
264 .expect("Failed to send post-restore request");
265 }
266
267 async fn sever(self) -> VmbusClientBuilder {
268 drop(self.task_send);
269 let task = self.task.await;
270 VmbusClientBuilder {
271 event_client: task.inner.synic.event_client,
272 msg_source: task.msg_source,
273 msg_client: task.inner.messages.poster,
274 }
275 }
276}
277
278impl Inspect for VmbusClient {
279 fn inspect(&self, req: inspect::Request<'_>) {
280 self.task_send.send(TaskRequest::Inspect(req.defer()));
281 }
282}
283
284#[derive(Debug)]
285pub struct ConnectResult {
286 pub version: VersionInfo,
287 pub offers: Vec<OfferInfo>,
288 pub offer_recv: mesh::Receiver<OfferInfo>,
289}
290
291impl VmbusClientAccess {
292 pub async fn modify(&self, request: ModifyConnectionRequest) -> ConnectionState {
293 self.client_request_send
294 .call(ClientRequest::Modify, request)
295 .await
296 .expect("Failed to send modify request")
297 }
298
299 pub fn connect_hvsock(
300 &self,
301 request: HvsockConnectRequest,
302 ) -> impl Future<Output = Option<OfferInfo>> + use<> {
303 self.client_request_send
304 .call(ClientRequest::HvsockConnect, request)
305 .map(|r| r.ok().flatten())
306 }
307}
308
309#[derive(Debug)]
310pub struct OpenRequest {
311 pub open_data: OpenData,
312 pub incoming_event: Option<Event>,
313 pub use_vtl2_connection_id: bool,
314}
315
316#[derive(Debug)]
317pub struct RestoreRequest {
318 pub incoming_event: Option<Event>,
319 pub redirected_event_flag: Option<u16>,
321 pub connection_id: u32,
323}
324
325pub enum ChannelRequest {
327 Open(FailableRpc<OpenRequest, OpenOutput>),
328 Restore(FailableRpc<RestoreRequest, OpenOutput>),
329 Close(Rpc<(), ()>),
330 Gpadl(FailableRpc<GpadlRequest, ()>),
331 TeardownGpadl(Rpc<GpadlId, ()>),
332 Modify(Rpc<ModifyRequest, i32>),
333}
334
335#[derive(Debug)]
336pub struct OpenOutput {
337 pub redirected_event_flag: Option<u16>,
339 pub guest_to_host_signal: Interrupt,
340}
341
342impl std::fmt::Display for ChannelRequest {
343 fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
344 let s = match self {
345 ChannelRequest::Open(_) => "Open",
346 ChannelRequest::Close(_) => "Close",
347 ChannelRequest::Restore(_) => "Restore",
348 ChannelRequest::Gpadl(_) => "Gpadl",
349 ChannelRequest::TeardownGpadl(_) => "TeardownGpadl",
350 ChannelRequest::Modify(_) => "Modify",
351 };
352 fmt.pad(s)
353 }
354}
355
356#[derive(Debug, Error)]
357pub enum RestoreError {
358 #[error("unsupported protocol version {0:#x}")]
359 UnsupportedVersion(u32),
360
361 #[error("unsupported feature flags {0:#x}")]
362 UnsupportedFeatureFlags(u32),
363
364 #[error("duplicate channel id {0}")]
365 DuplicateChannelId(u32),
366
367 #[error("duplicate gpadl id {0}")]
368 DuplicateGpadlId(u32),
369
370 #[error("gpadl for unknown channel id {0}")]
371 GpadlForUnknownChannelId(u32),
372
373 #[error("invalid pending message")]
374 InvalidPendingMessage(#[source] vmbus_core::MessageTooLarge),
375
376 #[error("failed to offer restored channel")]
377 OfferFailed(#[source] anyhow::Error),
378}
379
380#[derive(Debug, Inspect)]
383pub struct OfferInfo {
384 pub offer: protocol::OfferChannel,
385 #[inspect(skip)]
386 pub request_send: mesh::Sender<ChannelRequest>,
387 #[inspect(skip)]
388 pub revoke_recv: mesh::OneshotReceiver<()>,
389}
390
391#[derive(Debug)]
392enum ClientRequest {
393 Connect(Rpc<ConnectRequest, Result<ConnectResult, ConnectError>>),
394 Unload(Rpc<(), ()>),
395 Modify(Rpc<ModifyConnectionRequest, ConnectionState>),
396 HvsockConnect(Rpc<HvsockConnectRequest, Option<OfferInfo>>),
397}
398
399impl std::fmt::Display for ClientRequest {
400 fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
401 let s = match self {
402 ClientRequest::Connect(..) => "Connect",
403 ClientRequest::Unload { .. } => "Unload",
404 ClientRequest::Modify(..) => "Modify",
405 ClientRequest::HvsockConnect(..) => "HvsockConnect",
406 };
407 fmt.pad(s)
408 }
409}
410
411enum TaskRequest {
412 Inspect(inspect::Deferred),
413 Save(Rpc<(), SavedState>),
414 Restore(Rpc<SavedState, Result<Option<ConnectResult>, RestoreError>>),
415 PostRestore(Rpc<(), ()>),
416 Start,
417 Stop(Rpc<(), ()>),
418}
419
420#[derive(Inspect)]
424#[inspect(external_tag)]
425enum ClientState {
426 Disconnected,
428 Connecting {
430 version: Version,
431 #[inspect(skip)]
432 rpc: Rpc<ConnectRequest, Result<ConnectResult, ConnectError>>,
433 },
434 Connected {
436 version: VersionInfo,
437 #[inspect(skip)]
438 offer_send: mesh::Sender<OfferInfo>,
439 },
440 RequestingOffers {
442 version: VersionInfo,
443 #[inspect(skip)]
444 rpc: Rpc<(), Result<ConnectResult, ConnectError>>,
445 #[inspect(skip)]
446 offers: Vec<OfferInfo>,
447 },
448 Disconnecting {
450 version: VersionInfo,
451 #[inspect(skip)]
452 rpc: Rpc<(), ()>,
453 },
454}
455
456impl ClientState {
457 fn get_version(&self) -> Option<VersionInfo> {
458 match self {
459 ClientState::Connected { version, .. } => Some(*version),
460 ClientState::RequestingOffers { version, .. } => Some(*version),
461 ClientState::Disconnecting { version, .. } => Some(*version),
462 ClientState::Disconnected | ClientState::Connecting { .. } => None,
463 }
464 }
465}
466
467impl std::fmt::Display for ClientState {
468 fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
469 let s = match self {
470 ClientState::Disconnected => "Disconnected",
471 ClientState::Connecting { .. } => "Connecting",
472 ClientState::Connected { .. } => "Connected",
473 ClientState::RequestingOffers { .. } => "RequestingOffers",
474 ClientState::Disconnecting { .. } => "Disconnecting",
475 };
476 fmt.pad(s)
477 }
478}
479
480#[derive(Copy, Clone, Debug, Default)]
481struct ConnectRequest {
482 target_message_vp: u32,
483 monitor_page: Option<MonitorPageGpas>,
484 client_id: Guid,
485}
486
487#[derive(Copy, Clone, Debug, Default)]
488pub struct ModifyConnectionRequest {
489 pub monitor_page: Option<MonitorPageGpas>,
490}
491
492impl From<ModifyConnectionRequest> for protocol::ModifyConnection {
493 fn from(value: ModifyConnectionRequest) -> Self {
494 let monitor_page = value.monitor_page.unwrap_or_default();
495
496 Self {
497 parent_to_child_monitor_page_gpa: monitor_page.parent_to_child,
498 child_to_parent_monitor_page_gpa: monitor_page.child_to_parent,
499 }
500 }
501}
502
503#[derive(Debug, Inspect)]
507#[inspect(external_tag)]
508enum ChannelState {
509 Offered,
511 Opening {
513 connection_id: u32,
514 redirected_event_flag: Option<u16>,
515 #[inspect(skip)]
516 redirected_event: Option<Event>,
517 #[inspect(skip)]
518 rpc: FailableRpc<(), OpenOutput>,
519 },
520 Restored,
522 Opened {
524 connection_id: u32,
525 redirected_event_flag: Option<u16>,
526 #[inspect(skip)]
527 redirected_event: Option<Event>,
528 },
529 Revoked,
531}
532
533impl std::fmt::Display for ChannelState {
534 fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
535 let s = match self {
536 ChannelState::Opening { .. } => "Opening",
537 ChannelState::Offered => "Offered",
538 ChannelState::Opened { .. } => "Opened",
539 ChannelState::Restored => "Restored",
540 ChannelState::Revoked => "Revoked",
541 };
542 fmt.pad(s)
543 }
544}
545
546#[derive(Debug, Inspect)]
547struct Channel {
548 offer: protocol::OfferChannel,
549 #[inspect(skip)]
551 revoke_send: Option<mesh::OneshotSender<()>>,
552 state: ChannelState,
553 #[inspect(with = "|x| x.is_some()")]
554 modify_response_send: Option<Rpc<(), i32>>,
555 #[inspect(with = "|x| inspect::iter_by_key(x).map_key(|x| x.0)")]
556 gpadls: HashMap<GpadlId, GpadlState>,
557 is_client_released: bool,
558}
559
560impl Channel {
561 fn pending_request(&self) -> Option<&'static str> {
562 if self.modify_response_send.is_some() {
563 return Some("modify");
564 }
565 self.gpadls.iter().find_map(|(_, gpadl)| match gpadl {
566 GpadlState::Offered(_) => Some("creating gpadl"),
567 GpadlState::Created => None,
568 GpadlState::TearingDown { .. } => Some("tearing down gpadl"),
569 })
570 }
571}
572
573#[derive(Inspect)]
574struct ClientTask {
575 #[inspect(flatten)]
576 inner: ClientTaskInner,
577 channels: ChannelList,
578 state: ClientState,
579 hvsock_tracker: hvsock::HvsockRequestTracker,
580 running: bool,
581 #[inspect(with = "|x| x.is_some()")]
582 modify_request: Option<Rpc<ModifyConnectionRequest, ConnectionState>>,
583 #[inspect(skip)]
584 msg_source: Box<dyn VmbusMessageSource>,
585 #[inspect(skip)]
586 task_recv: mesh::Receiver<TaskRequest>,
587 #[inspect(skip)]
588 client_request_recv: mesh::Receiver<ClientRequest>,
589}
590
591impl ClientTask {
592 fn handle_initiate_contact(
593 &mut self,
594 rpc: Rpc<ConnectRequest, Result<ConnectResult, ConnectError>>,
595 version: Version,
596 ) {
597 let ClientState::Disconnected = self.state else {
598 tracing::warn!(client_state = %self.state, "invalid client state for InitiateContact");
599 rpc.complete(Err(ConnectError::InvalidState));
600 return;
601 };
602 let feature_flags = if version >= Version::Copper {
603 SUPPORTED_FEATURE_FLAGS
604 } else {
605 FeatureFlags::new()
606 };
607
608 let request = rpc.input();
609
610 tracing::debug!(version = ?version, ?feature_flags, "VmBus client connecting");
611 let target_info = protocol::TargetInfo::new()
612 .with_sint(SINT)
613 .with_vtl(VTL)
614 .with_feature_flags(feature_flags.into());
615 let monitor_page = request.monitor_page.unwrap_or_default();
616 let msg = protocol::InitiateContact2 {
617 initiate_contact: protocol::InitiateContact {
618 version_requested: version as u32,
619 target_message_vp: request.target_message_vp,
620 interrupt_page_or_target_info: target_info.into(),
621 parent_to_child_monitor_page_gpa: monitor_page.parent_to_child,
622 child_to_parent_monitor_page_gpa: monitor_page.child_to_parent,
623 },
624 client_id: request.client_id,
625 };
626
627 self.state = ClientState::Connecting { version, rpc };
628 if version < Version::Copper {
629 self.inner.messages.send(&msg.initiate_contact)
630 } else {
631 self.inner.messages.send(&msg);
632 }
633 }
634
635 fn handle_unload(&mut self, rpc: Rpc<(), ()>) {
636 tracing::debug!(%self.state, "VmBus client disconnecting");
637 self.state = ClientState::Disconnecting {
638 version: self.state.get_version().expect("invalid state for unload"),
639 rpc,
640 };
641
642 self.inner.messages.send(&protocol::Unload {});
643 }
644
645 fn handle_modify(&mut self, request: Rpc<ModifyConnectionRequest, ConnectionState>) {
646 if !matches!(self.state, ClientState::Connected { version, .. }
647 if version.feature_flags.modify_connection())
648 {
649 tracing::warn!("ModifyConnection not supported");
650 request.complete(ConnectionState::FAILED_UNKNOWN_FAILURE);
651 return;
652 }
653
654 if self.modify_request.is_some() {
655 tracing::warn!("Duplicate ModifyConnection request");
656 request.complete(ConnectionState::FAILED_UNKNOWN_FAILURE);
657 return;
658 }
659
660 let message = protocol::ModifyConnection::from(*request.input());
661 self.modify_request = Some(request);
662 self.inner.messages.send(&message);
663 }
664
665 fn handle_tl_connect(&mut self, rpc: Rpc<HvsockConnectRequest, Option<OfferInfo>>) {
666 let message = protocol::TlConnectRequest2::from(*rpc.input());
670 self.hvsock_tracker.add_request(rpc);
671 self.inner.messages.send(&message);
672 }
673
674 fn handle_client_request(&mut self, request: ClientRequest) {
675 match request {
676 ClientRequest::Connect(rpc) => {
677 self.handle_initiate_contact(rpc, *SUPPORTED_VERSIONS.last().unwrap());
678 }
679 ClientRequest::Unload(rpc) => {
680 self.handle_unload(rpc);
681 }
682 ClientRequest::Modify(request) => self.handle_modify(request),
683 ClientRequest::HvsockConnect(request) => self.handle_tl_connect(request),
684 }
685 }
686
687 fn handle_version_response(&mut self, msg: protocol::VersionResponse2) {
688 let old_state = std::mem::replace(&mut self.state, ClientState::Disconnected);
689 let ClientState::Connecting { version, rpc } = old_state else {
690 self.state = old_state;
691 tracing::warn!(
692 client_state = %self.state,
693 "invalid client state to handle VersionResponse"
694 );
695 return;
696 };
697 if msg.version_response.version_supported > 0 {
698 if msg.version_response.connection_state != ConnectionState::SUCCESSFUL {
699 rpc.complete(Err(ConnectError::FailedToConnect(
700 msg.version_response.connection_state,
701 )));
702 return;
703 }
704
705 let feature_flags = if version >= Version::Copper {
706 FeatureFlags::from(msg.supported_features)
707 } else {
708 FeatureFlags::new()
709 };
710
711 let version = VersionInfo {
712 version,
713 feature_flags,
714 };
715
716 self.inner.messages.send(&protocol::RequestOffers {});
717 self.state = ClientState::RequestingOffers {
718 version,
719 rpc: rpc.split().1,
720 offers: Vec::new(),
721 };
722 tracing::info!(?version, "VmBus client connected, requesting offers");
723 } else {
724 let index = SUPPORTED_VERSIONS
725 .iter()
726 .position(|v| *v == version)
727 .unwrap();
728
729 if index == 0 {
730 rpc.complete(Err(ConnectError::NoSupportedVersions));
731 return;
732 }
733 let next_version = SUPPORTED_VERSIONS[index - 1];
734 tracing::debug!(
735 version = version as u32,
736 next_version = next_version as u32,
737 "Unsupported version, retrying"
738 );
739 self.handle_initiate_contact(rpc, next_version);
740 }
741 }
742
743 fn create_channel(&mut self, offer: protocol::OfferChannel) -> Result<OfferInfo> {
744 self.create_channel_core(offer, ChannelState::Offered)
745 }
746
747 fn create_channel_core(
748 &mut self,
749 offer: protocol::OfferChannel,
750 state: ChannelState,
751 ) -> Result<OfferInfo> {
752 if self.channels.0.contains_key(&offer.channel_id) {
753 anyhow::bail!("channel {:?} exists", offer.channel_id);
754 }
755 let (request_send, request_recv) = mesh::channel();
756 let (revoke_send, revoke_recv) = mesh::oneshot();
757
758 self.channels.0.insert(
759 offer.channel_id,
760 Channel {
761 revoke_send: Some(revoke_send),
762 offer,
763 state,
764 modify_response_send: None,
765 gpadls: HashMap::new(),
766 is_client_released: false,
767 },
768 );
769
770 self.inner
771 .channel_requests
772 .push(TaggedStream::new(offer.channel_id, request_recv));
773
774 Ok(OfferInfo {
775 offer,
776 revoke_recv,
777 request_send,
778 })
779 }
780
781 fn handle_offer(&mut self, offer: protocol::OfferChannel) {
782 let offer_info = self
783 .create_channel(offer)
784 .expect("channel should not exist");
785
786 tracing::info!(
787 state = %self.state,
788 channel_id = offer.channel_id.0,
789 interface_id = %offer.interface_id,
790 instance_id = %offer.instance_id,
791 subchannel_index = offer.subchannel_index,
792 "received offer");
793
794 if let Some(offer) = self.hvsock_tracker.check_offer(&offer_info.offer) {
795 offer.complete(Some(offer_info));
796 } else {
797 match &mut self.state {
798 ClientState::Connected { offer_send, .. } => {
799 offer_send.send(offer_info);
800 }
801 ClientState::RequestingOffers { offers, .. } => {
802 offers.push(offer_info);
803 }
804 state => unreachable!("invalid client state for OfferChannel: {state}"),
805 }
806 }
807 }
808
809 fn handle_rescind(&mut self, rescind: protocol::RescindChannelOffer) -> TriedRelease {
810 tracing::info!(state = %self.state, channel_id = rescind.channel_id.0, "received rescind");
811
812 let mut channel = self.channels.get_mut(rescind.channel_id);
813 let event_flag = match std::mem::replace(&mut channel.state, ChannelState::Revoked) {
814 ChannelState::Offered => None,
815 ChannelState::Opening {
816 connection_id: _,
817 redirected_event_flag,
818 redirected_event: _,
819 rpc,
820 } => {
821 rpc.fail(anyhow::anyhow!("channel revoked"));
822 redirected_event_flag
823 }
824 ChannelState::Restored => None,
825 ChannelState::Opened {
826 connection_id: _,
827 redirected_event_flag,
828 redirected_event: _,
829 } => redirected_event_flag,
830 ChannelState::Revoked => {
831 panic!("channel id {:?} already revoked", rescind.channel_id);
832 }
833 };
834 if let Some(event_flag) = event_flag {
835 self.inner.synic.free_event_flag(event_flag);
836 }
837
838 channel.revoke_send.take().unwrap().send(());
840
841 channel.try_release(&mut self.inner.messages)
842 }
843
844 fn handle_offers_delivered(&mut self) {
845 match std::mem::replace(&mut self.state, ClientState::Disconnected) {
846 ClientState::RequestingOffers {
847 version,
848 rpc,
849 offers,
850 } => {
851 tracing::info!(version = ?version, "VmBus client connected, offers delivered");
852 let (offer_send, offer_recv) = mesh::channel();
853 self.state = ClientState::Connected {
854 version,
855 offer_send,
856 };
857 rpc.complete(Ok(ConnectResult {
858 version,
859 offers,
860 offer_recv,
861 }));
862 }
863 state => {
864 tracing::warn!(client_state = %state, "invalid client state for OffersDelivered");
865 self.state = state;
866 }
867 }
868 }
869
870 fn handle_gpadl_created(&mut self, request: protocol::GpadlCreated) -> TriedRelease {
871 let mut channel = self.channels.get_mut(request.channel_id);
872 let Some(gpadl_state) = channel.gpadls.get_mut(&request.gpadl_id) else {
873 panic!("GpadlCreated for unknown gpadl {:#x}", request.gpadl_id.0);
874 };
875
876 let rpc = match std::mem::replace(gpadl_state, GpadlState::Created) {
877 GpadlState::Offered(rpc) => rpc,
878 old_state => {
879 panic!(
880 "invalid state {old_state:?} for gpadl {:#x}:{:#x}",
881 request.channel_id.0, request.gpadl_id.0
882 );
883 }
884 };
885
886 let gpadl_created = request.status == protocol::STATUS_SUCCESS;
887 if gpadl_created {
888 rpc.complete(Ok(()));
889 } else {
890 channel.gpadls.remove(&request.gpadl_id).unwrap();
891 rpc.fail(anyhow::anyhow!(
892 "gpadl creation failed: {:#x}",
893 request.status
894 ));
895 };
896 channel.try_release(&mut self.inner.messages)
897 }
898
899 fn handle_open_result(&mut self, result: protocol::OpenResult) {
900 tracing::debug!(
901 channel_id = result.channel_id.0,
902 result = result.status,
903 "received open result"
904 );
905
906 let mut channel = self.channels.get_mut(result.channel_id);
907
908 let channel_opened = result.status == protocol::STATUS_SUCCESS as u32;
909 let old_state = std::mem::replace(&mut channel.state, ChannelState::Offered);
910 let ChannelState::Opening {
911 connection_id,
912 redirected_event_flag,
913 redirected_event,
914 rpc,
915 } = old_state
916 else {
917 tracing::warn!(
918 old_state = ?channel.state,
919 channel_opened,
920 "invalid state for open result"
921 );
922 channel.state = old_state;
923 return;
924 };
925
926 if !channel_opened {
927 if let Some(event_flag) = redirected_event_flag {
928 self.inner.synic.free_event_flag(event_flag);
929 }
930 rpc.fail(anyhow::anyhow!("open failed: {:#x}", result.status));
931 return;
932 }
933
934 channel.state = ChannelState::Opened {
935 connection_id,
936 redirected_event_flag,
937 redirected_event,
938 };
939
940 rpc.complete(Ok(OpenOutput {
941 redirected_event_flag,
942 guest_to_host_signal: self.inner.synic.guest_to_host_interrupt(connection_id),
943 }));
944 }
945
946 fn handle_gpadl_torndown(&mut self, request: protocol::GpadlTorndown) -> TriedRelease {
947 let Some(channel_id) = self.inner.teardown_gpadls.remove(&request.gpadl_id) else {
948 panic!("gpadl {:#x} not in teardown list", request.gpadl_id.0);
949 };
950
951 tracing::debug!(
952 gpadl_id = request.gpadl_id.0,
953 channel_id = channel_id.0,
954 "Received GpadlTorndown"
955 );
956
957 let mut channel = self.channels.get_mut(channel_id);
958 let gpadl_state = channel
959 .gpadls
960 .remove(&request.gpadl_id)
961 .expect("gpadl validated above");
962
963 let GpadlState::TearingDown { rpcs } = gpadl_state else {
964 panic!("gpadl should be tearing down if in teardown list, state = {gpadl_state:?}");
965 };
966
967 for rpc in rpcs {
968 rpc.complete(());
969 }
970 channel.try_release(&mut self.inner.messages)
971 }
972
973 fn handle_unload_complete(&mut self) {
974 match std::mem::replace(&mut self.state, ClientState::Disconnected) {
975 ClientState::Disconnecting { version: _, rpc } => {
976 tracing::info!("VmBus client disconnected");
977 rpc.complete(());
978 }
979 state => {
980 tracing::warn!(client_state = %state, "invalid client state for UnloadComplete");
981 }
982 }
983 }
984
985 fn handle_modify_complete(&mut self, response: protocol::ModifyConnectionResponse) {
986 if let Some(request) = self.modify_request.take() {
987 request.complete(response.connection_state)
988 } else {
989 tracing::warn!("Unexpected modify complete request");
990 }
991 }
992
993 fn handle_modify_channel_response(
994 &mut self,
995 response: protocol::ModifyChannelResponse,
996 ) -> TriedRelease {
997 let mut channel = self.channels.get_mut(response.channel_id);
998 let Some(sender) = channel.modify_response_send.take() else {
999 panic!(
1000 "unexpected modify channel response for channel {:#x}",
1001 response.channel_id.0
1002 );
1003 };
1004
1005 sender.complete(response.status);
1006 channel.try_release(&mut self.inner.messages)
1007 }
1008
1009 fn handle_tl_connect_result(&mut self, response: protocol::TlConnectResult) {
1010 if let Some(rpc) = self.hvsock_tracker.check_result(&response) {
1011 rpc.complete(None);
1012 }
1013 }
1014
1015 fn handle_synic_message(&mut self, data: &[u8]) -> bool {
1017 let msg = Message::parse(data, self.state.get_version()).unwrap();
1018 tracing::trace!(?msg, "received client message from synic");
1019
1020 match msg {
1021 Message::VersionResponse2(version_response, ..) => {
1022 self.handle_version_response(version_response);
1023 }
1024 Message::VersionResponse(version_response, ..) => {
1025 self.handle_version_response(version_response.into());
1026 }
1027 Message::OfferChannel(offer, ..) => {
1028 self.handle_offer(offer);
1029 }
1030 Message::AllOffersDelivered(..) => {
1031 self.handle_offers_delivered();
1032 }
1033 Message::UnloadComplete(..) => {
1034 self.handle_unload_complete();
1035 }
1036 Message::ModifyConnectionResponse(response, ..) => {
1037 self.handle_modify_complete(response);
1038 }
1039 Message::GpadlCreated(gpadl, ..) => {
1040 self.handle_gpadl_created(gpadl);
1041 }
1042 Message::OpenResult(result, ..) => {
1043 self.handle_open_result(result);
1044 }
1045 Message::GpadlTorndown(gpadl, ..) => {
1046 self.handle_gpadl_torndown(gpadl);
1047 }
1048 Message::RescindChannelOffer(rescind, ..) => {
1049 self.handle_rescind(rescind);
1050 }
1051 Message::ModifyChannelResponse(response, ..) => {
1052 self.handle_modify_channel_response(response);
1053 }
1054 Message::TlConnectResult(response, ..) => self.handle_tl_connect_result(response),
1055 Message::CloseReservedChannelResponse(..) => {
1057 todo!("Unsupported message {msg:?}")
1058 }
1059 Message::PauseResponse(..) => {
1060 return false;
1061 }
1062 Message::RequestOffers(..)
1064 | Message::OpenChannel2(..)
1065 | Message::OpenChannel(..)
1066 | Message::CloseChannel(..)
1067 | Message::GpadlHeader(..)
1068 | Message::GpadlBody(..)
1069 | Message::GpadlTeardown(..)
1070 | Message::RelIdReleased(..)
1071 | Message::InitiateContact(..)
1072 | Message::InitiateContact2(..)
1073 | Message::Unload(..)
1074 | Message::OpenReservedChannel(..)
1075 | Message::CloseReservedChannel(..)
1076 | Message::TlConnectRequest2(..)
1077 | Message::TlConnectRequest(..)
1078 | Message::ModifyChannel(..)
1079 | Message::ModifyConnection(..)
1080 | Message::Pause(..)
1081 | Message::Resume(..) => {
1082 unreachable!("Client received server message {msg:?}");
1083 }
1084 }
1085 true
1086 }
1087
1088 fn handle_open_channel(
1089 &mut self,
1090 channel_id: ChannelId,
1091 rpc: FailableRpc<OpenRequest, OpenOutput>,
1092 ) {
1093 let mut channel = self.channels.get_mut(channel_id);
1094 match &channel.state {
1095 ChannelState::Offered => {}
1096 ChannelState::Revoked => {
1097 rpc.fail(anyhow::anyhow!("channel revoked"));
1098 return;
1099 }
1100 state => {
1101 rpc.fail(anyhow::anyhow!("invalid channel state: {}", state));
1102 return;
1103 }
1104 }
1105
1106 tracing::info!(channel_id = channel_id.0, "opening channel on host");
1107 let (request, rpc) = rpc.split();
1108 let open_data = &request.open_data;
1109
1110 let supports_interrupt_redirection =
1111 if let ClientState::Connected { version, .. } = self.state {
1112 version.feature_flags.guest_specified_signal_parameters()
1113 || version.feature_flags.channel_interrupt_redirection()
1114 } else {
1115 false
1116 };
1117
1118 if !supports_interrupt_redirection && open_data.event_flag != channel_id.0 as u16 {
1119 rpc.fail(anyhow::anyhow!(
1120 "host does not support specifying the event flag"
1121 ));
1122 return;
1123 }
1124
1125 let open_channel = protocol::OpenChannel {
1126 channel_id,
1127 open_id: 0,
1128 ring_buffer_gpadl_id: open_data.ring_gpadl_id,
1129 target_vp: open_data.target_vp,
1130 downstream_ring_buffer_page_offset: open_data.ring_offset,
1131 user_data: open_data.user_data,
1132 };
1133
1134 let connection_id = if request.use_vtl2_connection_id {
1135 if !supports_interrupt_redirection {
1136 rpc.fail(anyhow::anyhow!(
1137 "host does not support specfiying the connection ID"
1138 ));
1139 return;
1140 }
1141 protocol::ConnectionId::new(channel_id.0, 2.try_into().unwrap(), 7).0
1142 } else {
1143 open_data.connection_id
1144 };
1145
1146 let mut flags = OpenChannelFlags::new();
1149 let event_flag = if let Some(event) = &request.incoming_event {
1150 if !supports_interrupt_redirection {
1151 rpc.fail(anyhow::anyhow!(
1152 "host does not support redirecting interrupts"
1153 ));
1154 return;
1155 }
1156
1157 flags.set_redirect_interrupt(true);
1158 match self.inner.synic.allocate_event_flag(event) {
1159 Ok(flag) => flag,
1160 Err(err) => {
1161 rpc.fail(err.context("failed to allocate event flag"));
1162 return;
1163 }
1164 }
1165 } else {
1166 open_data.event_flag
1167 };
1168
1169 if supports_interrupt_redirection {
1170 self.inner.messages.send(&protocol::OpenChannel2 {
1171 open_channel,
1172 connection_id,
1173 event_flag,
1174 flags,
1175 });
1176 } else {
1177 self.inner.messages.send(&open_channel);
1178 }
1179
1180 channel.state = ChannelState::Opening {
1181 connection_id,
1182 redirected_event_flag: (request.incoming_event.is_some()).then_some(event_flag),
1183 redirected_event: request.incoming_event,
1184 rpc,
1185 }
1186 }
1187
1188 fn handle_restore_channel(
1189 &mut self,
1190 channel_id: ChannelId,
1191 request: RestoreRequest,
1192 ) -> Result<OpenOutput> {
1193 let mut channel = self.channels.get_mut(channel_id);
1194 if !matches!(channel.state, ChannelState::Restored) {
1195 anyhow::bail!("invalid channel state: {}", channel.state);
1196 }
1197
1198 if request.incoming_event.is_some() != request.redirected_event_flag.is_some() {
1199 anyhow::bail!("incoming event and redirected event flag must both be set or unset");
1200 }
1201
1202 if let Some((flag, event)) = request
1203 .redirected_event_flag
1204 .zip(request.incoming_event.as_ref())
1205 {
1206 self.inner.synic.restore_event_flag(flag, event)?;
1207 }
1208
1209 channel.state = ChannelState::Opened {
1210 connection_id: request.connection_id,
1211 redirected_event_flag: request.redirected_event_flag,
1212 redirected_event: request.incoming_event,
1213 };
1214 Ok(OpenOutput {
1215 redirected_event_flag: request.redirected_event_flag,
1216 guest_to_host_signal: self
1217 .inner
1218 .synic
1219 .guest_to_host_interrupt(request.connection_id),
1220 })
1221 }
1222
1223 fn handle_gpadl(&mut self, channel_id: ChannelId, rpc: FailableRpc<GpadlRequest, ()>) {
1224 let (request, rpc) = rpc.split();
1225 let mut channel = self.channels.get_mut(channel_id);
1226 if channel
1227 .gpadls
1228 .insert(request.id, GpadlState::Offered(rpc))
1229 .is_some()
1230 {
1231 panic!(
1232 "duplicate gpadl ID {:?} for channel {:?}.",
1233 request.id, channel_id
1234 );
1235 }
1236
1237 tracing::trace!(
1238 channel_id = channel_id.0,
1239 gpadl_id = request.id.0,
1240 count = request.count,
1241 len = request.buf.len(),
1242 "received gpadl request"
1243 );
1244
1245 let (first, remaining) = if request.buf.len() > protocol::GpadlHeader::MAX_DATA_VALUES {
1247 request.buf.split_at(protocol::GpadlHeader::MAX_DATA_VALUES)
1248 } else {
1249 (request.buf.as_slice(), [].as_slice())
1250 };
1251
1252 let message = protocol::GpadlHeader {
1253 channel_id,
1254 gpadl_id: request.id,
1255 len: (request.buf.len() * size_of::<u64>())
1256 .try_into()
1257 .expect("Too many GPA values"),
1258 count: request.count,
1259 };
1260
1261 self.inner
1262 .messages
1263 .send_with_data(&message, first.as_bytes());
1264
1265 let message = protocol::GpadlBody {
1267 rsvd: 0,
1268 gpadl_id: request.id,
1269 };
1270 for chunk in remaining.chunks(protocol::GpadlBody::MAX_DATA_VALUES) {
1271 self.inner
1272 .messages
1273 .send_with_data(&message, chunk.as_bytes());
1274 }
1275 }
1276
1277 fn handle_gpadl_teardown(&mut self, channel_id: ChannelId, rpc: Rpc<GpadlId, ()>) {
1278 let (gpadl_id, rpc) = rpc.split();
1279 let mut channel = self.channels.get_mut(channel_id);
1280 let Some(gpadl_state) = channel.gpadls.get_mut(&gpadl_id) else {
1281 tracing::warn!(
1282 gpadl_id = gpadl_id.0,
1283 channel_id = channel_id.0,
1284 "Gpadl teardown for unknown gpadl or revoked channel"
1285 );
1286 return;
1287 };
1288
1289 match gpadl_state {
1290 GpadlState::Offered(_) => {
1291 tracing::warn!(
1292 gpadl_id = gpadl_id.0,
1293 channel_id = channel_id.0,
1294 "gpadl teardown for offered gpadl"
1295 );
1296 }
1297 GpadlState::Created => {
1298 *gpadl_state = GpadlState::TearingDown { rpcs: vec![rpc] };
1299 assert!(
1303 self.inner
1304 .teardown_gpadls
1305 .insert(gpadl_id, channel_id)
1306 .is_none(),
1307 "Gpadl state validated above"
1308 );
1309
1310 self.inner.messages.send(&protocol::GpadlTeardown {
1311 channel_id,
1312 gpadl_id,
1313 });
1314 }
1315 GpadlState::TearingDown { rpcs } => {
1316 rpcs.push(rpc);
1317 }
1318 }
1319 }
1320
1321 fn handle_close_channel(&mut self, channel_id: ChannelId) {
1322 let mut channel = self.channels.get_mut(channel_id);
1323 self.inner.close_channel(channel_id, &mut channel);
1324 }
1325
1326 fn handle_modify_channel(&mut self, channel_id: ChannelId, rpc: Rpc<ModifyRequest, i32>) {
1327 assert!(self.check_version(Version::Iron));
1331 let mut channel = self.channels.get_mut(channel_id);
1332 if channel.modify_response_send.is_some() {
1333 panic!("duplicate channel modify request {channel_id:?}");
1334 }
1335
1336 let (request, response) = rpc.split();
1337 channel.modify_response_send = Some(response);
1338 let payload = match request {
1339 ModifyRequest::TargetVp { target_vp } => protocol::ModifyChannel {
1340 channel_id,
1341 target_vp,
1342 },
1343 };
1344
1345 self.inner.messages.send(&payload);
1346 }
1347
1348 fn handle_channel_request(&mut self, channel_id: ChannelId, request: ChannelRequest) {
1349 match request {
1350 ChannelRequest::Open(rpc) => self.handle_open_channel(channel_id, rpc),
1351 ChannelRequest::Restore(rpc) => {
1352 rpc.handle_failable_sync(|request| self.handle_restore_channel(channel_id, request))
1353 }
1354 ChannelRequest::Gpadl(req) => self.handle_gpadl(channel_id, req),
1355 ChannelRequest::TeardownGpadl(req) => self.handle_gpadl_teardown(channel_id, req),
1356 ChannelRequest::Close(req) => {
1357 req.handle_sync(|()| self.handle_close_channel(channel_id))
1358 }
1359 ChannelRequest::Modify(req) => self.handle_modify_channel(channel_id, req),
1360 }
1361 }
1362
1363 async fn handle_task(&mut self, task: TaskRequest) {
1364 match task {
1365 TaskRequest::Inspect(deferred) => {
1366 deferred.inspect(&*self);
1367 }
1368 TaskRequest::Save(rpc) => rpc.handle_sync(|()| self.handle_save()),
1369 TaskRequest::Restore(rpc) => {
1370 rpc.handle_sync(|saved_state| self.handle_restore(saved_state))
1371 }
1372 TaskRequest::PostRestore(rpc) => rpc.handle_sync(|()| self.handle_post_restore()),
1373 TaskRequest::Start => self.handle_start(),
1374 TaskRequest::Stop(rpc) => rpc.handle(async |()| self.handle_stop().await).await,
1375 }
1376 }
1377
1378 fn handle_device_removal(&mut self, channel_id: ChannelId) -> TriedRelease {
1380 let mut channel = self.channels.get_mut(channel_id);
1381 channel.is_client_released = true;
1382 if let ChannelState::Opened { .. } = channel.state {
1384 tracing::warn!(
1385 channel_id = channel_id.0,
1386 "Channel dropped without closing first"
1387 );
1388 self.inner.close_channel(channel_id, &mut channel);
1389 }
1390 channel.try_release(&mut self.inner.messages)
1391 }
1392
1393 fn check_version(&self, version: Version) -> bool {
1395 matches!(self.state, ClientState::Connected { version: v, .. } if v.version >= version)
1396 }
1397
1398 fn handle_start(&mut self) {
1399 assert!(!self.running);
1400 self.msg_source.resume_message_stream();
1401 self.inner.messages.resume();
1402 self.running = true;
1403 }
1404
1405 async fn handle_stop(&mut self) {
1406 assert!(self.running);
1407
1408 loop {
1409 while let Some((id, request)) = self.channels.revoked_channel_with_pending_request() {
1414 tracelimit::info_ratelimited!(
1415 channel_id = id.0,
1416 request,
1417 "waiting for responses for channel"
1418 );
1419 assert!(self.process_next_message().await);
1420 }
1421
1422 if self.can_pause_resume() {
1423 self.inner.messages.pause();
1424 } else {
1425 self.msg_source.pause_message_stream();
1428 self.inner.messages.force_pause();
1429 }
1430
1431 while self.process_next_message().await {}
1434
1435 if self
1438 .channels
1439 .revoked_channel_with_pending_request()
1440 .is_none()
1441 {
1442 break;
1443 }
1444 if !self.can_pause_resume() {
1445 self.msg_source.resume_message_stream();
1446 }
1447 self.inner.messages.resume();
1448 }
1449
1450 tracing::debug!("messages drained");
1451 self.running = false;
1453 }
1454
1455 async fn process_next_message(&mut self) -> bool {
1456 let mut buf = [0; protocol::MAX_MESSAGE_SIZE];
1457 let recv = self.msg_source.recv(&mut buf);
1458 let flush = async {
1461 self.inner.messages.flush_messages().await;
1462 std::future::pending().await
1463 };
1464 let size = (recv, flush)
1465 .race()
1466 .await
1467 .expect("Fatal error reading messages from synic");
1468 if size == 0 {
1469 return false;
1470 }
1471 self.handle_synic_message(&buf[..size])
1472 }
1473
1474 fn can_pause_resume(&self) -> bool {
1483 if let ClientState::Connected { version, .. } = self.state {
1484 version.feature_flags.pause_resume()
1485 } else {
1486 false
1487 }
1488 }
1489
1490 async fn run(&mut self) {
1491 let mut buf = [0; protocol::MAX_MESSAGE_SIZE];
1492 loop {
1493 let mut message_recv =
1494 OptionFuture::from(self.running.then(|| self.msg_source.recv(&mut buf).fuse()));
1495
1496 let host_backed_up = !self.inner.messages.is_empty();
1506 let flush_messages = OptionFuture::from(
1507 (self.running && host_backed_up)
1508 .then(|| self.inner.messages.flush_messages().fuse()),
1509 );
1510
1511 let mut client_request_recv = OptionFuture::from(
1512 (self.running && !host_backed_up).then(|| self.client_request_recv.next()),
1513 );
1514
1515 let mut channel_requests = OptionFuture::from(
1516 (self.running && !host_backed_up)
1517 .then(|| self.inner.channel_requests.select_next_some()),
1518 );
1519
1520 futures::select! { _r = pin!(flush_messages) => {}
1522 r = self.task_recv.next() => {
1523 if let Some(task) = r {
1524 self.handle_task(task).await;
1525 } else {
1526 break;
1527 }
1528 }
1529 r = client_request_recv => {
1530 if let Some(Some(request)) = r {
1531 self.handle_client_request(request);
1532 } else {
1533 break;
1534 }
1535 }
1536 r = channel_requests => {
1537 match r.unwrap() {
1538 (id, Some(request)) => self.handle_channel_request(id, request),
1539 (id, _) => {
1540 self.handle_device_removal(id);
1541 }
1542 }
1543 }
1544 r = message_recv => {
1545 match r.unwrap() {
1546 Ok(size) => {
1547 if size == 0 {
1548 panic!("Unexpected end of file reading messages from synic.");
1549 }
1550
1551 self.handle_synic_message(&buf[..size]);
1552 }
1553 Err(err) => {
1554 panic!("Error reading messages from synic: {err:?}");
1555 }
1556 }
1557 }
1558 complete => break,
1559 }
1560 }
1561 }
1562}
1563
1564impl ClientTaskInner {
1565 fn close_channel(&mut self, channel_id: ChannelId, channel: &mut Channel) {
1566 if let ChannelState::Opened {
1567 redirected_event_flag,
1568 ..
1569 } = channel.state
1570 {
1571 if let Some(flag) = redirected_event_flag {
1572 self.synic.free_event_flag(flag);
1573 }
1574 tracing::info!(channel_id = channel_id.0, "closing channel on host");
1575 self.messages.send(&protocol::CloseChannel { channel_id });
1576 channel.state = ChannelState::Offered;
1577 } else {
1578 tracing::warn!(
1579 id = %channel_id.0,
1580 channel_state = %channel.state,
1581 "invalid channel state for close channel"
1582 );
1583 }
1584 }
1585}
1586
1587#[derive(Debug, Inspect)]
1588#[inspect(external_tag)]
1589enum GpadlState {
1590 Offered(#[inspect(skip)] FailableRpc<(), ()>),
1592 Created,
1594 TearingDown {
1596 #[inspect(skip)]
1597 rpcs: Vec<Rpc<(), ()>>,
1598 },
1599}
1600
1601#[derive(Inspect)]
1602struct OutgoingMessages {
1603 #[inspect(skip)]
1604 poster: Box<dyn PollPostMessage>,
1605 #[inspect(with = "|x| x.len()")]
1606 queued: VecDeque<OutgoingMessage>,
1607 state: OutgoingMessageState,
1608}
1609
1610#[derive(Inspect, PartialEq, Eq, Debug)]
1611enum OutgoingMessageState {
1612 Running,
1613 SendingPauseMessage,
1614 Paused,
1615}
1616
1617impl OutgoingMessages {
1618 fn send<T: IntoBytes + protocol::VmbusMessage + std::fmt::Debug + Immutable + KnownLayout>(
1619 &mut self,
1620 msg: &T,
1621 ) {
1622 self.send_with_data(msg, &[])
1623 }
1624
1625 fn send_with_data<
1626 T: IntoBytes + protocol::VmbusMessage + std::fmt::Debug + Immutable + KnownLayout,
1627 >(
1628 &mut self,
1629 msg: &T,
1630 data: &[u8],
1631 ) {
1632 tracing::trace!(typ = ?T::MESSAGE_TYPE, "Sending message to host");
1633 let msg = OutgoingMessage::with_data(msg, data);
1634 if self.queued.is_empty() && self.state == OutgoingMessageState::Running {
1635 let r = self.poster.poll_post_message(
1636 &mut Context::from_waker(std::task::Waker::noop()),
1637 protocol::VMBUS_MESSAGE_REDIRECT_CONNECTION_ID,
1638 1,
1639 msg.data(),
1640 );
1641 if let Poll::Ready(()) = r {
1642 return;
1643 }
1644 }
1645 tracing::trace!("queueing message");
1646 self.queued.push_back(msg);
1647 }
1648
1649 async fn flush_messages(&mut self) {
1650 let mut send = async |msg: &OutgoingMessage| {
1651 poll_fn(|cx| {
1652 self.poster.poll_post_message(
1653 cx,
1654 protocol::VMBUS_MESSAGE_REDIRECT_CONNECTION_ID,
1655 1,
1656 msg.data(),
1657 )
1658 })
1659 .await
1660 };
1661 match self.state {
1662 OutgoingMessageState::Running => {
1663 while let Some(msg) = self.queued.front() {
1664 send(msg).await;
1665 tracing::trace!("sent queued message");
1666 self.queued.pop_front();
1667 }
1668 }
1669 OutgoingMessageState::SendingPauseMessage => {
1670 send(&OutgoingMessage::new(&protocol::Pause)).await;
1671 tracing::trace!("sent pause message");
1672 self.state = OutgoingMessageState::Paused;
1673 }
1674 OutgoingMessageState::Paused => {}
1675 }
1676 }
1677
1678 fn pause(&mut self) {
1681 assert_eq!(self.state, OutgoingMessageState::Running);
1682 self.state = OutgoingMessageState::SendingPauseMessage;
1683 self.queued
1685 .push_front(OutgoingMessage::new(&protocol::Resume));
1686 }
1687
1688 fn force_pause(&mut self) {
1692 assert_eq!(self.state, OutgoingMessageState::Running);
1693 self.state = OutgoingMessageState::Paused;
1694 }
1695
1696 fn resume(&mut self) {
1697 assert_eq!(self.state, OutgoingMessageState::Paused);
1698 self.state = OutgoingMessageState::Running;
1699 }
1700
1701 fn is_empty(&self) -> bool {
1702 self.queued.is_empty()
1703 }
1704}
1705
1706#[derive(Inspect)]
1707struct ClientTaskInner {
1708 messages: OutgoingMessages,
1709 #[inspect(with = "|x| inspect::iter_by_key(x).map_key(|id| id.0)")]
1710 teardown_gpadls: HashMap<GpadlId, ChannelId>,
1711 #[inspect(skip)]
1712 channel_requests: SelectAll<TaggedStream<ChannelId, mesh::Receiver<ChannelRequest>>>,
1713 synic: SynicState,
1714}
1715
1716#[derive(Inspect)]
1717struct SynicState {
1718 #[inspect(skip)]
1719 event_client: Arc<dyn SynicEventClient>,
1720 #[inspect(iter_by_index)]
1721 event_flag_state: Vec<bool>,
1722}
1723
1724#[derive(Inspect, Default)]
1725#[inspect(transparent)]
1726struct ChannelList(
1727 #[inspect(with = "|x| inspect::iter_by_key(x).map_key(|id| id.0)")] HashMap<ChannelId, Channel>,
1728);
1729
1730struct ChannelRef<'a>(hash_map::OccupiedEntry<'a, ChannelId, Channel>);
1733
1734struct TriedRelease(());
1738
1739impl ChannelRef<'_> {
1740 fn try_release(self, messages: &mut OutgoingMessages) -> TriedRelease {
1744 if self.is_client_released
1745 && matches!(self.state, ChannelState::Revoked)
1746 && self.pending_request().is_none()
1747 {
1748 let channel_id = *self.0.key();
1749 tracelimit::info_ratelimited!(channel_id = channel_id.0, "releasing channel");
1750 messages.send(&protocol::RelIdReleased { channel_id });
1751 self.0.remove();
1752 }
1753 TriedRelease(())
1754 }
1755}
1756
1757impl Deref for ChannelRef<'_> {
1758 type Target = Channel;
1759
1760 fn deref(&self) -> &Self::Target {
1761 self.0.get()
1762 }
1763}
1764
1765impl DerefMut for ChannelRef<'_> {
1766 fn deref_mut(&mut self) -> &mut Self::Target {
1767 self.0.get_mut()
1768 }
1769}
1770
1771impl ChannelList {
1772 fn revoked_channel_with_pending_request(&self) -> Option<(ChannelId, &'static str)> {
1773 self.0.iter().find_map(|(&id, channel)| {
1774 if !matches!(channel.state, ChannelState::Revoked) {
1775 return None;
1776 }
1777 Some((id, channel.pending_request()?))
1778 })
1779 }
1780
1781 #[track_caller]
1782 fn get_mut(&mut self, channel_id: ChannelId) -> ChannelRef<'_> {
1783 match self.0.entry(channel_id) {
1784 hash_map::Entry::Occupied(entry) => ChannelRef(entry),
1785 hash_map::Entry::Vacant(_) => {
1786 panic!("channel {:?} not found", channel_id);
1787 }
1788 }
1789 }
1790}
1791
1792impl SynicState {
1793 fn guest_to_host_interrupt(&self, connection_id: u32) -> Interrupt {
1794 Interrupt::from_fn({
1795 let event_client = self.event_client.clone();
1796 move || {
1797 if let Err(err) = event_client.signal_event(connection_id, 0) {
1798 tracelimit::warn_ratelimited!(
1799 error = &err as &dyn std::error::Error,
1800 "failed to signal event"
1801 );
1802 }
1803 }
1804 })
1805 }
1806
1807 const MAX_EVENT_FLAGS: u16 = 2047;
1808
1809 fn allocate_event_flag(&mut self, event: &Event) -> Result<u16> {
1810 let i = self
1811 .event_flag_state
1812 .iter()
1813 .position(|&used| !used)
1814 .ok_or(())
1815 .or_else(|()| {
1816 if self.event_flag_state.len() >= Self::MAX_EVENT_FLAGS as usize {
1817 anyhow::bail!("out of event flags");
1818 }
1819 self.event_flag_state.push(false);
1820 Ok(self.event_flag_state.len() - 1)
1821 })?;
1822
1823 let event_flag = (i + 1) as u16;
1824 self.event_client
1825 .map_event(event_flag, event)
1826 .context("failed to map event")?;
1827 self.event_flag_state[i] = true;
1828 Ok(event_flag)
1829 }
1830
1831 fn restore_event_flag(&mut self, flag: u16, event: &Event) -> Result<()> {
1832 let i = (flag as usize)
1833 .checked_sub(1)
1834 .context("invalid event flag")?;
1835 if i >= Self::MAX_EVENT_FLAGS as usize {
1836 anyhow::bail!("invalid event flag");
1837 }
1838 if self.event_flag_state.len() <= i {
1839 self.event_flag_state.resize(i + 1, false);
1840 }
1841 if self.event_flag_state[i] {
1842 anyhow::bail!("event flag already in use");
1843 }
1844 self.event_client
1845 .map_event(flag, event)
1846 .context("failed to map event")?;
1847 self.event_flag_state[i] = true;
1848 Ok(())
1849 }
1850
1851 fn free_event_flag(&mut self, flag: u16) {
1852 let i = flag as usize - 1;
1853 assert!(i < self.event_flag_state.len());
1854 self.event_flag_state[i] = false;
1855 }
1856}
1857
1858#[cfg(test)]
1859mod tests {
1860 use super::*;
1861 use futures_concurrency::future::Join;
1862 use guid::Guid;
1863 use pal_async::DefaultDriver;
1864 use pal_async::async_test;
1865 use pal_async::timer::PolledTimer;
1866 use protocol::TargetInfo;
1867 use std::fmt::Debug;
1868 use std::task::ready;
1869 use std::time::Duration;
1870 use test_with_tracing::test;
1871 use vmbus_core::protocol::MessageHeader;
1872 use vmbus_core::protocol::MessageType;
1873 use vmbus_core::protocol::OfferFlags;
1874 use vmbus_core::protocol::UserDefinedData;
1875 use vmbus_core::protocol::VmbusMessage;
1876 use zerocopy::FromBytes;
1877 use zerocopy::FromZeros;
1878 use zerocopy::Immutable;
1879 use zerocopy::IntoBytes;
1880 use zerocopy::KnownLayout;
1881
1882 const VMBUS_TEST_CLIENT_ID: Guid = guid::guid!("e6e6e6e6-e6e6-e6e6-e6e6-e6e6e6e6e6e6");
1883
1884 fn in_msg<T: IntoBytes + Immutable + KnownLayout>(message_type: MessageType, t: T) -> Vec<u8> {
1885 let mut data = Vec::new();
1886 data.extend_from_slice(&message_type.0.to_ne_bytes());
1887 data.extend_from_slice(&0u32.to_ne_bytes());
1888 data.extend_from_slice(t.as_bytes());
1889 data
1890 }
1891
1892 #[track_caller]
1893 fn check_message<T>(msg: OutgoingMessage, chk: T)
1894 where
1895 T: IntoBytes + FromBytes + Immutable + KnownLayout + Debug + VmbusMessage,
1896 {
1897 check_message_with_data(msg, chk, &[]);
1898 }
1899
1900 #[track_caller]
1901 fn check_message_with_data<T>(msg: OutgoingMessage, chk: T, data: &[u8])
1902 where
1903 T: IntoBytes + FromBytes + Immutable + KnownLayout + Debug + VmbusMessage,
1904 {
1905 let chk_data = OutgoingMessage::with_data(&chk, data);
1906 if msg.data() != chk_data.data() {
1907 let (header, rest) = MessageHeader::read_from_prefix(msg.data()).unwrap();
1908 assert_eq!(header.message_type(), <T as VmbusMessage>::MESSAGE_TYPE);
1909 let (msg, rest) = T::read_from_prefix(rest).expect("incorrect message size");
1910 if msg.as_bytes() != chk.as_bytes() {
1911 panic!("mismatched messages, expected {:#?}, got {:#?}", chk, msg);
1912 }
1913 if rest != data {
1914 panic!("mismatched data, expected {:#?}, got {:#?}", data, rest);
1915 }
1916 }
1917 }
1918
1919 struct TestServer {
1920 messages: mesh::Receiver<OutgoingMessage>,
1921 send: mesh::Sender<Vec<u8>>,
1922 }
1923
1924 impl TestServer {
1925 async fn next(&mut self) -> Option<OutgoingMessage> {
1926 self.messages.next().await
1927 }
1928
1929 fn send(&self, msg: Vec<u8>) {
1930 self.send.send(msg);
1931 }
1932
1933 async fn connect(&mut self, client: &mut VmbusClient) -> ConnectResult {
1934 self.connect_with_channels(client, |_| {}).await
1935 }
1936
1937 async fn connect_with_channels(
1938 &mut self,
1939 client: &mut VmbusClient,
1940 send_offers: impl FnOnce(&mut Self),
1941 ) -> ConnectResult {
1942 let client_connect = client.connect(0, None, Guid::ZERO);
1943
1944 let server_connect = async {
1945 let _ = self.next().await.unwrap();
1946
1947 self.send(in_msg(
1948 MessageType::VERSION_RESPONSE,
1949 protocol::VersionResponse2 {
1950 version_response: protocol::VersionResponse {
1951 version_supported: 1,
1952 connection_state: ConnectionState::SUCCESSFUL,
1953 padding: 0,
1954 selected_version_or_connection_id: 0,
1955 },
1956 supported_features: SUPPORTED_FEATURE_FLAGS.into(),
1957 },
1958 ));
1959
1960 check_message(self.next().await.unwrap(), protocol::RequestOffers {});
1961
1962 send_offers(self);
1963 self.send(in_msg(MessageType::ALL_OFFERS_DELIVERED, [0x00]));
1964 };
1965
1966 let (connection, ()) = (client_connect, server_connect).join().await;
1967
1968 let connection = connection.unwrap();
1969 assert_eq!(connection.version.version, Version::Copper);
1970 assert_eq!(connection.version.feature_flags, SUPPORTED_FEATURE_FLAGS);
1971 connection
1972 }
1973
1974 async fn get_channel(&mut self, client: &mut VmbusClient) -> OfferInfo {
1975 let [channel] = self
1976 .get_channels(client, 1)
1977 .await
1978 .offers
1979 .try_into()
1980 .unwrap();
1981 channel
1982 }
1983
1984 async fn get_channels(&mut self, client: &mut VmbusClient, count: usize) -> ConnectResult {
1985 self.connect_with_channels(client, |this| {
1986 for i in 0..count {
1987 let offer = protocol::OfferChannel {
1988 interface_id: Guid::new_random(),
1989 instance_id: Guid::new_random(),
1990 rsvd: [0; 4],
1991 flags: OfferFlags::new(),
1992 mmio_megabytes: 0,
1993 user_defined: UserDefinedData::new_zeroed(),
1994 subchannel_index: 0,
1995 mmio_megabytes_optional: 0,
1996 channel_id: ChannelId(i as u32),
1997 monitor_id: 0,
1998 monitor_allocated: 0,
1999 is_dedicated: 0,
2000 connection_id: 0,
2001 };
2002
2003 this.send(in_msg(MessageType::OFFER_CHANNEL, offer));
2004 }
2005 })
2006 .await
2007 }
2008
2009 async fn stop_client(&mut self, client: &mut VmbusClient) {
2010 let client_stop = client.stop();
2011 let server_stop = async {
2012 check_message(self.next().await.unwrap(), protocol::Pause);
2013 self.send(in_msg(MessageType::PAUSE_RESPONSE, protocol::PauseResponse));
2014 };
2015 (client_stop, server_stop).join().await;
2016 }
2017
2018 async fn start_client(&mut self, client: &mut VmbusClient) {
2019 client.start();
2020 check_message(self.next().await.unwrap(), protocol::Resume);
2021 }
2022 }
2023
2024 struct TestServerClient {
2025 sender: mesh::Sender<OutgoingMessage>,
2026 timer: PolledTimer,
2027 deadline: Option<pal_async::timer::Instant>,
2028 }
2029
2030 impl PollPostMessage for TestServerClient {
2031 fn poll_post_message(
2032 &mut self,
2033 cx: &mut Context<'_>,
2034 _connection_id: u32,
2035 _typ: u32,
2036 msg: &[u8],
2037 ) -> Poll<()> {
2038 loop {
2039 if let Some(deadline) = self.deadline {
2040 ready!(self.timer.poll_until(cx, deadline));
2041 self.deadline = None;
2042 }
2043 let mut b = [0];
2048 getrandom::fill(&mut b).unwrap();
2049 if b[0] % 4 == 0 {
2050 self.deadline =
2051 Some(pal_async::timer::Instant::now() + Duration::from_millis(10));
2052 } else {
2053 let msg = OutgoingMessage::from_message(msg).unwrap();
2054 tracing::info!(
2055 msg = ?MessageHeader::read_from_prefix(msg.data()),
2056 "sending message"
2057 );
2058 self.sender.send(msg);
2059 break Poll::Ready(());
2060 }
2061 }
2062 }
2063 }
2064
2065 struct NoopSynicEvents;
2066
2067 impl SynicEventClient for NoopSynicEvents {
2068 fn map_event(&self, _event_flag: u16, _event: &Event) -> std::io::Result<()> {
2069 Ok(())
2070 }
2071
2072 fn unmap_event(&self, _event_flag: u16) {}
2073
2074 fn signal_event(&self, _connection_id: u32, _event_flag: u16) -> std::io::Result<()> {
2075 Err(std::io::ErrorKind::Unsupported.into())
2076 }
2077 }
2078
2079 struct TestMessageSource {
2080 msg_recv: mesh::Receiver<Vec<u8>>,
2081 paused: bool,
2082 }
2083
2084 impl AsyncRecv for TestMessageSource {
2085 fn poll_recv(
2086 &mut self,
2087 cx: &mut Context<'_>,
2088 mut bufs: &mut [std::io::IoSliceMut<'_>],
2089 ) -> Poll<std::io::Result<usize>> {
2090 let value = match self.msg_recv.poll_recv(cx) {
2091 Poll::Ready(v) => v.unwrap(),
2092 Poll::Pending => {
2093 if self.paused {
2094 return Poll::Ready(Ok(0));
2095 } else {
2096 return Poll::Pending;
2097 }
2098 }
2099 };
2100 let mut remaining = value.as_slice();
2101 let mut total_size = 0;
2102 while !remaining.is_empty() && !bufs.is_empty() {
2103 let size = bufs[0].len().min(remaining.len());
2104 bufs[0][..size].copy_from_slice(&remaining[..size]);
2105 remaining = &remaining[size..];
2106 bufs = &mut bufs[1..];
2107 total_size += size;
2108 }
2109
2110 Ok(total_size).into()
2111 }
2112 }
2113
2114 impl VmbusMessageSource for TestMessageSource {
2115 fn pause_message_stream(&mut self) {
2116 self.paused = true;
2117 }
2118
2119 fn resume_message_stream(&mut self) {
2120 self.paused = false;
2121 }
2122 }
2123
2124 fn test_init(driver: &DefaultDriver) -> (TestServer, VmbusClient) {
2125 let (msg_send, msg_recv) = mesh::channel();
2126 let (synic_send, synic_recv) = mesh::channel();
2127 let server = TestServer {
2128 messages: synic_recv,
2129 send: msg_send,
2130 };
2131 let mut client = VmbusClientBuilder::new(
2132 NoopSynicEvents,
2133 TestMessageSource {
2134 msg_recv,
2135 paused: false,
2136 },
2137 TestServerClient {
2138 sender: synic_send,
2139 deadline: None,
2140 timer: PolledTimer::new(driver),
2141 },
2142 )
2143 .build(driver);
2144 client.start();
2145 (server, client)
2146 }
2147
2148 #[async_test]
2149 async fn test_initiate_contact_success(driver: DefaultDriver) {
2150 let (mut server, client) = test_init(&driver);
2151 let _recv = client
2152 .access
2153 .client_request_send
2154 .call(ClientRequest::Connect, ConnectRequest::default());
2155 check_message(
2156 server.next().await.unwrap(),
2157 protocol::InitiateContact2 {
2158 initiate_contact: protocol::InitiateContact {
2159 version_requested: Version::Copper as u32,
2160 target_message_vp: 0,
2161 interrupt_page_or_target_info: TargetInfo::new()
2162 .with_sint(2)
2163 .with_vtl(0)
2164 .with_feature_flags(SUPPORTED_FEATURE_FLAGS.into())
2165 .into(),
2166 parent_to_child_monitor_page_gpa: 0,
2167 child_to_parent_monitor_page_gpa: 0,
2168 },
2169 ..FromZeros::new_zeroed()
2170 },
2171 );
2172 }
2173
2174 #[async_test]
2175 async fn test_connect_success(driver: DefaultDriver) {
2176 let (mut server, mut client) = test_init(&driver);
2177 let client_connect = client.connect(0, None, Guid::ZERO);
2178
2179 let server_connect = async {
2180 check_message(
2181 server.next().await.unwrap(),
2182 protocol::InitiateContact2 {
2183 initiate_contact: protocol::InitiateContact {
2184 version_requested: Version::Copper as u32,
2185 target_message_vp: 0,
2186 interrupt_page_or_target_info: TargetInfo::new()
2187 .with_sint(2)
2188 .with_vtl(0)
2189 .with_feature_flags(SUPPORTED_FEATURE_FLAGS.into())
2190 .into(),
2191 parent_to_child_monitor_page_gpa: 0,
2192 child_to_parent_monitor_page_gpa: 0,
2193 },
2194 ..FromZeros::new_zeroed()
2195 },
2196 );
2197
2198 server.send(in_msg(
2199 MessageType::VERSION_RESPONSE,
2200 protocol::VersionResponse2 {
2201 version_response: protocol::VersionResponse {
2202 version_supported: 1,
2203 connection_state: ConnectionState::SUCCESSFUL,
2204 padding: 0,
2205 selected_version_or_connection_id: 0,
2206 },
2207 supported_features: SUPPORTED_FEATURE_FLAGS.into_bits(),
2208 },
2209 ));
2210
2211 check_message(server.next().await.unwrap(), protocol::RequestOffers {});
2212 server.send(in_msg(MessageType::ALL_OFFERS_DELIVERED, [0x00]));
2213 };
2214
2215 let (connection, ()) = (client_connect, server_connect).join().await;
2216 let connection = connection.unwrap();
2217
2218 assert_eq!(connection.version.version, Version::Copper);
2219 assert_eq!(connection.version.feature_flags, SUPPORTED_FEATURE_FLAGS);
2220 }
2221
2222 #[async_test]
2223 async fn test_feature_flags(driver: DefaultDriver) {
2224 let (mut server, mut client) = test_init(&driver);
2225 let client_connect = client.connect(0, None, Guid::ZERO);
2226
2227 let server_connect = async {
2228 check_message(
2229 server.next().await.unwrap(),
2230 protocol::InitiateContact2 {
2231 initiate_contact: protocol::InitiateContact {
2232 version_requested: Version::Copper as u32,
2233 target_message_vp: 0,
2234 interrupt_page_or_target_info: TargetInfo::new()
2235 .with_sint(2)
2236 .with_vtl(0)
2237 .with_feature_flags(SUPPORTED_FEATURE_FLAGS.into())
2238 .into(),
2239 parent_to_child_monitor_page_gpa: 0,
2240 child_to_parent_monitor_page_gpa: 0,
2241 },
2242 ..FromZeros::new_zeroed()
2243 },
2244 );
2245
2246 server.send(in_msg(
2249 MessageType::VERSION_RESPONSE,
2250 protocol::VersionResponse2 {
2251 version_response: protocol::VersionResponse {
2252 version_supported: 1,
2253 connection_state: ConnectionState::SUCCESSFUL,
2254 padding: 0,
2255 selected_version_or_connection_id: 0,
2256 },
2257 supported_features: 2,
2258 },
2259 ));
2260
2261 check_message(server.next().await.unwrap(), protocol::RequestOffers {});
2262 server.send(in_msg(MessageType::ALL_OFFERS_DELIVERED, [0x00]));
2263 };
2264
2265 let (connection, ()) = (client_connect, server_connect).join().await;
2266 let connection = connection.unwrap();
2267
2268 assert_eq!(connection.version.version, Version::Copper);
2269 assert_eq!(
2270 connection.version.feature_flags,
2271 FeatureFlags::new().with_channel_interrupt_redirection(true)
2272 );
2273 }
2274
2275 #[async_test]
2276 async fn test_client_id(driver: DefaultDriver) {
2277 let (mut server, client) = test_init(&driver);
2278 let initiate_contact = ConnectRequest {
2279 client_id: VMBUS_TEST_CLIENT_ID,
2280 ..Default::default()
2281 };
2282 let _recv = client
2283 .access
2284 .client_request_send
2285 .call(ClientRequest::Connect, initiate_contact);
2286
2287 check_message(
2288 server.next().await.unwrap(),
2289 protocol::InitiateContact2 {
2290 initiate_contact: protocol::InitiateContact {
2291 version_requested: Version::Copper as u32,
2292 target_message_vp: 0,
2293 interrupt_page_or_target_info: TargetInfo::new()
2294 .with_sint(2)
2295 .with_vtl(0)
2296 .with_feature_flags(SUPPORTED_FEATURE_FLAGS.into())
2297 .into(),
2298 parent_to_child_monitor_page_gpa: 0,
2299 child_to_parent_monitor_page_gpa: 0,
2300 },
2301 client_id: VMBUS_TEST_CLIENT_ID,
2302 },
2303 );
2304 }
2305
2306 #[async_test]
2307 async fn test_version_negotiation(driver: DefaultDriver) {
2308 let (mut server, mut client) = test_init(&driver);
2309 let client_connect = client.connect(0, None, Guid::ZERO);
2310
2311 let server_connect = async {
2312 check_message(
2313 server.next().await.unwrap(),
2314 protocol::InitiateContact2 {
2315 initiate_contact: protocol::InitiateContact {
2316 version_requested: Version::Copper as u32,
2317 target_message_vp: 0,
2318 interrupt_page_or_target_info: TargetInfo::new()
2319 .with_sint(2)
2320 .with_vtl(0)
2321 .with_feature_flags(SUPPORTED_FEATURE_FLAGS.into())
2322 .into(),
2323 parent_to_child_monitor_page_gpa: 0,
2324 child_to_parent_monitor_page_gpa: 0,
2325 },
2326 ..FromZeros::new_zeroed()
2327 },
2328 );
2329
2330 server.send(in_msg(
2331 MessageType::VERSION_RESPONSE,
2332 protocol::VersionResponse {
2333 version_supported: 0,
2334 connection_state: ConnectionState::SUCCESSFUL,
2335 padding: 0,
2336 selected_version_or_connection_id: 0,
2337 },
2338 ));
2339
2340 check_message(
2341 server.next().await.unwrap(),
2342 protocol::InitiateContact {
2343 version_requested: Version::Iron as u32,
2344 target_message_vp: 0,
2345 interrupt_page_or_target_info: TargetInfo::new()
2346 .with_sint(2)
2347 .with_vtl(0)
2348 .with_feature_flags(FeatureFlags::new().into())
2349 .into(),
2350 parent_to_child_monitor_page_gpa: 0,
2351 child_to_parent_monitor_page_gpa: 0,
2352 },
2353 );
2354
2355 server.send(in_msg(
2356 MessageType::VERSION_RESPONSE,
2357 protocol::VersionResponse {
2358 version_supported: 1,
2359 connection_state: ConnectionState::SUCCESSFUL,
2360 padding: 0,
2361 selected_version_or_connection_id: 0,
2362 },
2363 ));
2364
2365 check_message(server.next().await.unwrap(), protocol::RequestOffers {});
2366 server.send(in_msg(MessageType::ALL_OFFERS_DELIVERED, [0x00]));
2367 };
2368
2369 let (connection, ()) = (client_connect, server_connect).join().await;
2370 let connection = connection.unwrap();
2371
2372 assert_eq!(connection.version.version, Version::Iron);
2373 assert_eq!(connection.version.feature_flags, FeatureFlags::new());
2374 }
2375
2376 #[async_test]
2377 async fn test_open_channel_success(driver: DefaultDriver) {
2378 let (mut server, mut client) = test_init(&driver);
2379 let channel = server.get_channel(&mut client).await;
2380
2381 let recv = channel.request_send.call(
2382 ChannelRequest::Open,
2383 OpenRequest {
2384 open_data: OpenData {
2385 target_vp: 0,
2386 ring_offset: 0,
2387 ring_gpadl_id: GpadlId(0),
2388 event_flag: 0,
2389 connection_id: 0,
2390 user_data: UserDefinedData::new_zeroed(),
2391 },
2392 incoming_event: None,
2393 use_vtl2_connection_id: false,
2394 },
2395 );
2396
2397 check_message(
2398 server.next().await.unwrap(),
2399 protocol::OpenChannel2 {
2400 open_channel: protocol::OpenChannel {
2401 channel_id: ChannelId(0),
2402 open_id: 0,
2403 ring_buffer_gpadl_id: GpadlId(0),
2404 target_vp: 0,
2405 downstream_ring_buffer_page_offset: 0,
2406 user_data: UserDefinedData::new_zeroed(),
2407 },
2408 connection_id: 0,
2409 event_flag: 0,
2410 flags: Default::default(),
2411 },
2412 );
2413
2414 server.send(in_msg(
2415 MessageType::OPEN_CHANNEL_RESULT,
2416 protocol::OpenResult {
2417 channel_id: ChannelId(0),
2418 open_id: 0,
2419 status: protocol::STATUS_SUCCESS as u32,
2420 },
2421 ));
2422
2423 recv.await.unwrap().unwrap();
2424 }
2425
2426 #[async_test]
2427 async fn test_open_channel_fail(driver: DefaultDriver) {
2428 let (mut server, mut client) = test_init(&driver);
2429 let channel = server.get_channel(&mut client).await;
2430
2431 let recv = channel.request_send.call(
2432 ChannelRequest::Open,
2433 OpenRequest {
2434 open_data: OpenData {
2435 target_vp: 0,
2436 ring_offset: 0,
2437 ring_gpadl_id: GpadlId(0),
2438 event_flag: 0,
2439 connection_id: 0,
2440 user_data: UserDefinedData::new_zeroed(),
2441 },
2442 incoming_event: None,
2443 use_vtl2_connection_id: false,
2444 },
2445 );
2446
2447 check_message(
2448 server.next().await.unwrap(),
2449 protocol::OpenChannel2 {
2450 open_channel: protocol::OpenChannel {
2451 channel_id: ChannelId(0),
2452 open_id: 0,
2453 ring_buffer_gpadl_id: GpadlId(0),
2454 target_vp: 0,
2455 downstream_ring_buffer_page_offset: 0,
2456 user_data: UserDefinedData::new_zeroed(),
2457 },
2458 connection_id: 0,
2459 event_flag: 0,
2460 flags: Default::default(),
2461 },
2462 );
2463
2464 server.send(in_msg(
2465 MessageType::OPEN_CHANNEL_RESULT,
2466 protocol::OpenResult {
2467 channel_id: ChannelId(0),
2468 open_id: 0,
2469 status: protocol::STATUS_UNSUCCESSFUL as u32,
2470 },
2471 ));
2472
2473 recv.await.unwrap().unwrap_err();
2474 }
2475
2476 #[async_test]
2477 async fn test_modify_channel(driver: DefaultDriver) {
2478 let (mut server, mut client) = test_init(&driver);
2479 let channel = server.get_channel(&mut client).await;
2480
2481 let recv = channel.request_send.call(
2484 ChannelRequest::Modify,
2485 ModifyRequest::TargetVp { target_vp: 1 },
2486 );
2487
2488 check_message(
2489 server.next().await.unwrap(),
2490 protocol::ModifyChannel {
2491 channel_id: ChannelId(0),
2492 target_vp: 1,
2493 },
2494 );
2495
2496 server.send(in_msg(
2497 MessageType::MODIFY_CHANNEL_RESPONSE,
2498 protocol::ModifyChannelResponse {
2499 channel_id: ChannelId(0),
2500 status: protocol::STATUS_SUCCESS,
2501 },
2502 ));
2503
2504 let status = recv.await.unwrap();
2505 assert_eq!(status, protocol::STATUS_SUCCESS);
2506 }
2507
2508 #[async_test]
2509 async fn test_save_restore_connected(driver: DefaultDriver) {
2510 let (mut server, mut client) = test_init(&driver);
2511 server.connect(&mut client).await;
2512 server.stop_client(&mut client).await;
2513 let s0 = client.save().await;
2514 let builder = client.sever().await;
2515 let mut client = builder.build(&driver);
2516 client.restore(s0.clone()).await.unwrap();
2517
2518 let s1 = client.save().await;
2519
2520 assert_eq!(s0, s1);
2521 }
2522
2523 #[async_test]
2524 async fn test_save_restore_connected_with_channel(driver: DefaultDriver) {
2525 let (mut server, mut client) = test_init(&driver);
2526 let c0 = server.get_channel(&mut client).await;
2527 server.stop_client(&mut client).await;
2528 let s0 = client.save().await;
2529 let builder = client.sever().await;
2530 let mut client = builder.build(&driver);
2531 let connection = client.restore(s0.clone()).await.unwrap().unwrap();
2532 let s1 = client.save().await;
2533 assert_eq!(s0, s1);
2534 assert_eq!(connection.offers[0].offer, c0.offer);
2535 }
2536
2537 #[async_test]
2538 async fn test_save_restore_connected_with_revoked_channel(driver: DefaultDriver) {
2539 let (mut server, mut client) = test_init(&driver);
2540 let c0 = server.get_channel(&mut client).await;
2541 server.send(in_msg(
2542 MessageType::RESCIND_CHANNEL_OFFER,
2543 protocol::RescindChannelOffer {
2544 channel_id: ChannelId(0),
2545 },
2546 ));
2547 c0.revoke_recv.await.unwrap();
2548 let rpc = c0.request_send.call(
2549 ChannelRequest::Modify,
2550 ModifyRequest::TargetVp { target_vp: 1 },
2551 );
2552
2553 check_message(
2554 server.next().await.unwrap(),
2555 protocol::ModifyChannel {
2556 channel_id: ChannelId(0),
2557 target_vp: 1,
2558 },
2559 );
2560
2561 let client_stop = client.stop();
2562 let server_stop = async {
2563 server.send(in_msg(
2564 MessageType::MODIFY_CHANNEL_RESPONSE,
2565 protocol::ModifyChannelResponse {
2566 channel_id: ChannelId(0),
2567 status: protocol::STATUS_SUCCESS,
2568 },
2569 ));
2570 check_message(server.next().await.unwrap(), protocol::Pause);
2571 server.send(in_msg(MessageType::PAUSE_RESPONSE, protocol::PauseResponse));
2572 };
2573 (client_stop, server_stop).join().await;
2574
2575 rpc.await.unwrap();
2576
2577 let s0 = client.save().await;
2578 let builder = client.sever().await;
2579 let mut client = builder.build(&driver);
2580 let connection = client.restore(s0.clone()).await.unwrap().unwrap();
2581 let s1 = client.save().await;
2582 assert_eq!(s0, s1);
2583 assert!(connection.offers.is_empty());
2584 server.start_client(&mut client).await;
2585 check_message(
2586 server.next().await.unwrap(),
2587 protocol::RelIdReleased {
2588 channel_id: ChannelId(0),
2589 },
2590 );
2591 }
2592
2593 #[async_test]
2594 async fn test_connect_fails_on_incorrect_state(driver: DefaultDriver) {
2595 let (mut server, mut client) = test_init(&driver);
2596 server.connect(&mut client).await;
2597 let err = client.connect(0, None, Guid::ZERO).await.unwrap_err();
2598 assert!(matches!(err, ConnectError::InvalidState), "{:?}", err);
2599 }
2600
2601 #[async_test]
2602 async fn test_hot_add_remove(driver: DefaultDriver) {
2603 let (mut server, mut client) = test_init(&driver);
2604
2605 let mut connection = server.connect(&mut client).await;
2606 let offer = protocol::OfferChannel {
2607 interface_id: Guid::new_random(),
2608 instance_id: Guid::new_random(),
2609 rsvd: [0; 4],
2610 flags: OfferFlags::new(),
2611 mmio_megabytes: 0,
2612 user_defined: UserDefinedData::new_zeroed(),
2613 subchannel_index: 0,
2614 mmio_megabytes_optional: 0,
2615 channel_id: ChannelId(5),
2616 monitor_id: 0,
2617 monitor_allocated: 0,
2618 is_dedicated: 0,
2619 connection_id: 0,
2620 };
2621
2622 server.send(in_msg(MessageType::OFFER_CHANNEL, offer));
2623 let info = connection.offer_recv.next().await.unwrap();
2624
2625 assert_eq!(offer, info.offer);
2626
2627 server.send(in_msg(
2628 MessageType::RESCIND_CHANNEL_OFFER,
2629 protocol::RescindChannelOffer {
2630 channel_id: ChannelId(5),
2631 },
2632 ));
2633
2634 info.revoke_recv.await.unwrap();
2635 drop(info.request_send);
2636
2637 check_message(
2638 server.next().await.unwrap(),
2639 protocol::RelIdReleased {
2640 channel_id: ChannelId(5),
2641 },
2642 );
2643 }
2644
2645 #[async_test]
2646 async fn test_gpadl_success(driver: DefaultDriver) {
2647 let (mut server, mut client) = test_init(&driver);
2648 let channel = server.get_channel(&mut client).await;
2649 let recv = channel.request_send.call(
2650 ChannelRequest::Gpadl,
2651 GpadlRequest {
2652 id: GpadlId(1),
2653 count: 1,
2654 buf: vec![5],
2655 },
2656 );
2657
2658 check_message_with_data(
2659 server.next().await.unwrap(),
2660 protocol::GpadlHeader {
2661 channel_id: ChannelId(0),
2662 gpadl_id: GpadlId(1),
2663 len: 8,
2664 count: 1,
2665 },
2666 0x5u64.as_bytes(),
2667 );
2668
2669 server.send(in_msg(
2670 MessageType::GPADL_CREATED,
2671 protocol::GpadlCreated {
2672 channel_id: ChannelId(0),
2673 gpadl_id: GpadlId(1),
2674 status: protocol::STATUS_SUCCESS,
2675 },
2676 ));
2677
2678 recv.await.unwrap().unwrap();
2679
2680 let rpc = channel
2681 .request_send
2682 .call(ChannelRequest::TeardownGpadl, GpadlId(1));
2683
2684 check_message(
2685 server.next().await.unwrap(),
2686 protocol::GpadlTeardown {
2687 channel_id: ChannelId(0),
2688 gpadl_id: GpadlId(1),
2689 },
2690 );
2691
2692 server.send(in_msg(
2693 MessageType::GPADL_TORNDOWN,
2694 protocol::GpadlTorndown {
2695 gpadl_id: GpadlId(1),
2696 },
2697 ));
2698
2699 rpc.await.unwrap();
2700 }
2701
2702 #[async_test]
2703 async fn test_gpadl_fail(driver: DefaultDriver) {
2704 let (mut server, mut client) = test_init(&driver);
2705 let channel = server.get_channel(&mut client).await;
2706 let recv = channel.request_send.call(
2707 ChannelRequest::Gpadl,
2708 GpadlRequest {
2709 id: GpadlId(1),
2710 count: 1,
2711 buf: vec![7],
2712 },
2713 );
2714
2715 check_message_with_data(
2716 server.next().await.unwrap(),
2717 protocol::GpadlHeader {
2718 channel_id: ChannelId(0),
2719 gpadl_id: GpadlId(1),
2720 len: 8,
2721 count: 1,
2722 },
2723 0x7u64.as_bytes(),
2724 );
2725
2726 server.send(in_msg(
2727 MessageType::GPADL_CREATED,
2728 protocol::GpadlCreated {
2729 channel_id: ChannelId(0),
2730 gpadl_id: GpadlId(1),
2731 status: protocol::STATUS_UNSUCCESSFUL,
2732 },
2733 ));
2734
2735 recv.await.unwrap().unwrap_err();
2736 }
2737
2738 #[async_test]
2739 async fn test_gpadl_with_revoke(driver: DefaultDriver) {
2740 let (mut server, mut client) = test_init(&driver);
2741 let channel = server.get_channel(&mut client).await;
2742 let channel_id = ChannelId(0);
2743 for gpadl_id in [1, 2, 3].map(GpadlId) {
2744 let recv = channel.request_send.call(
2745 ChannelRequest::Gpadl,
2746 GpadlRequest {
2747 id: gpadl_id,
2748 count: 1,
2749 buf: vec![3],
2750 },
2751 );
2752
2753 check_message_with_data(
2754 server.next().await.unwrap(),
2755 protocol::GpadlHeader {
2756 channel_id,
2757 gpadl_id,
2758 len: 8,
2759 count: 1,
2760 },
2761 0x3u64.as_bytes(),
2762 );
2763
2764 server.send(in_msg(
2765 MessageType::GPADL_CREATED,
2766 protocol::GpadlCreated {
2767 channel_id,
2768 gpadl_id,
2769 status: protocol::STATUS_SUCCESS,
2770 },
2771 ));
2772
2773 recv.await.unwrap().unwrap();
2774 }
2775
2776 let rpc = channel
2777 .request_send
2778 .call(ChannelRequest::TeardownGpadl, GpadlId(1));
2779
2780 check_message(
2781 server.next().await.unwrap(),
2782 protocol::GpadlTeardown {
2783 channel_id,
2784 gpadl_id: GpadlId(1),
2785 },
2786 );
2787
2788 server.send(in_msg(
2789 MessageType::RESCIND_CHANNEL_OFFER,
2790 protocol::RescindChannelOffer { channel_id },
2791 ));
2792
2793 let recv = channel.request_send.call_failable(
2794 ChannelRequest::Gpadl,
2795 GpadlRequest {
2796 id: GpadlId(4),
2797 count: 1,
2798 buf: vec![3],
2799 },
2800 );
2801
2802 check_message_with_data(
2803 server.next().await.unwrap(),
2804 protocol::GpadlHeader {
2805 channel_id,
2806 gpadl_id: GpadlId(4),
2807 len: 8,
2808 count: 1,
2809 },
2810 0x3u64.as_bytes(),
2811 );
2812
2813 server.send(in_msg(
2814 MessageType::GPADL_CREATED,
2815 protocol::GpadlCreated {
2816 channel_id,
2817 gpadl_id: GpadlId(4),
2818 status: protocol::STATUS_UNSUCCESSFUL,
2819 },
2820 ));
2821
2822 server.send(in_msg(
2823 MessageType::GPADL_TORNDOWN,
2824 protocol::GpadlTorndown {
2825 gpadl_id: GpadlId(1),
2826 },
2827 ));
2828
2829 rpc.await.unwrap();
2830 recv.await.unwrap_err();
2831
2832 channel.revoke_recv.await.unwrap();
2833
2834 let rpc = channel
2835 .request_send
2836 .call(ChannelRequest::TeardownGpadl, GpadlId(2));
2837 drop(channel.request_send);
2838
2839 check_message(
2840 server.next().await.unwrap(),
2841 protocol::GpadlTeardown {
2842 channel_id,
2843 gpadl_id: GpadlId(2),
2844 },
2845 );
2846
2847 server.send(in_msg(
2848 MessageType::GPADL_TORNDOWN,
2849 protocol::GpadlTorndown {
2850 gpadl_id: GpadlId(2),
2851 },
2852 ));
2853
2854 rpc.await.unwrap();
2855
2856 check_message(
2857 server.next().await.unwrap(),
2858 protocol::RelIdReleased { channel_id },
2859 );
2860 }
2861
2862 #[async_test]
2863 async fn test_modify_connection(driver: DefaultDriver) {
2864 let (mut server, mut client) = test_init(&driver);
2865 server.connect(&mut client).await;
2866 let call = client.access.client_request_send.call(
2867 ClientRequest::Modify,
2868 ModifyConnectionRequest {
2869 monitor_page: Some(MonitorPageGpas {
2870 child_to_parent: 5,
2871 parent_to_child: 6,
2872 }),
2873 },
2874 );
2875
2876 check_message(
2877 server.next().await.unwrap(),
2878 protocol::ModifyConnection {
2879 child_to_parent_monitor_page_gpa: 5,
2880 parent_to_child_monitor_page_gpa: 6,
2881 },
2882 );
2883
2884 server.send(in_msg(
2885 MessageType::MODIFY_CONNECTION_RESPONSE,
2886 protocol::ModifyConnectionResponse {
2887 connection_state: ConnectionState::FAILED_LOW_RESOURCES,
2888 },
2889 ));
2890
2891 let result = call.await.unwrap();
2892 assert_eq!(ConnectionState::FAILED_LOW_RESOURCES, result);
2893 }
2894
2895 #[async_test]
2896 async fn test_hvsock(driver: DefaultDriver) {
2897 let (mut server, mut client) = test_init(&driver);
2898 server.connect(&mut client).await;
2899 let request = HvsockConnectRequest {
2900 service_id: Guid::new_random(),
2901 endpoint_id: Guid::new_random(),
2902 silo_id: Guid::new_random(),
2903 hosted_silo_unaware: false,
2904 };
2905
2906 let resp = client.access().connect_hvsock(request);
2907 check_message(
2908 server.next().await.unwrap(),
2909 protocol::TlConnectRequest2 {
2910 base: protocol::TlConnectRequest {
2911 service_id: request.service_id,
2912 endpoint_id: request.endpoint_id,
2913 },
2914 silo_id: request.silo_id,
2915 },
2916 );
2917
2918 server.send(in_msg(
2920 MessageType::TL_CONNECT_REQUEST_RESULT,
2921 protocol::TlConnectResult {
2922 service_id: request.service_id,
2923 endpoint_id: request.endpoint_id,
2924 status: protocol::STATUS_CONNECTION_REFUSED,
2925 },
2926 ));
2927
2928 let result = resp.await;
2929 assert!(result.is_none());
2930 }
2931
2932 #[async_test]
2933 async fn test_synic_event_flags(driver: DefaultDriver) {
2934 let (mut server, mut client) = test_init(&driver);
2935 let connection = server.get_channels(&mut client, 5).await;
2936 let event = Event::new();
2937
2938 for _ in 0..5 {
2939 for (i, channel) in connection.offers.iter().enumerate() {
2940 let recv = channel.request_send.call(
2941 ChannelRequest::Open,
2942 OpenRequest {
2943 open_data: OpenData {
2944 target_vp: 0,
2945 ring_offset: 0,
2946 ring_gpadl_id: GpadlId(0),
2947 event_flag: 0,
2948 connection_id: 0,
2949 user_data: UserDefinedData::new_zeroed(),
2950 },
2951 incoming_event: Some(event.clone()),
2952 use_vtl2_connection_id: false,
2953 },
2954 );
2955
2956 let expected_event_flag = i as u16 + 1;
2957
2958 check_message(
2959 server.next().await.unwrap(),
2960 protocol::OpenChannel2 {
2961 open_channel: protocol::OpenChannel {
2962 channel_id: channel.offer.channel_id,
2963 open_id: 0,
2964 ring_buffer_gpadl_id: GpadlId(0),
2965 target_vp: 0,
2966 downstream_ring_buffer_page_offset: 0,
2967 user_data: UserDefinedData::new_zeroed(),
2968 },
2969 connection_id: 0,
2970 event_flag: expected_event_flag,
2971 flags: OpenChannelFlags::new().with_redirect_interrupt(true),
2972 },
2973 );
2974
2975 server.send(in_msg(
2976 MessageType::OPEN_CHANNEL_RESULT,
2977 protocol::OpenResult {
2978 channel_id: channel.offer.channel_id,
2979 open_id: 0,
2980 status: protocol::STATUS_SUCCESS as u32,
2981 },
2982 ));
2983
2984 let output = recv.await.unwrap().unwrap();
2985 assert_eq!(output.redirected_event_flag, Some(expected_event_flag));
2986 }
2987
2988 for (i, channel) in connection.offers.iter().enumerate() {
2989 channel
2992 .request_send
2993 .call(ChannelRequest::Close, ())
2994 .await
2995 .unwrap();
2996
2997 check_message(
2998 server.next().await.unwrap(),
2999 protocol::CloseChannel {
3000 channel_id: ChannelId(i as u32),
3001 },
3002 );
3003 }
3004 }
3005 }
3006
3007 #[async_test]
3008 async fn test_revoke(driver: DefaultDriver) {
3009 let (mut server, mut client) = test_init(&driver);
3010 let channel = server.get_channel(&mut client).await;
3011
3012 server.send(in_msg(
3013 MessageType::RESCIND_CHANNEL_OFFER,
3014 protocol::RescindChannelOffer {
3015 channel_id: ChannelId(0),
3016 },
3017 ));
3018
3019 channel.revoke_recv.await.unwrap();
3020
3021 channel
3022 .request_send
3023 .call_failable(
3024 ChannelRequest::Open,
3025 OpenRequest {
3026 open_data: OpenData {
3027 target_vp: 0,
3028 ring_offset: 0,
3029 ring_gpadl_id: GpadlId(0),
3030 event_flag: 0,
3031 connection_id: 0,
3032 user_data: UserDefinedData::new_zeroed(),
3033 },
3034 incoming_event: None,
3035 use_vtl2_connection_id: false,
3036 },
3037 )
3038 .await
3039 .unwrap_err();
3040 }
3041
3042 #[async_test]
3043 #[should_panic(expected = "channel should not exist")]
3044 async fn test_reoffer_in_use_rel_id(driver: DefaultDriver) {
3045 let (mut server, mut client) = test_init(&driver);
3046 let mut connection = server.get_channels(&mut client, 1).await;
3047 let [channel] = connection.offers.try_into().unwrap();
3048
3049 server.send(in_msg(
3050 MessageType::RESCIND_CHANNEL_OFFER,
3051 protocol::RescindChannelOffer {
3052 channel_id: ChannelId(0),
3053 },
3054 ));
3055
3056 channel.revoke_recv.await.unwrap();
3057
3058 let offer = protocol::OfferChannel {
3060 interface_id: Guid::new_random(),
3061 instance_id: Guid::new_random(),
3062 rsvd: [0; 4],
3063 flags: OfferFlags::new(),
3064 mmio_megabytes: 0,
3065 user_defined: UserDefinedData::new_zeroed(),
3066 subchannel_index: 0,
3067 mmio_megabytes_optional: 0,
3068 channel_id: ChannelId(0),
3069 monitor_id: 0,
3070 monitor_allocated: 0,
3071 is_dedicated: 0,
3072 connection_id: 0,
3073 };
3074
3075 server.send(in_msg(MessageType::OFFER_CHANNEL, offer));
3076
3077 connection.offer_recv.next().await;
3078 }
3079
3080 #[async_test]
3081 async fn test_revoke_release_and_reoffer(driver: DefaultDriver) {
3082 let (mut server, mut client) = test_init(&driver);
3083 let mut connection = server.get_channels(&mut client, 1).await;
3084 let [channel] = connection.offers.try_into().unwrap();
3085
3086 server.send(in_msg(
3087 MessageType::RESCIND_CHANNEL_OFFER,
3088 protocol::RescindChannelOffer {
3089 channel_id: ChannelId(0),
3090 },
3091 ));
3092
3093 channel.revoke_recv.await.unwrap();
3094 drop(channel.request_send);
3095
3096 check_message(
3097 server.next().await.unwrap(),
3098 protocol::RelIdReleased {
3099 channel_id: ChannelId(0),
3100 },
3101 );
3102
3103 let offer = protocol::OfferChannel {
3104 interface_id: Guid::new_random(),
3105 instance_id: Guid::new_random(),
3106 rsvd: [0; 4],
3107 flags: OfferFlags::new(),
3108 mmio_megabytes: 0,
3109 user_defined: UserDefinedData::new_zeroed(),
3110 subchannel_index: 0,
3111 mmio_megabytes_optional: 0,
3112 channel_id: ChannelId(0),
3113 monitor_id: 0,
3114 monitor_allocated: 0,
3115 is_dedicated: 0,
3116 connection_id: 0,
3117 };
3118
3119 server.send(in_msg(MessageType::OFFER_CHANNEL, offer));
3120
3121 connection.offer_recv.next().await.unwrap();
3122 }
3123
3124 #[async_test]
3125 async fn test_release_revoke_and_reoffer(driver: DefaultDriver) {
3126 let (mut server, mut client) = test_init(&driver);
3127 let mut connection = server.get_channels(&mut client, 1).await;
3128 let [channel] = connection.offers.try_into().unwrap();
3129
3130 let open = channel.request_send.call_failable(
3131 ChannelRequest::Open,
3132 OpenRequest {
3133 open_data: OpenData {
3134 target_vp: 0,
3135 ring_offset: 0,
3136 ring_gpadl_id: GpadlId(0),
3137 event_flag: 0,
3138 connection_id: 0,
3139 user_data: UserDefinedData::new_zeroed(),
3140 },
3141 incoming_event: None,
3142 use_vtl2_connection_id: false,
3143 },
3144 );
3145
3146 let server_open = async {
3147 check_message(
3148 server.next().await.unwrap(),
3149 protocol::OpenChannel2 {
3150 open_channel: protocol::OpenChannel {
3151 channel_id: ChannelId(0),
3152 open_id: 0,
3153 ring_buffer_gpadl_id: GpadlId(0),
3154 target_vp: 0,
3155 downstream_ring_buffer_page_offset: 0,
3156 user_data: UserDefinedData::new_zeroed(),
3157 },
3158 connection_id: 0,
3159 event_flag: 0,
3160 flags: Default::default(),
3161 },
3162 );
3163 server.send(in_msg(
3164 MessageType::OPEN_CHANNEL_RESULT,
3165 protocol::OpenResult {
3166 channel_id: ChannelId(0),
3167 open_id: 0,
3168 status: protocol::STATUS_SUCCESS as u32,
3169 },
3170 ));
3171 };
3172
3173 (open, server_open).join().await.0.unwrap();
3174
3175 drop(channel);
3177
3178 check_message(
3179 server.next().await.unwrap(),
3180 protocol::CloseChannel {
3181 channel_id: ChannelId(0),
3182 },
3183 );
3184
3185 server.send(in_msg(
3186 MessageType::RESCIND_CHANNEL_OFFER,
3187 protocol::RescindChannelOffer {
3188 channel_id: ChannelId(0),
3189 },
3190 ));
3191
3192 check_message(
3194 server.next().await.unwrap(),
3195 protocol::RelIdReleased {
3196 channel_id: ChannelId(0),
3197 },
3198 );
3199
3200 let offer = protocol::OfferChannel {
3201 interface_id: Guid::new_random(),
3202 instance_id: Guid::new_random(),
3203 rsvd: [0; 4],
3204 flags: OfferFlags::new(),
3205 mmio_megabytes: 0,
3206 user_defined: UserDefinedData::new_zeroed(),
3207 subchannel_index: 0,
3208 mmio_megabytes_optional: 0,
3209 channel_id: ChannelId(0),
3210 monitor_id: 0,
3211 monitor_allocated: 0,
3212 is_dedicated: 0,
3213 connection_id: 0,
3214 };
3215
3216 server.send(in_msg(MessageType::OFFER_CHANNEL, offer));
3217
3218 connection.offer_recv.next().await.unwrap();
3220 }
3221}