1#![expect(missing_docs)]
5#![forbid(unsafe_code)]
6
7mod hvsock;
8pub mod saved_state;
9
10pub use self::saved_state::SavedState;
11use anyhow::Context as _;
12use anyhow::Result;
13use futures::FutureExt;
14use futures::StreamExt;
15use futures::future::OptionFuture;
16use futures::stream::SelectAll;
17use futures_concurrency::future::Race;
18use guid::Guid;
19use inspect::Inspect;
20use mesh::rpc::FailableRpc;
21use mesh::rpc::Rpc;
22use mesh::rpc::RpcSend;
23use pal_async::task::Spawn;
24use pal_async::task::Task;
25use pal_event::Event;
26use std::collections::HashMap;
27use std::collections::VecDeque;
28use std::collections::hash_map;
29use std::convert::TryInto;
30use std::future::Future;
31use std::future::poll_fn;
32use std::ops::Deref;
33use std::ops::DerefMut;
34use std::pin::pin;
35use std::sync::Arc;
36use std::task::Context;
37use std::task::Poll;
38use thiserror::Error;
39use vmbus_async::async_dgram::AsyncRecv;
40use vmbus_async::async_dgram::AsyncRecvExt;
41use vmbus_channel::bus::GpadlRequest;
42use vmbus_channel::bus::ModifyRequest;
43use vmbus_channel::bus::OpenData;
44use vmbus_channel::gpadl::GpadlId;
45use vmbus_core::HvsockConnectRequest;
46use vmbus_core::OutgoingMessage;
47use vmbus_core::TaggedStream;
48use vmbus_core::VersionInfo;
49use vmbus_core::protocol;
50use vmbus_core::protocol::ChannelId;
51use vmbus_core::protocol::ConnectionState;
52use vmbus_core::protocol::FeatureFlags;
53use vmbus_core::protocol::Message;
54use vmbus_core::protocol::OpenChannelFlags;
55use vmbus_core::protocol::Version;
56use vmcore::interrupt::Interrupt;
57use vmcore::synic::MonitorPageGpas;
58use zerocopy::Immutable;
59use zerocopy::IntoBytes;
60use zerocopy::KnownLayout;
61
62const SINT: u8 = 2;
63const VTL: u8 = 0;
64const SUPPORTED_VERSIONS: &[Version] = &[Version::Iron, Version::Copper];
65const SUPPORTED_FEATURE_FLAGS: FeatureFlags = FeatureFlags::new()
66 .with_guest_specified_signal_parameters(true)
67 .with_channel_interrupt_redirection(true)
68 .with_modify_connection(true)
69 .with_client_id(true)
70 .with_pause_resume(true);
71
72pub trait SynicEventClient: Send + Sync {
74 fn map_event(&self, event_flag: u16, event: &Event) -> std::io::Result<()>;
76
77 fn unmap_event(&self, event_flag: u16);
79
80 fn signal_event(&self, connection_id: u32, event_flag: u16) -> std::io::Result<()>;
82}
83
84pub trait VmbusMessageSource: AsyncRecv + Send {
86 fn pause_message_stream(&mut self) {}
89
90 fn resume_message_stream(&mut self) {}
92}
93
94pub trait PollPostMessage: Send {
95 fn poll_post_message(
96 &mut self,
97 cx: &mut Context<'_>,
98 connection_id: u32,
99 typ: u32,
100 msg: &[u8],
101 ) -> Poll<()>;
102}
103
104pub struct VmbusClient {
105 task_send: mesh::Sender<TaskRequest>,
106 access: VmbusClientAccess,
107 task: Task<ClientTask>,
108}
109
110#[derive(Debug, thiserror::Error)]
111pub enum ConnectError {
112 #[error("invalid state to connect to the server")]
113 InvalidState,
114 #[error("no supported protocol versions")]
115 NoSupportedVersions,
116 #[error("failed to connect to the server: {0:?}")]
117 FailedToConnect(ConnectionState),
118}
119
120#[derive(Clone)]
121pub struct VmbusClientAccess {
122 client_request_send: mesh::Sender<ClientRequest>,
123}
124
125pub struct VmbusClientBuilder {
127 event_client: Arc<dyn SynicEventClient>,
128 msg_source: Box<dyn VmbusMessageSource>,
129 msg_client: Box<dyn PollPostMessage>,
130}
131
132impl VmbusClientBuilder {
133 pub fn new(
135 event_client: impl SynicEventClient + 'static,
136 msg_source: impl VmbusMessageSource + 'static,
137 msg_client: impl PollPostMessage + 'static,
138 ) -> Self {
139 Self {
140 event_client: Arc::new(event_client),
141 msg_source: Box::new(msg_source),
142 msg_client: Box::new(msg_client),
143 }
144 }
145
146 pub fn build(self, spawner: &impl Spawn) -> VmbusClient {
148 let (task_send, task_recv) = mesh::channel();
149 let (client_request_send, client_request_recv) = mesh::channel();
150
151 let inner = ClientTaskInner {
152 messages: OutgoingMessages {
153 poster: self.msg_client,
154 queued: VecDeque::new(),
155 state: OutgoingMessageState::Paused,
156 },
157 teardown_gpadls: HashMap::new(),
158 channel_requests: SelectAll::new(),
159 synic: SynicState {
160 event_flag_state: Vec::new(),
161 event_client: self.event_client,
162 },
163 };
164
165 let mut task = ClientTask {
166 inner,
167 channels: ChannelList::default(),
168 task_recv,
169 running: false,
170 msg_source: self.msg_source,
171 client_request_recv,
172 state: ClientState::Disconnected,
173 modify_request: None,
174 hvsock_tracker: hvsock::HvsockRequestTracker::new(),
175 };
176
177 let task = spawner.spawn("vmbus client", async move {
178 task.run().await;
179 task
180 });
181
182 VmbusClient {
183 access: VmbusClientAccess {
184 client_request_send,
185 },
186 task_send,
187 task,
188 }
189 }
190}
191
192impl VmbusClient {
193 pub async fn connect(
196 &mut self,
197 target_message_vp: u32,
198 monitor_page: Option<MonitorPageGpas>,
199 client_id: Guid,
200 ) -> Result<ConnectResult, ConnectError> {
201 let request = ConnectRequest {
202 target_message_vp,
203 monitor_page,
204 client_id,
205 };
206
207 self.access
208 .client_request_send
209 .call(ClientRequest::Connect, request)
210 .await
211 .unwrap()
212 }
213
214 pub async fn unload(self) {
215 self.access
216 .client_request_send
217 .call(ClientRequest::Unload, ())
218 .await
219 .unwrap();
220
221 self.sever().await;
222 }
223
224 pub fn access(&self) -> &VmbusClientAccess {
225 &self.access
226 }
227
228 pub fn start(&mut self) {
229 self.task_send.send(TaskRequest::Start);
230 }
231
232 pub async fn stop(&mut self) {
233 self.task_send
234 .call(TaskRequest::Stop, ())
235 .await
236 .expect("Failed to send stop request");
237 }
238
239 pub async fn save(&self) -> SavedState {
240 self.task_send
241 .call(TaskRequest::Save, ())
242 .await
243 .expect("Failed to send save request")
244 }
245
246 pub async fn restore(
247 &mut self,
248 state: SavedState,
249 ) -> Result<Option<ConnectResult>, RestoreError> {
250 self.task_send
251 .call(TaskRequest::Restore, state)
252 .await
253 .expect("Failed to send restore request")
254 }
255
256 pub async fn post_restore(&mut self) {
257 self.task_send
258 .call(TaskRequest::PostRestore, ())
259 .await
260 .expect("Failed to send post-restore request");
261 }
262
263 async fn sever(self) -> VmbusClientBuilder {
264 drop(self.task_send);
265 let task = self.task.await;
266 VmbusClientBuilder {
267 event_client: task.inner.synic.event_client,
268 msg_source: task.msg_source,
269 msg_client: task.inner.messages.poster,
270 }
271 }
272}
273
274impl Inspect for VmbusClient {
275 fn inspect(&self, req: inspect::Request<'_>) {
276 self.task_send.send(TaskRequest::Inspect(req.defer()));
277 }
278}
279
280#[derive(Debug)]
281pub struct ConnectResult {
282 pub version: VersionInfo,
283 pub offers: Vec<OfferInfo>,
284 pub offer_recv: mesh::Receiver<OfferInfo>,
285}
286
287impl VmbusClientAccess {
288 pub async fn modify(&self, request: ModifyConnectionRequest) -> ConnectionState {
289 self.client_request_send
290 .call(ClientRequest::Modify, request)
291 .await
292 .expect("Failed to send modify request")
293 }
294
295 pub fn connect_hvsock(
296 &self,
297 request: HvsockConnectRequest,
298 ) -> impl Future<Output = Option<OfferInfo>> + use<> {
299 self.client_request_send
300 .call(ClientRequest::HvsockConnect, request)
301 .map(|r| r.ok().flatten())
302 }
303}
304
305#[derive(Debug)]
306pub struct OpenRequest {
307 pub open_data: OpenData,
308 pub incoming_event: Option<Event>,
309 pub use_vtl2_connection_id: bool,
310}
311
312#[derive(Debug)]
313pub struct RestoreRequest {
314 pub incoming_event: Option<Event>,
315 pub redirected_event_flag: Option<u16>,
317 pub connection_id: u32,
319}
320
321pub enum ChannelRequest {
323 Open(FailableRpc<OpenRequest, OpenOutput>),
324 Restore(FailableRpc<RestoreRequest, OpenOutput>),
325 Close(Rpc<(), ()>),
326 Gpadl(FailableRpc<GpadlRequest, ()>),
327 TeardownGpadl(Rpc<GpadlId, ()>),
328 Modify(Rpc<ModifyRequest, i32>),
329}
330
331#[derive(Debug)]
332pub struct OpenOutput {
333 pub redirected_event_flag: Option<u16>,
335 pub guest_to_host_signal: Interrupt,
336}
337
338impl std::fmt::Display for ChannelRequest {
339 fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
340 let s = match self {
341 ChannelRequest::Open(_) => "Open",
342 ChannelRequest::Close(_) => "Close",
343 ChannelRequest::Restore(_) => "Restore",
344 ChannelRequest::Gpadl(_) => "Gpadl",
345 ChannelRequest::TeardownGpadl(_) => "TeardownGpadl",
346 ChannelRequest::Modify(_) => "Modify",
347 };
348 fmt.pad(s)
349 }
350}
351
352#[derive(Debug, Error)]
353pub enum RestoreError {
354 #[error("unsupported protocol version {0:#x}")]
355 UnsupportedVersion(u32),
356
357 #[error("unsupported feature flags {0:#x}")]
358 UnsupportedFeatureFlags(u32),
359
360 #[error("duplicate channel id {0}")]
361 DuplicateChannelId(u32),
362
363 #[error("duplicate gpadl id {0}")]
364 DuplicateGpadlId(u32),
365
366 #[error("gpadl for unknown channel id {0}")]
367 GpadlForUnknownChannelId(u32),
368
369 #[error("invalid pending message")]
370 InvalidPendingMessage(#[source] vmbus_core::MessageTooLarge),
371
372 #[error("failed to offer restored channel")]
373 OfferFailed(#[source] anyhow::Error),
374}
375
376#[derive(Debug, Inspect)]
379pub struct OfferInfo {
380 pub offer: protocol::OfferChannel,
381 #[inspect(skip)]
382 pub request_send: mesh::Sender<ChannelRequest>,
383 #[inspect(skip)]
384 pub revoke_recv: mesh::OneshotReceiver<()>,
385}
386
387#[derive(Debug)]
388enum ClientRequest {
389 Connect(Rpc<ConnectRequest, Result<ConnectResult, ConnectError>>),
390 Unload(Rpc<(), ()>),
391 Modify(Rpc<ModifyConnectionRequest, ConnectionState>),
392 HvsockConnect(Rpc<HvsockConnectRequest, Option<OfferInfo>>),
393}
394
395impl std::fmt::Display for ClientRequest {
396 fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
397 let s = match self {
398 ClientRequest::Connect(..) => "Connect",
399 ClientRequest::Unload { .. } => "Unload",
400 ClientRequest::Modify(..) => "Modify",
401 ClientRequest::HvsockConnect(..) => "HvsockConnect",
402 };
403 fmt.pad(s)
404 }
405}
406
407enum TaskRequest {
408 Inspect(inspect::Deferred),
409 Save(Rpc<(), SavedState>),
410 Restore(Rpc<SavedState, Result<Option<ConnectResult>, RestoreError>>),
411 PostRestore(Rpc<(), ()>),
412 Start,
413 Stop(Rpc<(), ()>),
414}
415
416#[derive(Inspect)]
420#[inspect(external_tag)]
421enum ClientState {
422 Disconnected,
424 Connecting {
426 version: Version,
427 #[inspect(skip)]
428 rpc: Rpc<ConnectRequest, Result<ConnectResult, ConnectError>>,
429 },
430 Connected {
432 version: VersionInfo,
433 #[inspect(skip)]
434 offer_send: mesh::Sender<OfferInfo>,
435 },
436 RequestingOffers {
438 version: VersionInfo,
439 #[inspect(skip)]
440 rpc: Rpc<(), Result<ConnectResult, ConnectError>>,
441 #[inspect(skip)]
442 offers: Vec<OfferInfo>,
443 },
444 Disconnecting {
446 version: VersionInfo,
447 #[inspect(skip)]
448 rpc: Rpc<(), ()>,
449 },
450}
451
452impl ClientState {
453 fn get_version(&self) -> Option<VersionInfo> {
454 match self {
455 ClientState::Connected { version, .. } => Some(*version),
456 ClientState::RequestingOffers { version, .. } => Some(*version),
457 ClientState::Disconnecting { version, .. } => Some(*version),
458 ClientState::Disconnected | ClientState::Connecting { .. } => None,
459 }
460 }
461}
462
463impl std::fmt::Display for ClientState {
464 fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
465 let s = match self {
466 ClientState::Disconnected => "Disconnected",
467 ClientState::Connecting { .. } => "Connecting",
468 ClientState::Connected { .. } => "Connected",
469 ClientState::RequestingOffers { .. } => "RequestingOffers",
470 ClientState::Disconnecting { .. } => "Disconnecting",
471 };
472 fmt.pad(s)
473 }
474}
475
476#[derive(Copy, Clone, Debug, Default)]
477struct ConnectRequest {
478 target_message_vp: u32,
479 monitor_page: Option<MonitorPageGpas>,
480 client_id: Guid,
481}
482
483#[derive(Copy, Clone, Debug, Default)]
484pub struct ModifyConnectionRequest {
485 pub monitor_page: Option<MonitorPageGpas>,
486}
487
488impl From<ModifyConnectionRequest> for protocol::ModifyConnection {
489 fn from(value: ModifyConnectionRequest) -> Self {
490 let monitor_page = value.monitor_page.unwrap_or_default();
491
492 Self {
493 parent_to_child_monitor_page_gpa: monitor_page.parent_to_child,
494 child_to_parent_monitor_page_gpa: monitor_page.child_to_parent,
495 }
496 }
497}
498
499#[derive(Debug, Inspect)]
503#[inspect(external_tag)]
504enum ChannelState {
505 Offered,
507 Opening {
509 connection_id: u32,
510 redirected_event_flag: Option<u16>,
511 #[inspect(skip)]
512 redirected_event: Option<Event>,
513 #[inspect(skip)]
514 rpc: FailableRpc<(), OpenOutput>,
515 },
516 Restored,
518 Opened {
520 connection_id: u32,
521 redirected_event_flag: Option<u16>,
522 #[inspect(skip)]
523 redirected_event: Option<Event>,
524 },
525 Revoked,
527}
528
529impl std::fmt::Display for ChannelState {
530 fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
531 let s = match self {
532 ChannelState::Opening { .. } => "Opening",
533 ChannelState::Offered => "Offered",
534 ChannelState::Opened { .. } => "Opened",
535 ChannelState::Restored => "Restored",
536 ChannelState::Revoked => "Revoked",
537 };
538 fmt.pad(s)
539 }
540}
541
542#[derive(Debug, Inspect)]
543struct Channel {
544 offer: protocol::OfferChannel,
545 #[inspect(skip)]
547 revoke_send: Option<mesh::OneshotSender<()>>,
548 state: ChannelState,
549 #[inspect(with = "|x| x.is_some()")]
550 modify_response_send: Option<Rpc<(), i32>>,
551 #[inspect(with = "|x| inspect::iter_by_key(x).map_key(|x| x.0)")]
552 gpadls: HashMap<GpadlId, GpadlState>,
553 is_client_released: bool,
554}
555
556impl Channel {
557 fn pending_request(&self) -> Option<&'static str> {
558 if self.modify_response_send.is_some() {
559 return Some("modify");
560 }
561 self.gpadls.iter().find_map(|(_, gpadl)| match gpadl {
562 GpadlState::Offered(_) => Some("creating gpadl"),
563 GpadlState::Created => None,
564 GpadlState::TearingDown { .. } => Some("tearing down gpadl"),
565 })
566 }
567}
568
569#[derive(Inspect)]
570struct ClientTask {
571 #[inspect(flatten)]
572 inner: ClientTaskInner,
573 channels: ChannelList,
574 state: ClientState,
575 hvsock_tracker: hvsock::HvsockRequestTracker,
576 running: bool,
577 #[inspect(with = "|x| x.is_some()")]
578 modify_request: Option<Rpc<ModifyConnectionRequest, ConnectionState>>,
579 #[inspect(skip)]
580 msg_source: Box<dyn VmbusMessageSource>,
581 #[inspect(skip)]
582 task_recv: mesh::Receiver<TaskRequest>,
583 #[inspect(skip)]
584 client_request_recv: mesh::Receiver<ClientRequest>,
585}
586
587impl ClientTask {
588 fn handle_initiate_contact(
589 &mut self,
590 rpc: Rpc<ConnectRequest, Result<ConnectResult, ConnectError>>,
591 version: Version,
592 ) {
593 let ClientState::Disconnected = self.state else {
594 tracing::warn!(client_state = %self.state, "invalid client state for InitiateContact");
595 rpc.complete(Err(ConnectError::InvalidState));
596 return;
597 };
598 let feature_flags = if version >= Version::Copper {
599 SUPPORTED_FEATURE_FLAGS
600 } else {
601 FeatureFlags::new()
602 };
603
604 let request = rpc.input();
605
606 tracing::debug!(version = ?version, ?feature_flags, "VmBus client connecting");
607 let target_info = protocol::TargetInfo::new()
608 .with_sint(SINT)
609 .with_vtl(VTL)
610 .with_feature_flags(feature_flags.into());
611 let monitor_page = request.monitor_page.unwrap_or_default();
612 let msg = protocol::InitiateContact2 {
613 initiate_contact: protocol::InitiateContact {
614 version_requested: version as u32,
615 target_message_vp: request.target_message_vp,
616 interrupt_page_or_target_info: target_info.into(),
617 parent_to_child_monitor_page_gpa: monitor_page.parent_to_child,
618 child_to_parent_monitor_page_gpa: monitor_page.child_to_parent,
619 },
620 client_id: request.client_id,
621 };
622
623 self.state = ClientState::Connecting { version, rpc };
624 if version < Version::Copper {
625 self.inner.messages.send(&msg.initiate_contact)
626 } else {
627 self.inner.messages.send(&msg);
628 }
629 }
630
631 fn handle_unload(&mut self, rpc: Rpc<(), ()>) {
632 tracing::debug!(%self.state, "VmBus client disconnecting");
633 self.state = ClientState::Disconnecting {
634 version: self.state.get_version().expect("invalid state for unload"),
635 rpc,
636 };
637
638 self.inner.messages.send(&protocol::Unload {});
639 }
640
641 fn handle_modify(&mut self, request: Rpc<ModifyConnectionRequest, ConnectionState>) {
642 if !matches!(self.state, ClientState::Connected { version, .. }
643 if version.feature_flags.modify_connection())
644 {
645 tracing::warn!("ModifyConnection not supported");
646 request.complete(ConnectionState::FAILED_UNKNOWN_FAILURE);
647 return;
648 }
649
650 if self.modify_request.is_some() {
651 tracing::warn!("Duplicate ModifyConnection request");
652 request.complete(ConnectionState::FAILED_UNKNOWN_FAILURE);
653 return;
654 }
655
656 let message = protocol::ModifyConnection::from(*request.input());
657 self.modify_request = Some(request);
658 self.inner.messages.send(&message);
659 }
660
661 fn handle_tl_connect(&mut self, rpc: Rpc<HvsockConnectRequest, Option<OfferInfo>>) {
662 let message = protocol::TlConnectRequest2::from(*rpc.input());
666 self.hvsock_tracker.add_request(rpc);
667 self.inner.messages.send(&message);
668 }
669
670 fn handle_client_request(&mut self, request: ClientRequest) {
671 match request {
672 ClientRequest::Connect(rpc) => {
673 self.handle_initiate_contact(rpc, *SUPPORTED_VERSIONS.last().unwrap());
674 }
675 ClientRequest::Unload(rpc) => {
676 self.handle_unload(rpc);
677 }
678 ClientRequest::Modify(request) => self.handle_modify(request),
679 ClientRequest::HvsockConnect(request) => self.handle_tl_connect(request),
680 }
681 }
682
683 fn handle_version_response(&mut self, msg: protocol::VersionResponse2) {
684 let old_state = std::mem::replace(&mut self.state, ClientState::Disconnected);
685 let ClientState::Connecting { version, rpc } = old_state else {
686 self.state = old_state;
687 tracing::warn!(
688 client_state = %self.state,
689 "invalid client state to handle VersionResponse"
690 );
691 return;
692 };
693 if msg.version_response.version_supported > 0 {
694 if msg.version_response.connection_state != ConnectionState::SUCCESSFUL {
695 rpc.complete(Err(ConnectError::FailedToConnect(
696 msg.version_response.connection_state,
697 )));
698 return;
699 }
700
701 let feature_flags = if version >= Version::Copper {
702 FeatureFlags::from(msg.supported_features)
703 } else {
704 FeatureFlags::new()
705 };
706
707 let version = VersionInfo {
708 version,
709 feature_flags,
710 };
711
712 self.inner.messages.send(&protocol::RequestOffers {});
713 self.state = ClientState::RequestingOffers {
714 version,
715 rpc: rpc.split().1,
716 offers: Vec::new(),
717 };
718 tracing::info!(?version, "VmBus client connected, requesting offers");
719 } else {
720 let index = SUPPORTED_VERSIONS
721 .iter()
722 .position(|v| *v == version)
723 .unwrap();
724
725 if index == 0 {
726 rpc.complete(Err(ConnectError::NoSupportedVersions));
727 return;
728 }
729 let next_version = SUPPORTED_VERSIONS[index - 1];
730 tracing::debug!(
731 version = version as u32,
732 next_version = next_version as u32,
733 "Unsupported version, retrying"
734 );
735 self.handle_initiate_contact(rpc, next_version);
736 }
737 }
738
739 fn create_channel(&mut self, offer: protocol::OfferChannel) -> Result<OfferInfo> {
740 self.create_channel_core(offer, ChannelState::Offered)
741 }
742
743 fn create_channel_core(
744 &mut self,
745 offer: protocol::OfferChannel,
746 state: ChannelState,
747 ) -> Result<OfferInfo> {
748 if self.channels.0.contains_key(&offer.channel_id) {
749 anyhow::bail!("channel {:?} exists", offer.channel_id);
750 }
751 let (request_send, request_recv) = mesh::channel();
752 let (revoke_send, revoke_recv) = mesh::oneshot();
753
754 self.channels.0.insert(
755 offer.channel_id,
756 Channel {
757 revoke_send: Some(revoke_send),
758 offer,
759 state,
760 modify_response_send: None,
761 gpadls: HashMap::new(),
762 is_client_released: false,
763 },
764 );
765
766 self.inner
767 .channel_requests
768 .push(TaggedStream::new(offer.channel_id, request_recv));
769
770 Ok(OfferInfo {
771 offer,
772 revoke_recv,
773 request_send,
774 })
775 }
776
777 fn handle_offer(&mut self, offer: protocol::OfferChannel) {
778 let offer_info = self
779 .create_channel(offer)
780 .expect("channel should not exist");
781
782 tracing::info!(
783 state = %self.state,
784 channel_id = offer.channel_id.0,
785 interface_id = %offer.interface_id,
786 instance_id = %offer.instance_id,
787 subchannel_index = offer.subchannel_index,
788 "received offer");
789
790 if let Some(offer) = self.hvsock_tracker.check_offer(&offer_info.offer) {
791 offer.complete(Some(offer_info));
792 } else {
793 match &mut self.state {
794 ClientState::Connected { offer_send, .. } => {
795 offer_send.send(offer_info);
796 }
797 ClientState::RequestingOffers { offers, .. } => {
798 offers.push(offer_info);
799 }
800 state => unreachable!("invalid client state for OfferChannel: {state}"),
801 }
802 }
803 }
804
805 fn handle_rescind(&mut self, rescind: protocol::RescindChannelOffer) -> TriedRelease {
806 tracing::info!(state = %self.state, channel_id = rescind.channel_id.0, "received rescind");
807
808 let mut channel = self.channels.get_mut(rescind.channel_id);
809 let event_flag = match std::mem::replace(&mut channel.state, ChannelState::Revoked) {
810 ChannelState::Offered => None,
811 ChannelState::Opening {
812 connection_id: _,
813 redirected_event_flag,
814 redirected_event: _,
815 rpc,
816 } => {
817 rpc.fail(anyhow::anyhow!("channel revoked"));
818 redirected_event_flag
819 }
820 ChannelState::Restored => None,
821 ChannelState::Opened {
822 connection_id: _,
823 redirected_event_flag,
824 redirected_event: _,
825 } => redirected_event_flag,
826 ChannelState::Revoked => {
827 panic!("channel id {:?} already revoked", rescind.channel_id);
828 }
829 };
830 if let Some(event_flag) = event_flag {
831 self.inner.synic.free_event_flag(event_flag);
832 }
833
834 channel.revoke_send.take().unwrap().send(());
836
837 channel.try_release(&mut self.inner.messages)
838 }
839
840 fn handle_offers_delivered(&mut self) {
841 match std::mem::replace(&mut self.state, ClientState::Disconnected) {
842 ClientState::RequestingOffers {
843 version,
844 rpc,
845 offers,
846 } => {
847 tracing::info!(version = ?version, "VmBus client connected, offers delivered");
848 let (offer_send, offer_recv) = mesh::channel();
849 self.state = ClientState::Connected {
850 version,
851 offer_send,
852 };
853 rpc.complete(Ok(ConnectResult {
854 version,
855 offers,
856 offer_recv,
857 }));
858 }
859 state => {
860 tracing::warn!(client_state = %state, "invalid client state for OffersDelivered");
861 self.state = state;
862 }
863 }
864 }
865
866 fn handle_gpadl_created(&mut self, request: protocol::GpadlCreated) -> TriedRelease {
867 let mut channel = self.channels.get_mut(request.channel_id);
868 let Some(gpadl_state) = channel.gpadls.get_mut(&request.gpadl_id) else {
869 panic!("GpadlCreated for unknown gpadl {:#x}", request.gpadl_id.0);
870 };
871
872 let rpc = match std::mem::replace(gpadl_state, GpadlState::Created) {
873 GpadlState::Offered(rpc) => rpc,
874 old_state => {
875 panic!(
876 "invalid state {old_state:?} for gpadl {:#x}:{:#x}",
877 request.channel_id.0, request.gpadl_id.0
878 );
879 }
880 };
881
882 let gpadl_created = request.status == protocol::STATUS_SUCCESS;
883 if gpadl_created {
884 rpc.complete(Ok(()));
885 } else {
886 channel.gpadls.remove(&request.gpadl_id).unwrap();
887 rpc.fail(anyhow::anyhow!(
888 "gpadl creation failed: {:#x}",
889 request.status
890 ));
891 };
892 channel.try_release(&mut self.inner.messages)
893 }
894
895 fn handle_open_result(&mut self, result: protocol::OpenResult) {
896 tracing::debug!(
897 channel_id = result.channel_id.0,
898 result = result.status,
899 "received open result"
900 );
901
902 let mut channel = self.channels.get_mut(result.channel_id);
903
904 let channel_opened = result.status == protocol::STATUS_SUCCESS as u32;
905 let old_state = std::mem::replace(&mut channel.state, ChannelState::Offered);
906 let ChannelState::Opening {
907 connection_id,
908 redirected_event_flag,
909 redirected_event,
910 rpc,
911 } = old_state
912 else {
913 tracing::warn!(
914 old_state = ?channel.state,
915 channel_opened,
916 "invalid state for open result"
917 );
918 channel.state = old_state;
919 return;
920 };
921
922 if !channel_opened {
923 if let Some(event_flag) = redirected_event_flag {
924 self.inner.synic.free_event_flag(event_flag);
925 }
926 rpc.fail(anyhow::anyhow!("open failed: {:#x}", result.status));
927 return;
928 }
929
930 channel.state = ChannelState::Opened {
931 connection_id,
932 redirected_event_flag,
933 redirected_event,
934 };
935
936 rpc.complete(Ok(OpenOutput {
937 redirected_event_flag,
938 guest_to_host_signal: self.inner.synic.guest_to_host_interrupt(connection_id),
939 }));
940 }
941
942 fn handle_gpadl_torndown(&mut self, request: protocol::GpadlTorndown) -> TriedRelease {
943 let Some(channel_id) = self.inner.teardown_gpadls.remove(&request.gpadl_id) else {
944 panic!("gpadl {:#x} not in teardown list", request.gpadl_id.0);
945 };
946
947 tracing::debug!(
948 gpadl_id = request.gpadl_id.0,
949 channel_id = channel_id.0,
950 "Received GpadlTorndown"
951 );
952
953 let mut channel = self.channels.get_mut(channel_id);
954 let gpadl_state = channel
955 .gpadls
956 .remove(&request.gpadl_id)
957 .expect("gpadl validated above");
958
959 let GpadlState::TearingDown { rpcs } = gpadl_state else {
960 panic!("gpadl should be tearing down if in teardown list, state = {gpadl_state:?}");
961 };
962
963 for rpc in rpcs {
964 rpc.complete(());
965 }
966 channel.try_release(&mut self.inner.messages)
967 }
968
969 fn handle_unload_complete(&mut self) {
970 match std::mem::replace(&mut self.state, ClientState::Disconnected) {
971 ClientState::Disconnecting { version: _, rpc } => {
972 tracing::info!("VmBus client disconnected");
973 rpc.complete(());
974 }
975 state => {
976 tracing::warn!(client_state = %state, "invalid client state for UnloadComplete");
977 }
978 }
979 }
980
981 fn handle_modify_complete(&mut self, response: protocol::ModifyConnectionResponse) {
982 if let Some(request) = self.modify_request.take() {
983 request.complete(response.connection_state)
984 } else {
985 tracing::warn!("Unexpected modify complete request");
986 }
987 }
988
989 fn handle_modify_channel_response(
990 &mut self,
991 response: protocol::ModifyChannelResponse,
992 ) -> TriedRelease {
993 let mut channel = self.channels.get_mut(response.channel_id);
994 let Some(sender) = channel.modify_response_send.take() else {
995 panic!(
996 "unexpected modify channel response for channel {:#x}",
997 response.channel_id.0
998 );
999 };
1000
1001 sender.complete(response.status);
1002 channel.try_release(&mut self.inner.messages)
1003 }
1004
1005 fn handle_tl_connect_result(&mut self, response: protocol::TlConnectResult) {
1006 if let Some(rpc) = self.hvsock_tracker.check_result(&response) {
1007 rpc.complete(None);
1008 }
1009 }
1010
1011 fn handle_synic_message(&mut self, data: &[u8]) -> bool {
1013 let msg = Message::parse(data, self.state.get_version()).unwrap();
1014 tracing::trace!(?msg, "received client message from synic");
1015
1016 match msg {
1017 Message::VersionResponse2(version_response, ..) => {
1018 self.handle_version_response(version_response);
1019 }
1020 Message::VersionResponse(version_response, ..) => {
1021 self.handle_version_response(version_response.into());
1022 }
1023 Message::OfferChannel(offer, ..) => {
1024 self.handle_offer(offer);
1025 }
1026 Message::AllOffersDelivered(..) => {
1027 self.handle_offers_delivered();
1028 }
1029 Message::UnloadComplete(..) => {
1030 self.handle_unload_complete();
1031 }
1032 Message::ModifyConnectionResponse(response, ..) => {
1033 self.handle_modify_complete(response);
1034 }
1035 Message::GpadlCreated(gpadl, ..) => {
1036 self.handle_gpadl_created(gpadl);
1037 }
1038 Message::OpenResult(result, ..) => {
1039 self.handle_open_result(result);
1040 }
1041 Message::GpadlTorndown(gpadl, ..) => {
1042 self.handle_gpadl_torndown(gpadl);
1043 }
1044 Message::RescindChannelOffer(rescind, ..) => {
1045 self.handle_rescind(rescind);
1046 }
1047 Message::ModifyChannelResponse(response, ..) => {
1048 self.handle_modify_channel_response(response);
1049 }
1050 Message::TlConnectResult(response, ..) => self.handle_tl_connect_result(response),
1051 Message::CloseReservedChannelResponse(..) => {
1053 todo!("Unsupported message {msg:?}")
1054 }
1055 Message::PauseResponse(..) => {
1056 return false;
1057 }
1058 Message::RequestOffers(..)
1060 | Message::OpenChannel2(..)
1061 | Message::OpenChannel(..)
1062 | Message::CloseChannel(..)
1063 | Message::GpadlHeader(..)
1064 | Message::GpadlBody(..)
1065 | Message::GpadlTeardown(..)
1066 | Message::RelIdReleased(..)
1067 | Message::InitiateContact(..)
1068 | Message::InitiateContact2(..)
1069 | Message::Unload(..)
1070 | Message::OpenReservedChannel(..)
1071 | Message::CloseReservedChannel(..)
1072 | Message::TlConnectRequest2(..)
1073 | Message::TlConnectRequest(..)
1074 | Message::ModifyChannel(..)
1075 | Message::ModifyConnection(..)
1076 | Message::Pause(..)
1077 | Message::Resume(..) => {
1078 unreachable!("Client received server message {msg:?}");
1079 }
1080 }
1081 true
1082 }
1083
1084 fn handle_open_channel(
1085 &mut self,
1086 channel_id: ChannelId,
1087 rpc: FailableRpc<OpenRequest, OpenOutput>,
1088 ) {
1089 let mut channel = self.channels.get_mut(channel_id);
1090 match &channel.state {
1091 ChannelState::Offered => {}
1092 ChannelState::Revoked => {
1093 rpc.fail(anyhow::anyhow!("channel revoked"));
1094 return;
1095 }
1096 state => {
1097 rpc.fail(anyhow::anyhow!("invalid channel state: {}", state));
1098 return;
1099 }
1100 }
1101
1102 tracing::info!(channel_id = channel_id.0, "opening channel on host");
1103 let (request, rpc) = rpc.split();
1104 let open_data = &request.open_data;
1105
1106 let supports_interrupt_redirection =
1107 if let ClientState::Connected { version, .. } = self.state {
1108 version.feature_flags.guest_specified_signal_parameters()
1109 || version.feature_flags.channel_interrupt_redirection()
1110 } else {
1111 false
1112 };
1113
1114 if !supports_interrupt_redirection && open_data.event_flag != channel_id.0 as u16 {
1115 rpc.fail(anyhow::anyhow!(
1116 "host does not support specifying the event flag"
1117 ));
1118 return;
1119 }
1120
1121 let open_channel = protocol::OpenChannel {
1122 channel_id,
1123 open_id: 0,
1124 ring_buffer_gpadl_id: open_data.ring_gpadl_id,
1125 target_vp: open_data.target_vp,
1126 downstream_ring_buffer_page_offset: open_data.ring_offset,
1127 user_data: open_data.user_data,
1128 };
1129
1130 let connection_id = if request.use_vtl2_connection_id {
1131 if !supports_interrupt_redirection {
1132 rpc.fail(anyhow::anyhow!(
1133 "host does not support specfiying the connection ID"
1134 ));
1135 return;
1136 }
1137 protocol::ConnectionId::new(channel_id.0, 2.try_into().unwrap(), 7).0
1138 } else {
1139 open_data.connection_id
1140 };
1141
1142 let mut flags = OpenChannelFlags::new();
1145 let event_flag = if let Some(event) = &request.incoming_event {
1146 if !supports_interrupt_redirection {
1147 rpc.fail(anyhow::anyhow!(
1148 "host does not support redirecting interrupts"
1149 ));
1150 return;
1151 }
1152
1153 flags.set_redirect_interrupt(true);
1154 match self.inner.synic.allocate_event_flag(event) {
1155 Ok(flag) => flag,
1156 Err(err) => {
1157 rpc.fail(err.context("failed to allocate event flag"));
1158 return;
1159 }
1160 }
1161 } else {
1162 open_data.event_flag
1163 };
1164
1165 if supports_interrupt_redirection {
1166 self.inner.messages.send(&protocol::OpenChannel2 {
1167 open_channel,
1168 connection_id,
1169 event_flag,
1170 flags,
1171 });
1172 } else {
1173 self.inner.messages.send(&open_channel);
1174 }
1175
1176 channel.state = ChannelState::Opening {
1177 connection_id,
1178 redirected_event_flag: (request.incoming_event.is_some()).then_some(event_flag),
1179 redirected_event: request.incoming_event,
1180 rpc,
1181 }
1182 }
1183
1184 fn handle_restore_channel(
1185 &mut self,
1186 channel_id: ChannelId,
1187 request: RestoreRequest,
1188 ) -> Result<OpenOutput> {
1189 let mut channel = self.channels.get_mut(channel_id);
1190 if !matches!(channel.state, ChannelState::Restored) {
1191 anyhow::bail!("invalid channel state: {}", channel.state);
1192 }
1193
1194 if request.incoming_event.is_some() != request.redirected_event_flag.is_some() {
1195 anyhow::bail!("incoming event and redirected event flag must both be set or unset");
1196 }
1197
1198 if let Some((flag, event)) = request
1199 .redirected_event_flag
1200 .zip(request.incoming_event.as_ref())
1201 {
1202 self.inner.synic.restore_event_flag(flag, event)?;
1203 }
1204
1205 channel.state = ChannelState::Opened {
1206 connection_id: request.connection_id,
1207 redirected_event_flag: request.redirected_event_flag,
1208 redirected_event: request.incoming_event,
1209 };
1210 Ok(OpenOutput {
1211 redirected_event_flag: request.redirected_event_flag,
1212 guest_to_host_signal: self
1213 .inner
1214 .synic
1215 .guest_to_host_interrupt(request.connection_id),
1216 })
1217 }
1218
1219 fn handle_gpadl(&mut self, channel_id: ChannelId, rpc: FailableRpc<GpadlRequest, ()>) {
1220 let (request, rpc) = rpc.split();
1221 let mut channel = self.channels.get_mut(channel_id);
1222 if channel
1223 .gpadls
1224 .insert(request.id, GpadlState::Offered(rpc))
1225 .is_some()
1226 {
1227 panic!(
1228 "duplicate gpadl ID {:?} for channel {:?}.",
1229 request.id, channel_id
1230 );
1231 }
1232
1233 tracing::trace!(
1234 channel_id = channel_id.0,
1235 gpadl_id = request.id.0,
1236 count = request.count,
1237 len = request.buf.len(),
1238 "received gpadl request"
1239 );
1240
1241 let (first, remaining) = if request.buf.len() > protocol::GpadlHeader::MAX_DATA_VALUES {
1243 request.buf.split_at(protocol::GpadlHeader::MAX_DATA_VALUES)
1244 } else {
1245 (request.buf.as_slice(), [].as_slice())
1246 };
1247
1248 let message = protocol::GpadlHeader {
1249 channel_id,
1250 gpadl_id: request.id,
1251 len: (request.buf.len() * size_of::<u64>())
1252 .try_into()
1253 .expect("Too many GPA values"),
1254 count: request.count,
1255 };
1256
1257 self.inner
1258 .messages
1259 .send_with_data(&message, first.as_bytes());
1260
1261 let message = protocol::GpadlBody {
1263 rsvd: 0,
1264 gpadl_id: request.id,
1265 };
1266 for chunk in remaining.chunks(protocol::GpadlBody::MAX_DATA_VALUES) {
1267 self.inner
1268 .messages
1269 .send_with_data(&message, chunk.as_bytes());
1270 }
1271 }
1272
1273 fn handle_gpadl_teardown(&mut self, channel_id: ChannelId, rpc: Rpc<GpadlId, ()>) {
1274 let (gpadl_id, rpc) = rpc.split();
1275 let mut channel = self.channels.get_mut(channel_id);
1276 let Some(gpadl_state) = channel.gpadls.get_mut(&gpadl_id) else {
1277 tracing::warn!(
1278 gpadl_id = gpadl_id.0,
1279 channel_id = channel_id.0,
1280 "Gpadl teardown for unknown gpadl or revoked channel"
1281 );
1282 return;
1283 };
1284
1285 match gpadl_state {
1286 GpadlState::Offered(_) => {
1287 tracing::warn!(
1288 gpadl_id = gpadl_id.0,
1289 channel_id = channel_id.0,
1290 "gpadl teardown for offered gpadl"
1291 );
1292 }
1293 GpadlState::Created => {
1294 *gpadl_state = GpadlState::TearingDown { rpcs: vec![rpc] };
1295 assert!(
1299 self.inner
1300 .teardown_gpadls
1301 .insert(gpadl_id, channel_id)
1302 .is_none(),
1303 "Gpadl state validated above"
1304 );
1305
1306 self.inner.messages.send(&protocol::GpadlTeardown {
1307 channel_id,
1308 gpadl_id,
1309 });
1310 }
1311 GpadlState::TearingDown { rpcs } => {
1312 rpcs.push(rpc);
1313 }
1314 }
1315 }
1316
1317 fn handle_close_channel(&mut self, channel_id: ChannelId) {
1318 let mut channel = self.channels.get_mut(channel_id);
1319 self.inner.close_channel(channel_id, &mut channel);
1320 }
1321
1322 fn handle_modify_channel(&mut self, channel_id: ChannelId, rpc: Rpc<ModifyRequest, i32>) {
1323 assert!(self.check_version(Version::Iron));
1327 let mut channel = self.channels.get_mut(channel_id);
1328 if channel.modify_response_send.is_some() {
1329 panic!("duplicate channel modify request {channel_id:?}");
1330 }
1331
1332 let (request, response) = rpc.split();
1333 channel.modify_response_send = Some(response);
1334 let payload = match request {
1335 ModifyRequest::TargetVp { target_vp } => protocol::ModifyChannel {
1336 channel_id,
1337 target_vp,
1338 },
1339 };
1340
1341 self.inner.messages.send(&payload);
1342 }
1343
1344 fn handle_channel_request(&mut self, channel_id: ChannelId, request: ChannelRequest) {
1345 match request {
1346 ChannelRequest::Open(rpc) => self.handle_open_channel(channel_id, rpc),
1347 ChannelRequest::Restore(rpc) => {
1348 rpc.handle_failable_sync(|request| self.handle_restore_channel(channel_id, request))
1349 }
1350 ChannelRequest::Gpadl(req) => self.handle_gpadl(channel_id, req),
1351 ChannelRequest::TeardownGpadl(req) => self.handle_gpadl_teardown(channel_id, req),
1352 ChannelRequest::Close(req) => {
1353 req.handle_sync(|()| self.handle_close_channel(channel_id))
1354 }
1355 ChannelRequest::Modify(req) => self.handle_modify_channel(channel_id, req),
1356 }
1357 }
1358
1359 async fn handle_task(&mut self, task: TaskRequest) {
1360 match task {
1361 TaskRequest::Inspect(deferred) => {
1362 deferred.inspect(&*self);
1363 }
1364 TaskRequest::Save(rpc) => rpc.handle_sync(|()| self.handle_save()),
1365 TaskRequest::Restore(rpc) => {
1366 rpc.handle_sync(|saved_state| self.handle_restore(saved_state))
1367 }
1368 TaskRequest::PostRestore(rpc) => rpc.handle_sync(|()| self.handle_post_restore()),
1369 TaskRequest::Start => self.handle_start(),
1370 TaskRequest::Stop(rpc) => rpc.handle(async |()| self.handle_stop().await).await,
1371 }
1372 }
1373
1374 fn handle_device_removal(&mut self, channel_id: ChannelId) -> TriedRelease {
1376 let mut channel = self.channels.get_mut(channel_id);
1377 channel.is_client_released = true;
1378 if let ChannelState::Opened { .. } = channel.state {
1380 tracing::warn!(
1381 channel_id = channel_id.0,
1382 "Channel dropped without closing first"
1383 );
1384 self.inner.close_channel(channel_id, &mut channel);
1385 }
1386 channel.try_release(&mut self.inner.messages)
1387 }
1388
1389 fn check_version(&self, version: Version) -> bool {
1391 matches!(self.state, ClientState::Connected { version: v, .. } if v.version >= version)
1392 }
1393
1394 fn handle_start(&mut self) {
1395 assert!(!self.running);
1396 self.msg_source.resume_message_stream();
1397 self.inner.messages.resume();
1398 self.running = true;
1399 }
1400
1401 async fn handle_stop(&mut self) {
1402 assert!(self.running);
1403
1404 loop {
1405 while let Some((id, request)) = self.channels.revoked_channel_with_pending_request() {
1410 tracelimit::info_ratelimited!(
1411 channel_id = id.0,
1412 request,
1413 "waiting for responses for channel"
1414 );
1415 assert!(self.process_next_message().await);
1416 }
1417
1418 if self.can_pause_resume() {
1419 self.inner.messages.pause();
1420 } else {
1421 self.msg_source.pause_message_stream();
1424 self.inner.messages.force_pause();
1425 }
1426
1427 while self.process_next_message().await {}
1430
1431 if self
1434 .channels
1435 .revoked_channel_with_pending_request()
1436 .is_none()
1437 {
1438 break;
1439 }
1440 if !self.can_pause_resume() {
1441 self.msg_source.resume_message_stream();
1442 }
1443 self.inner.messages.resume();
1444 }
1445
1446 tracing::debug!("messages drained");
1447 self.running = false;
1449 }
1450
1451 async fn process_next_message(&mut self) -> bool {
1452 let mut buf = [0; protocol::MAX_MESSAGE_SIZE];
1453 let recv = self.msg_source.recv(&mut buf);
1454 let flush = async {
1457 self.inner.messages.flush_messages().await;
1458 std::future::pending().await
1459 };
1460 let size = (recv, flush)
1461 .race()
1462 .await
1463 .expect("Fatal error reading messages from synic");
1464 if size == 0 {
1465 return false;
1466 }
1467 self.handle_synic_message(&buf[..size])
1468 }
1469
1470 fn can_pause_resume(&self) -> bool {
1479 if let ClientState::Connected { version, .. } = self.state {
1480 version.feature_flags.pause_resume()
1481 } else {
1482 false
1483 }
1484 }
1485
1486 async fn run(&mut self) {
1487 let mut buf = [0; protocol::MAX_MESSAGE_SIZE];
1488 loop {
1489 let mut message_recv =
1490 OptionFuture::from(self.running.then(|| self.msg_source.recv(&mut buf).fuse()));
1491
1492 let host_backed_up = !self.inner.messages.is_empty();
1502 let mut flush_messages = OptionFuture::from(
1503 (self.running && host_backed_up)
1504 .then(|| self.inner.messages.flush_messages().fuse()),
1505 );
1506
1507 let mut client_request_recv = OptionFuture::from(
1508 (self.running && !host_backed_up).then(|| self.client_request_recv.next()),
1509 );
1510
1511 let mut channel_requests = OptionFuture::from(
1512 (self.running && !host_backed_up)
1513 .then(|| self.inner.channel_requests.select_next_some()),
1514 );
1515
1516 futures::select! { _r = pin!(flush_messages) => {}
1518 r = self.task_recv.next() => {
1519 if let Some(task) = r {
1520 self.handle_task(task).await;
1521 } else {
1522 break;
1523 }
1524 }
1525 r = client_request_recv => {
1526 if let Some(Some(request)) = r {
1527 self.handle_client_request(request);
1528 } else {
1529 break;
1530 }
1531 }
1532 r = channel_requests => {
1533 match r.unwrap() {
1534 (id, Some(request)) => self.handle_channel_request(id, request),
1535 (id, _) => {
1536 self.handle_device_removal(id);
1537 }
1538 }
1539 }
1540 r = message_recv => {
1541 match r.unwrap() {
1542 Ok(size) => {
1543 if size == 0 {
1544 panic!("Unexpected end of file reading messages from synic.");
1545 }
1546
1547 self.handle_synic_message(&buf[..size]);
1548 }
1549 Err(err) => {
1550 panic!("Error reading messages from synic: {err:?}");
1551 }
1552 }
1553 }
1554 complete => break,
1555 }
1556 }
1557 }
1558}
1559
1560impl ClientTaskInner {
1561 fn close_channel(&mut self, channel_id: ChannelId, channel: &mut Channel) {
1562 if let ChannelState::Opened {
1563 redirected_event_flag,
1564 ..
1565 } = channel.state
1566 {
1567 if let Some(flag) = redirected_event_flag {
1568 self.synic.free_event_flag(flag);
1569 }
1570 tracing::info!(channel_id = channel_id.0, "closing channel on host");
1571 self.messages.send(&protocol::CloseChannel { channel_id });
1572 channel.state = ChannelState::Offered;
1573 } else {
1574 tracing::warn!(
1575 id = %channel_id.0,
1576 channel_state = %channel.state,
1577 "invalid channel state for close channel"
1578 );
1579 }
1580 }
1581}
1582
1583#[derive(Debug, Inspect)]
1584#[inspect(external_tag)]
1585enum GpadlState {
1586 Offered(#[inspect(skip)] FailableRpc<(), ()>),
1588 Created,
1590 TearingDown {
1592 #[inspect(skip)]
1593 rpcs: Vec<Rpc<(), ()>>,
1594 },
1595}
1596
1597#[derive(Inspect)]
1598struct OutgoingMessages {
1599 #[inspect(skip)]
1600 poster: Box<dyn PollPostMessage>,
1601 #[inspect(with = "|x| x.len()")]
1602 queued: VecDeque<OutgoingMessage>,
1603 state: OutgoingMessageState,
1604}
1605
1606#[derive(Inspect, PartialEq, Eq, Debug)]
1607enum OutgoingMessageState {
1608 Running,
1609 SendingPauseMessage,
1610 Paused,
1611}
1612
1613impl OutgoingMessages {
1614 fn send<T: IntoBytes + protocol::VmbusMessage + std::fmt::Debug + Immutable + KnownLayout>(
1615 &mut self,
1616 msg: &T,
1617 ) {
1618 self.send_with_data(msg, &[])
1619 }
1620
1621 fn send_with_data<
1622 T: IntoBytes + protocol::VmbusMessage + std::fmt::Debug + Immutable + KnownLayout,
1623 >(
1624 &mut self,
1625 msg: &T,
1626 data: &[u8],
1627 ) {
1628 tracing::trace!(typ = ?T::MESSAGE_TYPE, "Sending message to host");
1629 let msg = OutgoingMessage::with_data(msg, data);
1630 if self.queued.is_empty() && self.state == OutgoingMessageState::Running {
1631 let r = self.poster.poll_post_message(
1632 &mut Context::from_waker(std::task::Waker::noop()),
1633 protocol::VMBUS_MESSAGE_REDIRECT_CONNECTION_ID,
1634 1,
1635 msg.data(),
1636 );
1637 if let Poll::Ready(()) = r {
1638 return;
1639 }
1640 }
1641 tracing::trace!("queueing message");
1642 self.queued.push_back(msg);
1643 }
1644
1645 async fn flush_messages(&mut self) {
1646 let mut send = async |msg: &OutgoingMessage| {
1647 poll_fn(|cx| {
1648 self.poster.poll_post_message(
1649 cx,
1650 protocol::VMBUS_MESSAGE_REDIRECT_CONNECTION_ID,
1651 1,
1652 msg.data(),
1653 )
1654 })
1655 .await
1656 };
1657 match self.state {
1658 OutgoingMessageState::Running => {
1659 while let Some(msg) = self.queued.front() {
1660 send(msg).await;
1661 tracing::trace!("sent queued message");
1662 self.queued.pop_front();
1663 }
1664 }
1665 OutgoingMessageState::SendingPauseMessage => {
1666 send(&OutgoingMessage::new(&protocol::Pause)).await;
1667 tracing::trace!("sent pause message");
1668 self.state = OutgoingMessageState::Paused;
1669 }
1670 OutgoingMessageState::Paused => {}
1671 }
1672 }
1673
1674 fn pause(&mut self) {
1677 assert_eq!(self.state, OutgoingMessageState::Running);
1678 self.state = OutgoingMessageState::SendingPauseMessage;
1679 self.queued
1681 .push_front(OutgoingMessage::new(&protocol::Resume));
1682 }
1683
1684 fn force_pause(&mut self) {
1688 assert_eq!(self.state, OutgoingMessageState::Running);
1689 self.state = OutgoingMessageState::Paused;
1690 }
1691
1692 fn resume(&mut self) {
1693 assert_eq!(self.state, OutgoingMessageState::Paused);
1694 self.state = OutgoingMessageState::Running;
1695 }
1696
1697 fn is_empty(&self) -> bool {
1698 self.queued.is_empty()
1699 }
1700}
1701
1702#[derive(Inspect)]
1703struct ClientTaskInner {
1704 messages: OutgoingMessages,
1705 #[inspect(with = "|x| inspect::iter_by_key(x).map_key(|id| id.0)")]
1706 teardown_gpadls: HashMap<GpadlId, ChannelId>,
1707 #[inspect(skip)]
1708 channel_requests: SelectAll<TaggedStream<ChannelId, mesh::Receiver<ChannelRequest>>>,
1709 synic: SynicState,
1710}
1711
1712#[derive(Inspect)]
1713struct SynicState {
1714 #[inspect(skip)]
1715 event_client: Arc<dyn SynicEventClient>,
1716 #[inspect(iter_by_index)]
1717 event_flag_state: Vec<bool>,
1718}
1719
1720#[derive(Inspect, Default)]
1721#[inspect(transparent)]
1722struct ChannelList(
1723 #[inspect(with = "|x| inspect::iter_by_key(x).map_key(|id| id.0)")] HashMap<ChannelId, Channel>,
1724);
1725
1726struct ChannelRef<'a>(hash_map::OccupiedEntry<'a, ChannelId, Channel>);
1729
1730struct TriedRelease(());
1734
1735impl ChannelRef<'_> {
1736 fn try_release(self, messages: &mut OutgoingMessages) -> TriedRelease {
1740 if self.is_client_released
1741 && matches!(self.state, ChannelState::Revoked)
1742 && self.pending_request().is_none()
1743 {
1744 let channel_id = *self.0.key();
1745 tracelimit::info_ratelimited!(channel_id = channel_id.0, "releasing channel");
1746 messages.send(&protocol::RelIdReleased { channel_id });
1747 self.0.remove();
1748 }
1749 TriedRelease(())
1750 }
1751}
1752
1753impl Deref for ChannelRef<'_> {
1754 type Target = Channel;
1755
1756 fn deref(&self) -> &Self::Target {
1757 self.0.get()
1758 }
1759}
1760
1761impl DerefMut for ChannelRef<'_> {
1762 fn deref_mut(&mut self) -> &mut Self::Target {
1763 self.0.get_mut()
1764 }
1765}
1766
1767impl ChannelList {
1768 fn revoked_channel_with_pending_request(&self) -> Option<(ChannelId, &'static str)> {
1769 self.0.iter().find_map(|(&id, channel)| {
1770 if !matches!(channel.state, ChannelState::Revoked) {
1771 return None;
1772 }
1773 Some((id, channel.pending_request()?))
1774 })
1775 }
1776
1777 #[track_caller]
1778 fn get_mut(&mut self, channel_id: ChannelId) -> ChannelRef<'_> {
1779 match self.0.entry(channel_id) {
1780 hash_map::Entry::Occupied(entry) => ChannelRef(entry),
1781 hash_map::Entry::Vacant(_) => {
1782 panic!("channel {:?} not found", channel_id);
1783 }
1784 }
1785 }
1786}
1787
1788impl SynicState {
1789 fn guest_to_host_interrupt(&self, connection_id: u32) -> Interrupt {
1790 Interrupt::from_fn({
1791 let event_client = self.event_client.clone();
1792 move || {
1793 if let Err(err) = event_client.signal_event(connection_id, 0) {
1794 tracelimit::warn_ratelimited!(
1795 error = &err as &dyn std::error::Error,
1796 "failed to signal event"
1797 );
1798 }
1799 }
1800 })
1801 }
1802
1803 const MAX_EVENT_FLAGS: u16 = 2047;
1804
1805 fn allocate_event_flag(&mut self, event: &Event) -> Result<u16> {
1806 let i = self
1807 .event_flag_state
1808 .iter()
1809 .position(|&used| !used)
1810 .ok_or(())
1811 .or_else(|()| {
1812 if self.event_flag_state.len() >= Self::MAX_EVENT_FLAGS as usize {
1813 anyhow::bail!("out of event flags");
1814 }
1815 self.event_flag_state.push(false);
1816 Ok(self.event_flag_state.len() - 1)
1817 })?;
1818
1819 let event_flag = (i + 1) as u16;
1820 self.event_client
1821 .map_event(event_flag, event)
1822 .context("failed to map event")?;
1823 self.event_flag_state[i] = true;
1824 Ok(event_flag)
1825 }
1826
1827 fn restore_event_flag(&mut self, flag: u16, event: &Event) -> Result<()> {
1828 let i = (flag as usize)
1829 .checked_sub(1)
1830 .context("invalid event flag")?;
1831 if i >= Self::MAX_EVENT_FLAGS as usize {
1832 anyhow::bail!("invalid event flag");
1833 }
1834 if self.event_flag_state.len() <= i {
1835 self.event_flag_state.resize(i + 1, false);
1836 }
1837 if self.event_flag_state[i] {
1838 anyhow::bail!("event flag already in use");
1839 }
1840 self.event_client
1841 .map_event(flag, event)
1842 .context("failed to map event")?;
1843 self.event_flag_state[i] = true;
1844 Ok(())
1845 }
1846
1847 fn free_event_flag(&mut self, flag: u16) {
1848 let i = flag as usize - 1;
1849 assert!(i < self.event_flag_state.len());
1850 self.event_flag_state[i] = false;
1851 }
1852}
1853
1854#[cfg(test)]
1855mod tests {
1856 use super::*;
1857 use futures_concurrency::future::Join;
1858 use guid::Guid;
1859 use pal_async::DefaultDriver;
1860 use pal_async::async_test;
1861 use pal_async::timer::PolledTimer;
1862 use protocol::TargetInfo;
1863 use std::fmt::Debug;
1864 use std::task::ready;
1865 use std::time::Duration;
1866 use test_with_tracing::test;
1867 use vmbus_core::protocol::MessageHeader;
1868 use vmbus_core::protocol::MessageType;
1869 use vmbus_core::protocol::OfferFlags;
1870 use vmbus_core::protocol::UserDefinedData;
1871 use vmbus_core::protocol::VmbusMessage;
1872 use zerocopy::FromBytes;
1873 use zerocopy::FromZeros;
1874 use zerocopy::Immutable;
1875 use zerocopy::IntoBytes;
1876 use zerocopy::KnownLayout;
1877
1878 const VMBUS_TEST_CLIENT_ID: Guid = guid::guid!("e6e6e6e6-e6e6-e6e6-e6e6-e6e6e6e6e6e6");
1879
1880 fn in_msg<T: IntoBytes + Immutable + KnownLayout>(message_type: MessageType, t: T) -> Vec<u8> {
1881 let mut data = Vec::new();
1882 data.extend_from_slice(&message_type.0.to_ne_bytes());
1883 data.extend_from_slice(&0u32.to_ne_bytes());
1884 data.extend_from_slice(t.as_bytes());
1885 data
1886 }
1887
1888 #[track_caller]
1889 fn check_message<T>(msg: OutgoingMessage, chk: T)
1890 where
1891 T: IntoBytes + FromBytes + Immutable + KnownLayout + Debug + VmbusMessage,
1892 {
1893 check_message_with_data(msg, chk, &[]);
1894 }
1895
1896 #[track_caller]
1897 fn check_message_with_data<T>(msg: OutgoingMessage, chk: T, data: &[u8])
1898 where
1899 T: IntoBytes + FromBytes + Immutable + KnownLayout + Debug + VmbusMessage,
1900 {
1901 let chk_data = OutgoingMessage::with_data(&chk, data);
1902 if msg.data() != chk_data.data() {
1903 let (header, rest) = MessageHeader::read_from_prefix(msg.data()).unwrap();
1904 assert_eq!(header.message_type(), <T as VmbusMessage>::MESSAGE_TYPE);
1905 let (msg, rest) = T::read_from_prefix(rest).expect("incorrect message size");
1906 if msg.as_bytes() != chk.as_bytes() {
1907 panic!("mismatched messages, expected {:#?}, got {:#?}", chk, msg);
1908 }
1909 if rest != data {
1910 panic!("mismatched data, expected {:#?}, got {:#?}", data, rest);
1911 }
1912 }
1913 }
1914
1915 struct TestServer {
1916 messages: mesh::Receiver<OutgoingMessage>,
1917 send: mesh::Sender<Vec<u8>>,
1918 }
1919
1920 impl TestServer {
1921 async fn next(&mut self) -> Option<OutgoingMessage> {
1922 self.messages.next().await
1923 }
1924
1925 fn send(&self, msg: Vec<u8>) {
1926 self.send.send(msg);
1927 }
1928
1929 async fn connect(&mut self, client: &mut VmbusClient) -> ConnectResult {
1930 self.connect_with_channels(client, |_| {}).await
1931 }
1932
1933 async fn connect_with_channels(
1934 &mut self,
1935 client: &mut VmbusClient,
1936 send_offers: impl FnOnce(&mut Self),
1937 ) -> ConnectResult {
1938 let client_connect = client.connect(0, None, Guid::ZERO);
1939
1940 let server_connect = async {
1941 let _ = self.next().await.unwrap();
1942
1943 self.send(in_msg(
1944 MessageType::VERSION_RESPONSE,
1945 protocol::VersionResponse2 {
1946 version_response: protocol::VersionResponse {
1947 version_supported: 1,
1948 connection_state: ConnectionState::SUCCESSFUL,
1949 padding: 0,
1950 selected_version_or_connection_id: 0,
1951 },
1952 supported_features: SUPPORTED_FEATURE_FLAGS.into(),
1953 },
1954 ));
1955
1956 check_message(self.next().await.unwrap(), protocol::RequestOffers {});
1957
1958 send_offers(self);
1959 self.send(in_msg(MessageType::ALL_OFFERS_DELIVERED, [0x00]));
1960 };
1961
1962 let (connection, ()) = (client_connect, server_connect).join().await;
1963
1964 let connection = connection.unwrap();
1965 assert_eq!(connection.version.version, Version::Copper);
1966 assert_eq!(connection.version.feature_flags, SUPPORTED_FEATURE_FLAGS);
1967 connection
1968 }
1969
1970 async fn get_channel(&mut self, client: &mut VmbusClient) -> OfferInfo {
1971 let [channel] = self
1972 .get_channels(client, 1)
1973 .await
1974 .offers
1975 .try_into()
1976 .unwrap();
1977 channel
1978 }
1979
1980 async fn get_channels(&mut self, client: &mut VmbusClient, count: usize) -> ConnectResult {
1981 self.connect_with_channels(client, |this| {
1982 for i in 0..count {
1983 let offer = protocol::OfferChannel {
1984 interface_id: Guid::new_random(),
1985 instance_id: Guid::new_random(),
1986 rsvd: [0; 4],
1987 flags: OfferFlags::new(),
1988 mmio_megabytes: 0,
1989 user_defined: UserDefinedData::new_zeroed(),
1990 subchannel_index: 0,
1991 mmio_megabytes_optional: 0,
1992 channel_id: ChannelId(i as u32),
1993 monitor_id: 0,
1994 monitor_allocated: 0,
1995 is_dedicated: 0,
1996 connection_id: 0,
1997 };
1998
1999 this.send(in_msg(MessageType::OFFER_CHANNEL, offer));
2000 }
2001 })
2002 .await
2003 }
2004
2005 async fn stop_client(&mut self, client: &mut VmbusClient) {
2006 let client_stop = client.stop();
2007 let server_stop = async {
2008 check_message(self.next().await.unwrap(), protocol::Pause);
2009 self.send(in_msg(MessageType::PAUSE_RESPONSE, protocol::PauseResponse));
2010 };
2011 (client_stop, server_stop).join().await;
2012 }
2013
2014 async fn start_client(&mut self, client: &mut VmbusClient) {
2015 client.start();
2016 check_message(self.next().await.unwrap(), protocol::Resume);
2017 }
2018 }
2019
2020 struct TestServerClient {
2021 sender: mesh::Sender<OutgoingMessage>,
2022 timer: PolledTimer,
2023 deadline: Option<pal_async::timer::Instant>,
2024 }
2025
2026 impl PollPostMessage for TestServerClient {
2027 fn poll_post_message(
2028 &mut self,
2029 cx: &mut Context<'_>,
2030 _connection_id: u32,
2031 _typ: u32,
2032 msg: &[u8],
2033 ) -> Poll<()> {
2034 loop {
2035 if let Some(deadline) = self.deadline {
2036 ready!(self.timer.poll_until(cx, deadline));
2037 self.deadline = None;
2038 }
2039 let mut b = [0];
2044 getrandom::fill(&mut b).unwrap();
2045 if b[0] % 4 == 0 {
2046 self.deadline =
2047 Some(pal_async::timer::Instant::now() + Duration::from_millis(10));
2048 } else {
2049 let msg = OutgoingMessage::from_message(msg).unwrap();
2050 tracing::info!(
2051 msg = ?MessageHeader::read_from_prefix(msg.data()),
2052 "sending message"
2053 );
2054 self.sender.send(msg);
2055 break Poll::Ready(());
2056 }
2057 }
2058 }
2059 }
2060
2061 struct NoopSynicEvents;
2062
2063 impl SynicEventClient for NoopSynicEvents {
2064 fn map_event(&self, _event_flag: u16, _event: &Event) -> std::io::Result<()> {
2065 Ok(())
2066 }
2067
2068 fn unmap_event(&self, _event_flag: u16) {}
2069
2070 fn signal_event(&self, _connection_id: u32, _event_flag: u16) -> std::io::Result<()> {
2071 Err(std::io::ErrorKind::Unsupported.into())
2072 }
2073 }
2074
2075 struct TestMessageSource {
2076 msg_recv: mesh::Receiver<Vec<u8>>,
2077 paused: bool,
2078 }
2079
2080 impl AsyncRecv for TestMessageSource {
2081 fn poll_recv(
2082 &mut self,
2083 cx: &mut Context<'_>,
2084 mut bufs: &mut [std::io::IoSliceMut<'_>],
2085 ) -> Poll<std::io::Result<usize>> {
2086 let value = match self.msg_recv.poll_recv(cx) {
2087 Poll::Ready(v) => v.unwrap(),
2088 Poll::Pending => {
2089 if self.paused {
2090 return Poll::Ready(Ok(0));
2091 } else {
2092 return Poll::Pending;
2093 }
2094 }
2095 };
2096 let mut remaining = value.as_slice();
2097 let mut total_size = 0;
2098 while !remaining.is_empty() && !bufs.is_empty() {
2099 let size = bufs[0].len().min(remaining.len());
2100 bufs[0][..size].copy_from_slice(&remaining[..size]);
2101 remaining = &remaining[size..];
2102 bufs = &mut bufs[1..];
2103 total_size += size;
2104 }
2105
2106 Ok(total_size).into()
2107 }
2108 }
2109
2110 impl VmbusMessageSource for TestMessageSource {
2111 fn pause_message_stream(&mut self) {
2112 self.paused = true;
2113 }
2114
2115 fn resume_message_stream(&mut self) {
2116 self.paused = false;
2117 }
2118 }
2119
2120 fn test_init(driver: &DefaultDriver) -> (TestServer, VmbusClient) {
2121 let (msg_send, msg_recv) = mesh::channel();
2122 let (synic_send, synic_recv) = mesh::channel();
2123 let server = TestServer {
2124 messages: synic_recv,
2125 send: msg_send,
2126 };
2127 let mut client = VmbusClientBuilder::new(
2128 NoopSynicEvents,
2129 TestMessageSource {
2130 msg_recv,
2131 paused: false,
2132 },
2133 TestServerClient {
2134 sender: synic_send,
2135 deadline: None,
2136 timer: PolledTimer::new(driver),
2137 },
2138 )
2139 .build(driver);
2140 client.start();
2141 (server, client)
2142 }
2143
2144 #[async_test]
2145 async fn test_initiate_contact_success(driver: DefaultDriver) {
2146 let (mut server, client) = test_init(&driver);
2147 let _recv = client
2148 .access
2149 .client_request_send
2150 .call(ClientRequest::Connect, ConnectRequest::default());
2151 check_message(
2152 server.next().await.unwrap(),
2153 protocol::InitiateContact2 {
2154 initiate_contact: protocol::InitiateContact {
2155 version_requested: Version::Copper as u32,
2156 target_message_vp: 0,
2157 interrupt_page_or_target_info: TargetInfo::new()
2158 .with_sint(2)
2159 .with_vtl(0)
2160 .with_feature_flags(SUPPORTED_FEATURE_FLAGS.into())
2161 .into(),
2162 parent_to_child_monitor_page_gpa: 0,
2163 child_to_parent_monitor_page_gpa: 0,
2164 },
2165 ..FromZeros::new_zeroed()
2166 },
2167 );
2168 }
2169
2170 #[async_test]
2171 async fn test_connect_success(driver: DefaultDriver) {
2172 let (mut server, mut client) = test_init(&driver);
2173 let client_connect = client.connect(0, None, Guid::ZERO);
2174
2175 let server_connect = async {
2176 check_message(
2177 server.next().await.unwrap(),
2178 protocol::InitiateContact2 {
2179 initiate_contact: protocol::InitiateContact {
2180 version_requested: Version::Copper as u32,
2181 target_message_vp: 0,
2182 interrupt_page_or_target_info: TargetInfo::new()
2183 .with_sint(2)
2184 .with_vtl(0)
2185 .with_feature_flags(SUPPORTED_FEATURE_FLAGS.into())
2186 .into(),
2187 parent_to_child_monitor_page_gpa: 0,
2188 child_to_parent_monitor_page_gpa: 0,
2189 },
2190 ..FromZeros::new_zeroed()
2191 },
2192 );
2193
2194 server.send(in_msg(
2195 MessageType::VERSION_RESPONSE,
2196 protocol::VersionResponse2 {
2197 version_response: protocol::VersionResponse {
2198 version_supported: 1,
2199 connection_state: ConnectionState::SUCCESSFUL,
2200 padding: 0,
2201 selected_version_or_connection_id: 0,
2202 },
2203 supported_features: SUPPORTED_FEATURE_FLAGS.into_bits(),
2204 },
2205 ));
2206
2207 check_message(server.next().await.unwrap(), protocol::RequestOffers {});
2208 server.send(in_msg(MessageType::ALL_OFFERS_DELIVERED, [0x00]));
2209 };
2210
2211 let (connection, ()) = (client_connect, server_connect).join().await;
2212 let connection = connection.unwrap();
2213
2214 assert_eq!(connection.version.version, Version::Copper);
2215 assert_eq!(connection.version.feature_flags, SUPPORTED_FEATURE_FLAGS);
2216 }
2217
2218 #[async_test]
2219 async fn test_feature_flags(driver: DefaultDriver) {
2220 let (mut server, mut client) = test_init(&driver);
2221 let client_connect = client.connect(0, None, Guid::ZERO);
2222
2223 let server_connect = async {
2224 check_message(
2225 server.next().await.unwrap(),
2226 protocol::InitiateContact2 {
2227 initiate_contact: protocol::InitiateContact {
2228 version_requested: Version::Copper as u32,
2229 target_message_vp: 0,
2230 interrupt_page_or_target_info: TargetInfo::new()
2231 .with_sint(2)
2232 .with_vtl(0)
2233 .with_feature_flags(SUPPORTED_FEATURE_FLAGS.into())
2234 .into(),
2235 parent_to_child_monitor_page_gpa: 0,
2236 child_to_parent_monitor_page_gpa: 0,
2237 },
2238 ..FromZeros::new_zeroed()
2239 },
2240 );
2241
2242 server.send(in_msg(
2245 MessageType::VERSION_RESPONSE,
2246 protocol::VersionResponse2 {
2247 version_response: protocol::VersionResponse {
2248 version_supported: 1,
2249 connection_state: ConnectionState::SUCCESSFUL,
2250 padding: 0,
2251 selected_version_or_connection_id: 0,
2252 },
2253 supported_features: 2,
2254 },
2255 ));
2256
2257 check_message(server.next().await.unwrap(), protocol::RequestOffers {});
2258 server.send(in_msg(MessageType::ALL_OFFERS_DELIVERED, [0x00]));
2259 };
2260
2261 let (connection, ()) = (client_connect, server_connect).join().await;
2262 let connection = connection.unwrap();
2263
2264 assert_eq!(connection.version.version, Version::Copper);
2265 assert_eq!(
2266 connection.version.feature_flags,
2267 FeatureFlags::new().with_channel_interrupt_redirection(true)
2268 );
2269 }
2270
2271 #[async_test]
2272 async fn test_client_id(driver: DefaultDriver) {
2273 let (mut server, client) = test_init(&driver);
2274 let initiate_contact = ConnectRequest {
2275 client_id: VMBUS_TEST_CLIENT_ID,
2276 ..Default::default()
2277 };
2278 let _recv = client
2279 .access
2280 .client_request_send
2281 .call(ClientRequest::Connect, initiate_contact);
2282
2283 check_message(
2284 server.next().await.unwrap(),
2285 protocol::InitiateContact2 {
2286 initiate_contact: protocol::InitiateContact {
2287 version_requested: Version::Copper as u32,
2288 target_message_vp: 0,
2289 interrupt_page_or_target_info: TargetInfo::new()
2290 .with_sint(2)
2291 .with_vtl(0)
2292 .with_feature_flags(SUPPORTED_FEATURE_FLAGS.into())
2293 .into(),
2294 parent_to_child_monitor_page_gpa: 0,
2295 child_to_parent_monitor_page_gpa: 0,
2296 },
2297 client_id: VMBUS_TEST_CLIENT_ID,
2298 },
2299 );
2300 }
2301
2302 #[async_test]
2303 async fn test_version_negotiation(driver: DefaultDriver) {
2304 let (mut server, mut client) = test_init(&driver);
2305 let client_connect = client.connect(0, None, Guid::ZERO);
2306
2307 let server_connect = async {
2308 check_message(
2309 server.next().await.unwrap(),
2310 protocol::InitiateContact2 {
2311 initiate_contact: protocol::InitiateContact {
2312 version_requested: Version::Copper as u32,
2313 target_message_vp: 0,
2314 interrupt_page_or_target_info: TargetInfo::new()
2315 .with_sint(2)
2316 .with_vtl(0)
2317 .with_feature_flags(SUPPORTED_FEATURE_FLAGS.into())
2318 .into(),
2319 parent_to_child_monitor_page_gpa: 0,
2320 child_to_parent_monitor_page_gpa: 0,
2321 },
2322 ..FromZeros::new_zeroed()
2323 },
2324 );
2325
2326 server.send(in_msg(
2327 MessageType::VERSION_RESPONSE,
2328 protocol::VersionResponse {
2329 version_supported: 0,
2330 connection_state: ConnectionState::SUCCESSFUL,
2331 padding: 0,
2332 selected_version_or_connection_id: 0,
2333 },
2334 ));
2335
2336 check_message(
2337 server.next().await.unwrap(),
2338 protocol::InitiateContact {
2339 version_requested: Version::Iron as u32,
2340 target_message_vp: 0,
2341 interrupt_page_or_target_info: TargetInfo::new()
2342 .with_sint(2)
2343 .with_vtl(0)
2344 .with_feature_flags(FeatureFlags::new().into())
2345 .into(),
2346 parent_to_child_monitor_page_gpa: 0,
2347 child_to_parent_monitor_page_gpa: 0,
2348 },
2349 );
2350
2351 server.send(in_msg(
2352 MessageType::VERSION_RESPONSE,
2353 protocol::VersionResponse {
2354 version_supported: 1,
2355 connection_state: ConnectionState::SUCCESSFUL,
2356 padding: 0,
2357 selected_version_or_connection_id: 0,
2358 },
2359 ));
2360
2361 check_message(server.next().await.unwrap(), protocol::RequestOffers {});
2362 server.send(in_msg(MessageType::ALL_OFFERS_DELIVERED, [0x00]));
2363 };
2364
2365 let (connection, ()) = (client_connect, server_connect).join().await;
2366 let connection = connection.unwrap();
2367
2368 assert_eq!(connection.version.version, Version::Iron);
2369 assert_eq!(connection.version.feature_flags, FeatureFlags::new());
2370 }
2371
2372 #[async_test]
2373 async fn test_open_channel_success(driver: DefaultDriver) {
2374 let (mut server, mut client) = test_init(&driver);
2375 let channel = server.get_channel(&mut client).await;
2376
2377 let recv = channel.request_send.call(
2378 ChannelRequest::Open,
2379 OpenRequest {
2380 open_data: OpenData {
2381 target_vp: 0,
2382 ring_offset: 0,
2383 ring_gpadl_id: GpadlId(0),
2384 event_flag: 0,
2385 connection_id: 0,
2386 user_data: UserDefinedData::new_zeroed(),
2387 },
2388 incoming_event: None,
2389 use_vtl2_connection_id: false,
2390 },
2391 );
2392
2393 check_message(
2394 server.next().await.unwrap(),
2395 protocol::OpenChannel2 {
2396 open_channel: protocol::OpenChannel {
2397 channel_id: ChannelId(0),
2398 open_id: 0,
2399 ring_buffer_gpadl_id: GpadlId(0),
2400 target_vp: 0,
2401 downstream_ring_buffer_page_offset: 0,
2402 user_data: UserDefinedData::new_zeroed(),
2403 },
2404 connection_id: 0,
2405 event_flag: 0,
2406 flags: Default::default(),
2407 },
2408 );
2409
2410 server.send(in_msg(
2411 MessageType::OPEN_CHANNEL_RESULT,
2412 protocol::OpenResult {
2413 channel_id: ChannelId(0),
2414 open_id: 0,
2415 status: protocol::STATUS_SUCCESS as u32,
2416 },
2417 ));
2418
2419 recv.await.unwrap().unwrap();
2420 }
2421
2422 #[async_test]
2423 async fn test_open_channel_fail(driver: DefaultDriver) {
2424 let (mut server, mut client) = test_init(&driver);
2425 let channel = server.get_channel(&mut client).await;
2426
2427 let recv = channel.request_send.call(
2428 ChannelRequest::Open,
2429 OpenRequest {
2430 open_data: OpenData {
2431 target_vp: 0,
2432 ring_offset: 0,
2433 ring_gpadl_id: GpadlId(0),
2434 event_flag: 0,
2435 connection_id: 0,
2436 user_data: UserDefinedData::new_zeroed(),
2437 },
2438 incoming_event: None,
2439 use_vtl2_connection_id: false,
2440 },
2441 );
2442
2443 check_message(
2444 server.next().await.unwrap(),
2445 protocol::OpenChannel2 {
2446 open_channel: protocol::OpenChannel {
2447 channel_id: ChannelId(0),
2448 open_id: 0,
2449 ring_buffer_gpadl_id: GpadlId(0),
2450 target_vp: 0,
2451 downstream_ring_buffer_page_offset: 0,
2452 user_data: UserDefinedData::new_zeroed(),
2453 },
2454 connection_id: 0,
2455 event_flag: 0,
2456 flags: Default::default(),
2457 },
2458 );
2459
2460 server.send(in_msg(
2461 MessageType::OPEN_CHANNEL_RESULT,
2462 protocol::OpenResult {
2463 channel_id: ChannelId(0),
2464 open_id: 0,
2465 status: protocol::STATUS_UNSUCCESSFUL as u32,
2466 },
2467 ));
2468
2469 recv.await.unwrap().unwrap_err();
2470 }
2471
2472 #[async_test]
2473 async fn test_modify_channel(driver: DefaultDriver) {
2474 let (mut server, mut client) = test_init(&driver);
2475 let channel = server.get_channel(&mut client).await;
2476
2477 let recv = channel.request_send.call(
2480 ChannelRequest::Modify,
2481 ModifyRequest::TargetVp { target_vp: 1 },
2482 );
2483
2484 check_message(
2485 server.next().await.unwrap(),
2486 protocol::ModifyChannel {
2487 channel_id: ChannelId(0),
2488 target_vp: 1,
2489 },
2490 );
2491
2492 server.send(in_msg(
2493 MessageType::MODIFY_CHANNEL_RESPONSE,
2494 protocol::ModifyChannelResponse {
2495 channel_id: ChannelId(0),
2496 status: protocol::STATUS_SUCCESS,
2497 },
2498 ));
2499
2500 let status = recv.await.unwrap();
2501 assert_eq!(status, protocol::STATUS_SUCCESS);
2502 }
2503
2504 #[async_test]
2505 async fn test_save_restore_connected(driver: DefaultDriver) {
2506 let (mut server, mut client) = test_init(&driver);
2507 server.connect(&mut client).await;
2508 server.stop_client(&mut client).await;
2509 let s0 = client.save().await;
2510 let builder = client.sever().await;
2511 let mut client = builder.build(&driver);
2512 client.restore(s0.clone()).await.unwrap();
2513
2514 let s1 = client.save().await;
2515
2516 assert_eq!(s0, s1);
2517 }
2518
2519 #[async_test]
2520 async fn test_save_restore_connected_with_channel(driver: DefaultDriver) {
2521 let (mut server, mut client) = test_init(&driver);
2522 let c0 = server.get_channel(&mut client).await;
2523 server.stop_client(&mut client).await;
2524 let s0 = client.save().await;
2525 let builder = client.sever().await;
2526 let mut client = builder.build(&driver);
2527 let connection = client.restore(s0.clone()).await.unwrap().unwrap();
2528 let s1 = client.save().await;
2529 assert_eq!(s0, s1);
2530 assert_eq!(connection.offers[0].offer, c0.offer);
2531 }
2532
2533 #[async_test]
2534 async fn test_save_restore_connected_with_revoked_channel(driver: DefaultDriver) {
2535 let (mut server, mut client) = test_init(&driver);
2536 let c0 = server.get_channel(&mut client).await;
2537 server.send(in_msg(
2538 MessageType::RESCIND_CHANNEL_OFFER,
2539 protocol::RescindChannelOffer {
2540 channel_id: ChannelId(0),
2541 },
2542 ));
2543 c0.revoke_recv.await.unwrap();
2544 let rpc = c0.request_send.call(
2545 ChannelRequest::Modify,
2546 ModifyRequest::TargetVp { target_vp: 1 },
2547 );
2548
2549 check_message(
2550 server.next().await.unwrap(),
2551 protocol::ModifyChannel {
2552 channel_id: ChannelId(0),
2553 target_vp: 1,
2554 },
2555 );
2556
2557 let client_stop = client.stop();
2558 let server_stop = async {
2559 server.send(in_msg(
2560 MessageType::MODIFY_CHANNEL_RESPONSE,
2561 protocol::ModifyChannelResponse {
2562 channel_id: ChannelId(0),
2563 status: protocol::STATUS_SUCCESS,
2564 },
2565 ));
2566 check_message(server.next().await.unwrap(), protocol::Pause);
2567 server.send(in_msg(MessageType::PAUSE_RESPONSE, protocol::PauseResponse));
2568 };
2569 (client_stop, server_stop).join().await;
2570
2571 rpc.await.unwrap();
2572
2573 let s0 = client.save().await;
2574 let builder = client.sever().await;
2575 let mut client = builder.build(&driver);
2576 let connection = client.restore(s0.clone()).await.unwrap().unwrap();
2577 let s1 = client.save().await;
2578 assert_eq!(s0, s1);
2579 assert!(connection.offers.is_empty());
2580 server.start_client(&mut client).await;
2581 check_message(
2582 server.next().await.unwrap(),
2583 protocol::RelIdReleased {
2584 channel_id: ChannelId(0),
2585 },
2586 );
2587 }
2588
2589 #[async_test]
2590 async fn test_connect_fails_on_incorrect_state(driver: DefaultDriver) {
2591 let (mut server, mut client) = test_init(&driver);
2592 server.connect(&mut client).await;
2593 let err = client.connect(0, None, Guid::ZERO).await.unwrap_err();
2594 assert!(matches!(err, ConnectError::InvalidState), "{:?}", err);
2595 }
2596
2597 #[async_test]
2598 async fn test_hot_add_remove(driver: DefaultDriver) {
2599 let (mut server, mut client) = test_init(&driver);
2600
2601 let mut connection = server.connect(&mut client).await;
2602 let offer = protocol::OfferChannel {
2603 interface_id: Guid::new_random(),
2604 instance_id: Guid::new_random(),
2605 rsvd: [0; 4],
2606 flags: OfferFlags::new(),
2607 mmio_megabytes: 0,
2608 user_defined: UserDefinedData::new_zeroed(),
2609 subchannel_index: 0,
2610 mmio_megabytes_optional: 0,
2611 channel_id: ChannelId(5),
2612 monitor_id: 0,
2613 monitor_allocated: 0,
2614 is_dedicated: 0,
2615 connection_id: 0,
2616 };
2617
2618 server.send(in_msg(MessageType::OFFER_CHANNEL, offer));
2619 let info = connection.offer_recv.next().await.unwrap();
2620
2621 assert_eq!(offer, info.offer);
2622
2623 server.send(in_msg(
2624 MessageType::RESCIND_CHANNEL_OFFER,
2625 protocol::RescindChannelOffer {
2626 channel_id: ChannelId(5),
2627 },
2628 ));
2629
2630 info.revoke_recv.await.unwrap();
2631 drop(info.request_send);
2632
2633 check_message(
2634 server.next().await.unwrap(),
2635 protocol::RelIdReleased {
2636 channel_id: ChannelId(5),
2637 },
2638 );
2639 }
2640
2641 #[async_test]
2642 async fn test_gpadl_success(driver: DefaultDriver) {
2643 let (mut server, mut client) = test_init(&driver);
2644 let channel = server.get_channel(&mut client).await;
2645 let recv = channel.request_send.call(
2646 ChannelRequest::Gpadl,
2647 GpadlRequest {
2648 id: GpadlId(1),
2649 count: 1,
2650 buf: vec![5],
2651 },
2652 );
2653
2654 check_message_with_data(
2655 server.next().await.unwrap(),
2656 protocol::GpadlHeader {
2657 channel_id: ChannelId(0),
2658 gpadl_id: GpadlId(1),
2659 len: 8,
2660 count: 1,
2661 },
2662 0x5u64.as_bytes(),
2663 );
2664
2665 server.send(in_msg(
2666 MessageType::GPADL_CREATED,
2667 protocol::GpadlCreated {
2668 channel_id: ChannelId(0),
2669 gpadl_id: GpadlId(1),
2670 status: protocol::STATUS_SUCCESS,
2671 },
2672 ));
2673
2674 recv.await.unwrap().unwrap();
2675
2676 let rpc = channel
2677 .request_send
2678 .call(ChannelRequest::TeardownGpadl, GpadlId(1));
2679
2680 check_message(
2681 server.next().await.unwrap(),
2682 protocol::GpadlTeardown {
2683 channel_id: ChannelId(0),
2684 gpadl_id: GpadlId(1),
2685 },
2686 );
2687
2688 server.send(in_msg(
2689 MessageType::GPADL_TORNDOWN,
2690 protocol::GpadlTorndown {
2691 gpadl_id: GpadlId(1),
2692 },
2693 ));
2694
2695 rpc.await.unwrap();
2696 }
2697
2698 #[async_test]
2699 async fn test_gpadl_fail(driver: DefaultDriver) {
2700 let (mut server, mut client) = test_init(&driver);
2701 let channel = server.get_channel(&mut client).await;
2702 let recv = channel.request_send.call(
2703 ChannelRequest::Gpadl,
2704 GpadlRequest {
2705 id: GpadlId(1),
2706 count: 1,
2707 buf: vec![7],
2708 },
2709 );
2710
2711 check_message_with_data(
2712 server.next().await.unwrap(),
2713 protocol::GpadlHeader {
2714 channel_id: ChannelId(0),
2715 gpadl_id: GpadlId(1),
2716 len: 8,
2717 count: 1,
2718 },
2719 0x7u64.as_bytes(),
2720 );
2721
2722 server.send(in_msg(
2723 MessageType::GPADL_CREATED,
2724 protocol::GpadlCreated {
2725 channel_id: ChannelId(0),
2726 gpadl_id: GpadlId(1),
2727 status: protocol::STATUS_UNSUCCESSFUL,
2728 },
2729 ));
2730
2731 recv.await.unwrap().unwrap_err();
2732 }
2733
2734 #[async_test]
2735 async fn test_gpadl_with_revoke(driver: DefaultDriver) {
2736 let (mut server, mut client) = test_init(&driver);
2737 let channel = server.get_channel(&mut client).await;
2738 let channel_id = ChannelId(0);
2739 for gpadl_id in [1, 2, 3].map(GpadlId) {
2740 let recv = channel.request_send.call(
2741 ChannelRequest::Gpadl,
2742 GpadlRequest {
2743 id: gpadl_id,
2744 count: 1,
2745 buf: vec![3],
2746 },
2747 );
2748
2749 check_message_with_data(
2750 server.next().await.unwrap(),
2751 protocol::GpadlHeader {
2752 channel_id,
2753 gpadl_id,
2754 len: 8,
2755 count: 1,
2756 },
2757 0x3u64.as_bytes(),
2758 );
2759
2760 server.send(in_msg(
2761 MessageType::GPADL_CREATED,
2762 protocol::GpadlCreated {
2763 channel_id,
2764 gpadl_id,
2765 status: protocol::STATUS_SUCCESS,
2766 },
2767 ));
2768
2769 recv.await.unwrap().unwrap();
2770 }
2771
2772 let rpc = channel
2773 .request_send
2774 .call(ChannelRequest::TeardownGpadl, GpadlId(1));
2775
2776 check_message(
2777 server.next().await.unwrap(),
2778 protocol::GpadlTeardown {
2779 channel_id,
2780 gpadl_id: GpadlId(1),
2781 },
2782 );
2783
2784 server.send(in_msg(
2785 MessageType::RESCIND_CHANNEL_OFFER,
2786 protocol::RescindChannelOffer { channel_id },
2787 ));
2788
2789 let recv = channel.request_send.call_failable(
2790 ChannelRequest::Gpadl,
2791 GpadlRequest {
2792 id: GpadlId(4),
2793 count: 1,
2794 buf: vec![3],
2795 },
2796 );
2797
2798 check_message_with_data(
2799 server.next().await.unwrap(),
2800 protocol::GpadlHeader {
2801 channel_id,
2802 gpadl_id: GpadlId(4),
2803 len: 8,
2804 count: 1,
2805 },
2806 0x3u64.as_bytes(),
2807 );
2808
2809 server.send(in_msg(
2810 MessageType::GPADL_CREATED,
2811 protocol::GpadlCreated {
2812 channel_id,
2813 gpadl_id: GpadlId(4),
2814 status: protocol::STATUS_UNSUCCESSFUL,
2815 },
2816 ));
2817
2818 server.send(in_msg(
2819 MessageType::GPADL_TORNDOWN,
2820 protocol::GpadlTorndown {
2821 gpadl_id: GpadlId(1),
2822 },
2823 ));
2824
2825 rpc.await.unwrap();
2826 recv.await.unwrap_err();
2827
2828 channel.revoke_recv.await.unwrap();
2829
2830 let rpc = channel
2831 .request_send
2832 .call(ChannelRequest::TeardownGpadl, GpadlId(2));
2833 drop(channel.request_send);
2834
2835 check_message(
2836 server.next().await.unwrap(),
2837 protocol::GpadlTeardown {
2838 channel_id,
2839 gpadl_id: GpadlId(2),
2840 },
2841 );
2842
2843 server.send(in_msg(
2844 MessageType::GPADL_TORNDOWN,
2845 protocol::GpadlTorndown {
2846 gpadl_id: GpadlId(2),
2847 },
2848 ));
2849
2850 rpc.await.unwrap();
2851
2852 check_message(
2853 server.next().await.unwrap(),
2854 protocol::RelIdReleased { channel_id },
2855 );
2856 }
2857
2858 #[async_test]
2859 async fn test_modify_connection(driver: DefaultDriver) {
2860 let (mut server, mut client) = test_init(&driver);
2861 server.connect(&mut client).await;
2862 let call = client.access.client_request_send.call(
2863 ClientRequest::Modify,
2864 ModifyConnectionRequest {
2865 monitor_page: Some(MonitorPageGpas {
2866 child_to_parent: 5,
2867 parent_to_child: 6,
2868 }),
2869 },
2870 );
2871
2872 check_message(
2873 server.next().await.unwrap(),
2874 protocol::ModifyConnection {
2875 child_to_parent_monitor_page_gpa: 5,
2876 parent_to_child_monitor_page_gpa: 6,
2877 },
2878 );
2879
2880 server.send(in_msg(
2881 MessageType::MODIFY_CONNECTION_RESPONSE,
2882 protocol::ModifyConnectionResponse {
2883 connection_state: ConnectionState::FAILED_LOW_RESOURCES,
2884 },
2885 ));
2886
2887 let result = call.await.unwrap();
2888 assert_eq!(ConnectionState::FAILED_LOW_RESOURCES, result);
2889 }
2890
2891 #[async_test]
2892 async fn test_hvsock(driver: DefaultDriver) {
2893 let (mut server, mut client) = test_init(&driver);
2894 server.connect(&mut client).await;
2895 let request = HvsockConnectRequest {
2896 service_id: Guid::new_random(),
2897 endpoint_id: Guid::new_random(),
2898 silo_id: Guid::new_random(),
2899 hosted_silo_unaware: false,
2900 };
2901
2902 let resp = client.access().connect_hvsock(request);
2903 check_message(
2904 server.next().await.unwrap(),
2905 protocol::TlConnectRequest2 {
2906 base: protocol::TlConnectRequest {
2907 service_id: request.service_id,
2908 endpoint_id: request.endpoint_id,
2909 },
2910 silo_id: request.silo_id,
2911 },
2912 );
2913
2914 server.send(in_msg(
2916 MessageType::TL_CONNECT_REQUEST_RESULT,
2917 protocol::TlConnectResult {
2918 service_id: request.service_id,
2919 endpoint_id: request.endpoint_id,
2920 status: protocol::STATUS_CONNECTION_REFUSED,
2921 },
2922 ));
2923
2924 let result = resp.await;
2925 assert!(result.is_none());
2926 }
2927
2928 #[async_test]
2929 async fn test_synic_event_flags(driver: DefaultDriver) {
2930 let (mut server, mut client) = test_init(&driver);
2931 let connection = server.get_channels(&mut client, 5).await;
2932 let event = Event::new();
2933
2934 for _ in 0..5 {
2935 for (i, channel) in connection.offers.iter().enumerate() {
2936 let recv = channel.request_send.call(
2937 ChannelRequest::Open,
2938 OpenRequest {
2939 open_data: OpenData {
2940 target_vp: 0,
2941 ring_offset: 0,
2942 ring_gpadl_id: GpadlId(0),
2943 event_flag: 0,
2944 connection_id: 0,
2945 user_data: UserDefinedData::new_zeroed(),
2946 },
2947 incoming_event: Some(event.clone()),
2948 use_vtl2_connection_id: false,
2949 },
2950 );
2951
2952 let expected_event_flag = i as u16 + 1;
2953
2954 check_message(
2955 server.next().await.unwrap(),
2956 protocol::OpenChannel2 {
2957 open_channel: protocol::OpenChannel {
2958 channel_id: channel.offer.channel_id,
2959 open_id: 0,
2960 ring_buffer_gpadl_id: GpadlId(0),
2961 target_vp: 0,
2962 downstream_ring_buffer_page_offset: 0,
2963 user_data: UserDefinedData::new_zeroed(),
2964 },
2965 connection_id: 0,
2966 event_flag: expected_event_flag,
2967 flags: OpenChannelFlags::new().with_redirect_interrupt(true),
2968 },
2969 );
2970
2971 server.send(in_msg(
2972 MessageType::OPEN_CHANNEL_RESULT,
2973 protocol::OpenResult {
2974 channel_id: channel.offer.channel_id,
2975 open_id: 0,
2976 status: protocol::STATUS_SUCCESS as u32,
2977 },
2978 ));
2979
2980 let output = recv.await.unwrap().unwrap();
2981 assert_eq!(output.redirected_event_flag, Some(expected_event_flag));
2982 }
2983
2984 for (i, channel) in connection.offers.iter().enumerate() {
2985 channel
2988 .request_send
2989 .call(ChannelRequest::Close, ())
2990 .await
2991 .unwrap();
2992
2993 check_message(
2994 server.next().await.unwrap(),
2995 protocol::CloseChannel {
2996 channel_id: ChannelId(i as u32),
2997 },
2998 );
2999 }
3000 }
3001 }
3002
3003 #[async_test]
3004 async fn test_revoke(driver: DefaultDriver) {
3005 let (mut server, mut client) = test_init(&driver);
3006 let channel = server.get_channel(&mut client).await;
3007
3008 server.send(in_msg(
3009 MessageType::RESCIND_CHANNEL_OFFER,
3010 protocol::RescindChannelOffer {
3011 channel_id: ChannelId(0),
3012 },
3013 ));
3014
3015 channel.revoke_recv.await.unwrap();
3016
3017 channel
3018 .request_send
3019 .call_failable(
3020 ChannelRequest::Open,
3021 OpenRequest {
3022 open_data: OpenData {
3023 target_vp: 0,
3024 ring_offset: 0,
3025 ring_gpadl_id: GpadlId(0),
3026 event_flag: 0,
3027 connection_id: 0,
3028 user_data: UserDefinedData::new_zeroed(),
3029 },
3030 incoming_event: None,
3031 use_vtl2_connection_id: false,
3032 },
3033 )
3034 .await
3035 .unwrap_err();
3036 }
3037
3038 #[async_test]
3039 #[should_panic(expected = "channel should not exist")]
3040 async fn test_reoffer_in_use_rel_id(driver: DefaultDriver) {
3041 let (mut server, mut client) = test_init(&driver);
3042 let mut connection = server.get_channels(&mut client, 1).await;
3043 let [channel] = connection.offers.try_into().unwrap();
3044
3045 server.send(in_msg(
3046 MessageType::RESCIND_CHANNEL_OFFER,
3047 protocol::RescindChannelOffer {
3048 channel_id: ChannelId(0),
3049 },
3050 ));
3051
3052 channel.revoke_recv.await.unwrap();
3053
3054 let offer = protocol::OfferChannel {
3056 interface_id: Guid::new_random(),
3057 instance_id: Guid::new_random(),
3058 rsvd: [0; 4],
3059 flags: OfferFlags::new(),
3060 mmio_megabytes: 0,
3061 user_defined: UserDefinedData::new_zeroed(),
3062 subchannel_index: 0,
3063 mmio_megabytes_optional: 0,
3064 channel_id: ChannelId(0),
3065 monitor_id: 0,
3066 monitor_allocated: 0,
3067 is_dedicated: 0,
3068 connection_id: 0,
3069 };
3070
3071 server.send(in_msg(MessageType::OFFER_CHANNEL, offer));
3072
3073 connection.offer_recv.next().await;
3074 }
3075
3076 #[async_test]
3077 async fn test_revoke_release_and_reoffer(driver: DefaultDriver) {
3078 let (mut server, mut client) = test_init(&driver);
3079 let mut connection = server.get_channels(&mut client, 1).await;
3080 let [channel] = connection.offers.try_into().unwrap();
3081
3082 server.send(in_msg(
3083 MessageType::RESCIND_CHANNEL_OFFER,
3084 protocol::RescindChannelOffer {
3085 channel_id: ChannelId(0),
3086 },
3087 ));
3088
3089 channel.revoke_recv.await.unwrap();
3090 drop(channel.request_send);
3091
3092 check_message(
3093 server.next().await.unwrap(),
3094 protocol::RelIdReleased {
3095 channel_id: ChannelId(0),
3096 },
3097 );
3098
3099 let offer = protocol::OfferChannel {
3100 interface_id: Guid::new_random(),
3101 instance_id: Guid::new_random(),
3102 rsvd: [0; 4],
3103 flags: OfferFlags::new(),
3104 mmio_megabytes: 0,
3105 user_defined: UserDefinedData::new_zeroed(),
3106 subchannel_index: 0,
3107 mmio_megabytes_optional: 0,
3108 channel_id: ChannelId(0),
3109 monitor_id: 0,
3110 monitor_allocated: 0,
3111 is_dedicated: 0,
3112 connection_id: 0,
3113 };
3114
3115 server.send(in_msg(MessageType::OFFER_CHANNEL, offer));
3116
3117 connection.offer_recv.next().await.unwrap();
3118 }
3119
3120 #[async_test]
3121 async fn test_release_revoke_and_reoffer(driver: DefaultDriver) {
3122 let (mut server, mut client) = test_init(&driver);
3123 let mut connection = server.get_channels(&mut client, 1).await;
3124 let [channel] = connection.offers.try_into().unwrap();
3125
3126 let open = channel.request_send.call_failable(
3127 ChannelRequest::Open,
3128 OpenRequest {
3129 open_data: OpenData {
3130 target_vp: 0,
3131 ring_offset: 0,
3132 ring_gpadl_id: GpadlId(0),
3133 event_flag: 0,
3134 connection_id: 0,
3135 user_data: UserDefinedData::new_zeroed(),
3136 },
3137 incoming_event: None,
3138 use_vtl2_connection_id: false,
3139 },
3140 );
3141
3142 let server_open = async {
3143 check_message(
3144 server.next().await.unwrap(),
3145 protocol::OpenChannel2 {
3146 open_channel: protocol::OpenChannel {
3147 channel_id: ChannelId(0),
3148 open_id: 0,
3149 ring_buffer_gpadl_id: GpadlId(0),
3150 target_vp: 0,
3151 downstream_ring_buffer_page_offset: 0,
3152 user_data: UserDefinedData::new_zeroed(),
3153 },
3154 connection_id: 0,
3155 event_flag: 0,
3156 flags: Default::default(),
3157 },
3158 );
3159 server.send(in_msg(
3160 MessageType::OPEN_CHANNEL_RESULT,
3161 protocol::OpenResult {
3162 channel_id: ChannelId(0),
3163 open_id: 0,
3164 status: protocol::STATUS_SUCCESS as u32,
3165 },
3166 ));
3167 };
3168
3169 (open, server_open).join().await.0.unwrap();
3170
3171 drop(channel);
3173
3174 check_message(
3175 server.next().await.unwrap(),
3176 protocol::CloseChannel {
3177 channel_id: ChannelId(0),
3178 },
3179 );
3180
3181 server.send(in_msg(
3182 MessageType::RESCIND_CHANNEL_OFFER,
3183 protocol::RescindChannelOffer {
3184 channel_id: ChannelId(0),
3185 },
3186 ));
3187
3188 check_message(
3190 server.next().await.unwrap(),
3191 protocol::RelIdReleased {
3192 channel_id: ChannelId(0),
3193 },
3194 );
3195
3196 let offer = protocol::OfferChannel {
3197 interface_id: Guid::new_random(),
3198 instance_id: Guid::new_random(),
3199 rsvd: [0; 4],
3200 flags: OfferFlags::new(),
3201 mmio_megabytes: 0,
3202 user_defined: UserDefinedData::new_zeroed(),
3203 subchannel_index: 0,
3204 mmio_megabytes_optional: 0,
3205 channel_id: ChannelId(0),
3206 monitor_id: 0,
3207 monitor_allocated: 0,
3208 is_dedicated: 0,
3209 connection_id: 0,
3210 };
3211
3212 server.send(in_msg(MessageType::OFFER_CHANNEL, offer));
3213
3214 connection.offer_recv.next().await.unwrap();
3216 }
3217}