1mod saved_state;
5
6use crate::Guid;
7use crate::SINT;
8use crate::SynicMessage;
9use crate::monitor::AssignedMonitors;
10use crate::protocol::Version;
11use hvdef::Vtl;
12use inspect::Inspect;
13pub use saved_state::RestoreError;
14pub use saved_state::SavedState;
15pub use saved_state::SavedStateData;
16use slab::Slab;
17use std::cmp::min;
18use std::collections::VecDeque;
19use std::collections::hash_map::Entry;
20use std::collections::hash_map::HashMap;
21use std::fmt::Display;
22use std::ops::Index;
23use std::ops::IndexMut;
24use std::task::Poll;
25use std::task::ready;
26use std::time::Duration;
27use thiserror::Error;
28use vmbus_channel::bus::ChannelType;
29use vmbus_channel::bus::GpadlRequest;
30use vmbus_channel::bus::OfferKey;
31use vmbus_channel::bus::OfferParams;
32use vmbus_channel::bus::OpenData;
33use vmbus_channel::bus::RestoredGpadl;
34use vmbus_core::HvsockConnectRequest;
35use vmbus_core::HvsockConnectResult;
36use vmbus_core::MaxVersionInfo;
37use vmbus_core::OutgoingMessage;
38use vmbus_core::VersionInfo;
39use vmbus_core::protocol;
40use vmbus_core::protocol::ChannelId;
41use vmbus_core::protocol::ConnectionId;
42use vmbus_core::protocol::FeatureFlags;
43use vmbus_core::protocol::GpadlId;
44use vmbus_core::protocol::Message;
45use vmbus_core::protocol::OfferFlags;
46use vmbus_core::protocol::UserDefinedData;
47use vmbus_ring::gparange;
48use vmcore::monitor::MonitorId;
49use vmcore::synic::MonitorInfo;
50use vmcore::synic::MonitorPageGpas;
51use zerocopy::FromZeros;
52use zerocopy::Immutable;
53use zerocopy::IntoBytes;
54use zerocopy::KnownLayout;
55
56#[derive(Debug, Error)]
58pub enum ChannelError {
59 #[error("unknown channel ID")]
60 UnknownChannelId,
61 #[error("unknown GPADL ID")]
62 UnknownGpadlId,
63 #[error("parse error")]
64 ParseError(#[from] protocol::ParseError),
65 #[error("invalid gpa range")]
66 InvalidGpaRange(#[source] gparange::Error),
67 #[error("duplicate GPADL ID")]
68 DuplicateGpadlId,
69 #[error("GPADL is already complete")]
70 GpadlAlreadyComplete,
71 #[error("GPADL channel ID mismatch")]
72 WrongGpadlChannelId,
73 #[error("trying to open an open channel")]
74 ChannelAlreadyOpen,
75 #[error("trying to close a closed channel")]
76 ChannelNotOpen,
77 #[error("invalid GPADL state for operation")]
78 InvalidGpadlState,
79 #[error("invalid channel state for operation")]
80 InvalidChannelState,
81 #[error("channel ID has already been released")]
82 ChannelReleased,
83 #[error("channel offers have already been sent")]
84 OffersAlreadySent,
85 #[error("invalid operation on reserved channel")]
86 ChannelReserved,
87 #[error("invalid operation on non-reserved channel")]
88 ChannelNotReserved,
89 #[error("received untrusted message for trusted connection")]
90 UntrustedMessage,
91 #[error("received a non-resuming message while paused")]
92 Paused,
93}
94
95#[derive(Debug, Error)]
96pub enum OfferError {
97 #[error("the channel ID {} is not valid for this operation", (.0).0)]
98 InvalidChannelId(ChannelId),
99 #[error("the channel ID {} is already in use", (.0).0)]
100 ChannelIdInUse(ChannelId),
101 #[error("offer {0} already exists")]
102 AlreadyExists(OfferKey),
103 #[error("specified resources do not match those of the existing saved or revoked offer")]
104 IncompatibleResources,
105 #[error("too many channels have been offered")]
106 TooManyChannels,
107 #[error("mismatched monitor ID from saved state; expected {0:?}, actual {1:?}")]
108 MismatchedMonitorId(Option<MonitorId>, MonitorId),
109}
110
111#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
113pub struct OfferId(usize);
114
115type IncompleteGpadlMap = HashMap<GpadlId, OfferId>;
116
117type GpadlMap = HashMap<(GpadlId, OfferId), Gpadl>;
118
119pub struct Server {
121 state: ConnectionState,
122 channels: ChannelList,
123 assigned_channels: AssignedChannels,
124 assigned_monitors: AssignedMonitors,
125 gpadls: GpadlMap,
126 incomplete_gpadls: IncompleteGpadlMap,
127 child_connection_id: u32,
128 max_version: Option<MaxVersionInfo>,
129 delayed_max_version: Option<MaxVersionInfo>,
130 pending_messages: PendingMessages,
133}
134
135pub struct ServerWithNotifier<'a, T> {
136 inner: &'a mut Server,
137 notifier: &'a mut T,
138}
139
140impl<T> Drop for ServerWithNotifier<'_, T> {
141 fn drop(&mut self) {
142 self.inner.validate();
143 }
144}
145
146impl<T: Notifier> Inspect for ServerWithNotifier<'_, T> {
147 fn inspect(&self, req: inspect::Request<'_>) {
148 let mut resp = req.respond();
149 let (state, info, next_action) = match &self.inner.state {
150 ConnectionState::Disconnected => ("disconnected", None, None),
151 ConnectionState::Connecting { info, .. } => ("connecting", Some(info), None),
152 ConnectionState::Connected(info) => (
153 if info.offers_sent {
154 "connected"
155 } else {
156 "negotiated"
157 },
158 Some(info),
159 None,
160 ),
161 ConnectionState::Disconnecting { next_action, .. } => {
162 ("disconnecting", None, Some(next_action))
163 }
164 };
165
166 resp.field("connection_info", info);
167 let next_action = next_action.map(|a| match a {
168 ConnectionAction::None => "disconnect",
169 ConnectionAction::Reset => "reset",
170 ConnectionAction::SendUnloadComplete => "unload",
171 ConnectionAction::Reconnect { .. } => "reconnect",
172 ConnectionAction::SendFailedVersionResponse => "send_version_response",
173 });
174 resp.field("state", state)
175 .field("next_action", next_action)
176 .field(
177 "assigned_monitors_bitmap",
178 format_args!("{:x}", self.inner.assigned_monitors.bitmap()),
179 )
180 .child("channels", |req| {
181 let mut resp = req.respond();
182 self.inner
183 .channels
184 .inspect(self.notifier, self.inner.get_version(), &mut resp);
185 for ((gpadl_id, offer_id), gpadl) in &self.inner.gpadls {
186 let channel = &self.inner.channels[*offer_id];
187 resp.field(
188 &channel_inspect_path(
189 &channel.offer,
190 format_args!("/gpadls/{}", gpadl_id.0),
191 ),
192 gpadl,
193 );
194 }
195 });
196 }
197}
198
199#[derive(Debug, Copy, Clone, Inspect)]
200struct ConnectionInfo {
201 version: VersionInfo,
202 trusted: bool,
205 offers_sent: bool,
206 interrupt_page: Option<u64>,
207 monitor_page: Option<MonitorPageGpas>,
208 target_message_vp: u32,
209 modifying: bool,
210 client_id: Guid,
211 paused: bool,
212}
213
214#[derive(Debug)]
216enum ConnectionState {
217 Disconnected,
218 Disconnecting {
219 next_action: ConnectionAction,
220 modify_sent: bool,
221 },
222 Connecting {
223 info: ConnectionInfo,
224 next_action: ConnectionAction,
225 },
226 Connected(ConnectionInfo),
227}
228
229impl ConnectionState {
230 fn check_version(&self, min_version: Version) -> bool {
232 matches!(self, ConnectionState::Connected(info) if info.version.version >= min_version)
233 }
234
235 fn check_feature_flags(&self, flags: impl Fn(FeatureFlags) -> bool) -> bool {
238 matches!(self, ConnectionState::Connected(info) if flags(info.version.feature_flags))
239 }
240
241 fn get_version(&self) -> Option<VersionInfo> {
242 if let ConnectionState::Connected(info) = self {
243 Some(info.version)
244 } else {
245 None
246 }
247 }
248
249 fn is_trusted(&self) -> bool {
250 match self {
251 ConnectionState::Connected(info) => info.trusted,
252 ConnectionState::Connecting { info, .. } => info.trusted,
253 _ => false,
254 }
255 }
256
257 fn is_paused(&self) -> bool {
258 if let ConnectionState::Connected(info) = self {
259 info.paused
260 } else {
261 false
262 }
263 }
264}
265
266#[derive(Debug, Copy, Clone)]
267enum ConnectionAction {
268 None,
269 Reset,
270 SendUnloadComplete,
271 Reconnect {
272 initiate_contact: InitiateContactRequest,
273 },
274 SendFailedVersionResponse,
275}
276
277#[derive(PartialEq, Eq, Debug, Copy, Clone)]
278pub enum MonitorPageRequest {
279 None,
280 Some(MonitorPageGpas),
281 Invalid,
282}
283
284#[derive(PartialEq, Eq, Debug, Copy, Clone)]
285pub struct InitiateContactRequest {
286 pub version_requested: u32,
287 pub target_message_vp: u32,
288 pub monitor_page: MonitorPageRequest,
289 pub target_sint: u8,
290 pub target_vtl: u8,
291 pub feature_flags: u32,
292 pub interrupt_page: Option<u64>,
293 pub client_id: Guid,
294 pub trusted: bool,
295}
296
297#[derive(Debug, Copy, Clone)]
298pub struct OpenRequest {
299 pub open_id: u32,
300 pub ring_buffer_gpadl_id: GpadlId,
301 pub target_vp: u32,
302 pub downstream_ring_buffer_page_offset: u32,
303 pub user_data: UserDefinedData,
304 pub guest_specified_interrupt_info: Option<SignalInfo>,
305 pub flags: protocol::OpenChannelFlags,
306}
307
308#[derive(Debug, Copy, Clone, Eq, PartialEq)]
309pub enum Update<T: std::fmt::Debug + Copy + Clone> {
310 Unchanged,
311 Reset,
312 Set(T),
313}
314
315impl<T: std::fmt::Debug + Copy + Clone> From<Option<T>> for Update<T> {
316 fn from(value: Option<T>) -> Self {
317 match value {
318 None => Self::Reset,
319 Some(value) => Self::Set(value),
320 }
321 }
322}
323
324#[derive(Debug, Copy, Clone, Eq, PartialEq)]
325pub struct ModifyConnectionRequest {
326 pub version: Option<u32>,
327 pub monitor_page: Update<MonitorPageGpas>,
328 pub interrupt_page: Update<u64>,
329 pub target_message_vp: Option<u32>,
330 pub notify_relay: bool,
331}
332
333impl Default for ModifyConnectionRequest {
335 fn default() -> Self {
336 Self {
337 version: None,
338 monitor_page: Update::Unchanged,
339 interrupt_page: Update::Unchanged,
340 target_message_vp: None,
341 notify_relay: true,
342 }
343 }
344}
345
346impl From<protocol::ModifyConnection> for ModifyConnectionRequest {
347 fn from(value: protocol::ModifyConnection) -> Self {
348 let monitor_page = if value.parent_to_child_monitor_page_gpa != 0 {
349 Update::Set(MonitorPageGpas {
350 parent_to_child: value.parent_to_child_monitor_page_gpa,
351 child_to_parent: value.child_to_parent_monitor_page_gpa,
352 })
353 } else {
354 Update::Reset
355 };
356
357 Self {
358 monitor_page,
359 ..Default::default()
360 }
361 }
362}
363
364#[derive(Debug, Copy, Clone)]
366pub enum ModifyConnectionResponse {
367 Supported(protocol::ConnectionState, FeatureFlags),
372 Unsupported,
375}
376
377#[derive(Debug, Copy, Clone)]
378pub enum ModifyState {
379 NotModifying,
380 Modifying { pending_target_vp: Option<u32> },
381}
382
383impl ModifyState {
384 pub fn is_modifying(&self) -> bool {
385 matches!(self, ModifyState::Modifying { .. })
386 }
387}
388
389#[derive(Debug, Copy, Clone)]
390pub struct SignalInfo {
391 pub event_flag: u16,
392 pub connection_id: u32,
393}
394
395#[derive(Debug, Copy, Clone, PartialEq, Eq)]
396enum RestoreState {
397 New,
399 Restoring,
403 Unmatched,
406 Restored,
408}
409
410#[derive(Debug, Clone)]
412enum ChannelState {
413 ClientReleased,
417
418 Closed,
420
421 Opening {
424 request: OpenRequest,
425 reserved_state: Option<ReservedState>,
426 },
427
428 Open {
430 params: OpenRequest,
431 modify_state: ModifyState,
432 reserved_state: Option<ReservedState>,
433 },
434
435 Closing {
437 params: OpenRequest,
438 reserved_state: Option<ReservedState>,
439 },
440
441 ClosingReopen {
444 params: OpenRequest,
445 request: OpenRequest,
446 },
447
448 Revoked,
450
451 Reoffered,
454
455 ClosingClientRelease,
458
459 OpeningClientRelease,
462}
463
464impl ChannelState {
465 fn is_released(&self) -> bool {
468 match self {
469 ChannelState::Closed
470 | ChannelState::Opening { .. }
471 | ChannelState::Open { .. }
472 | ChannelState::Closing { .. }
473 | ChannelState::ClosingReopen { .. }
474 | ChannelState::Revoked
475 | ChannelState::Reoffered => false,
476
477 ChannelState::ClientReleased
478 | ChannelState::ClosingClientRelease
479 | ChannelState::OpeningClientRelease => true,
480 }
481 }
482
483 fn is_revoked(&self) -> bool {
485 match self {
486 ChannelState::Revoked | ChannelState::Reoffered => true,
487
488 ChannelState::ClientReleased
489 | ChannelState::Closed
490 | ChannelState::Opening { .. }
491 | ChannelState::Open { .. }
492 | ChannelState::Closing { .. }
493 | ChannelState::ClosingReopen { .. }
494 | ChannelState::ClosingClientRelease
495 | ChannelState::OpeningClientRelease => false,
496 }
497 }
498
499 fn is_reserved(&self) -> bool {
500 match self {
501 ChannelState::Open {
503 reserved_state: Some(_),
504 ..
505 }
506 | ChannelState::Opening {
507 reserved_state: Some(_),
508 ..
509 }
510 | ChannelState::Closing {
511 reserved_state: Some(_),
512 ..
513 } => true,
514
515 ChannelState::Opening { .. }
516 | ChannelState::Open { .. }
517 | ChannelState::Closing { .. }
518 | ChannelState::ClientReleased
519 | ChannelState::Closed
520 | ChannelState::ClosingReopen { .. }
521 | ChannelState::Revoked
522 | ChannelState::Reoffered
523 | ChannelState::ClosingClientRelease
524 | ChannelState::OpeningClientRelease => false,
525 }
526 }
527}
528
529impl Display for ChannelState {
530 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
531 let state = match self {
532 Self::ClientReleased => "ClientReleased",
533 Self::Closed => "Closed",
534 Self::Opening { .. } => "Opening",
535 Self::Open { .. } => "Open",
536 Self::Closing { .. } => "Closing",
537 Self::ClosingReopen { .. } => "ClosingReopen",
538 Self::Revoked => "Revoked",
539 Self::Reoffered => "Reoffered",
540 Self::ClosingClientRelease => "ClosingClientRelease",
541 Self::OpeningClientRelease => "OpeningClientRelease",
542 };
543 write!(f, "{}", state)
544 }
545}
546
547#[derive(Debug, Clone, Default, mesh::MeshPayload)]
549pub enum MnfUsage {
550 #[default]
552 Disabled,
553 Enabled { latency: Duration },
555 Relayed { monitor_id: u8 },
558}
559
560impl MnfUsage {
561 pub fn is_enabled(&self) -> bool {
562 matches!(self, Self::Enabled { .. })
563 }
564
565 pub fn is_relayed(&self) -> bool {
566 matches!(self, Self::Relayed { .. })
567 }
568
569 pub fn enabled_and_then<T>(&self, f: impl FnOnce(Duration) -> Option<T>) -> Option<T> {
570 if let Self::Enabled { latency } = self {
571 f(*latency)
572 } else {
573 None
574 }
575 }
576}
577
578impl From<Option<Duration>> for MnfUsage {
579 fn from(value: Option<Duration>) -> Self {
580 match value {
581 None => Self::Disabled,
582 Some(latency) => Self::Enabled { latency },
583 }
584 }
585}
586
587#[derive(Debug, Clone, Default, mesh::MeshPayload)]
588pub struct OfferParamsInternal {
589 pub interface_name: String,
591 pub instance_id: Guid,
592 pub interface_id: Guid,
593 pub mmio_megabytes: u16,
594 pub mmio_megabytes_optional: u16,
595 pub subchannel_index: u16,
596 pub use_mnf: MnfUsage,
597 pub offer_order: Option<u32>,
598 pub flags: OfferFlags,
599 pub user_defined: UserDefinedData,
600}
601
602impl OfferParamsInternal {
603 pub fn key(&self) -> OfferKey {
605 OfferKey {
606 interface_id: self.interface_id,
607 instance_id: self.instance_id,
608 subchannel_index: self.subchannel_index,
609 }
610 }
611}
612
613impl From<OfferParams> for OfferParamsInternal {
614 fn from(value: OfferParams) -> Self {
615 let mut user_defined = UserDefinedData::new_zeroed();
616
617 let mut flags = OfferFlags::new()
620 .with_confidential_ring_buffer(true)
621 .with_confidential_external_memory(value.allow_confidential_external_memory);
622
623 match value.channel_type {
624 ChannelType::Device { pipe_packets } => {
625 if pipe_packets {
626 flags.set_named_pipe_mode(true);
627 user_defined.as_pipe_params_mut().pipe_type = protocol::PipeType::MESSAGE;
628 }
629 }
630 ChannelType::Interface {
631 user_defined: interface_user_defined,
632 } => {
633 flags.set_enumerate_device_interface(true);
634 user_defined = interface_user_defined;
635 }
636 ChannelType::Pipe { message_mode } => {
637 flags.set_enumerate_device_interface(true);
638 flags.set_named_pipe_mode(true);
639 user_defined.as_pipe_params_mut().pipe_type = if message_mode {
640 protocol::PipeType::MESSAGE
641 } else {
642 protocol::PipeType::BYTE
643 };
644 }
645 ChannelType::HvSocket {
646 is_connect,
647 is_for_container,
648 silo_id,
649 } => {
650 flags.set_enumerate_device_interface(true);
651 flags.set_tlnpi_provider(true);
652 flags.set_named_pipe_mode(true);
653 *user_defined.as_hvsock_params_mut() = protocol::HvsockUserDefinedParameters::new(
654 is_connect,
655 is_for_container,
656 silo_id,
657 );
658 }
659 };
660
661 Self {
662 interface_name: value.interface_name,
663 instance_id: value.instance_id,
664 interface_id: value.interface_id,
665 mmio_megabytes: value.mmio_megabytes,
666 mmio_megabytes_optional: value.mmio_megabytes_optional,
667 subchannel_index: value.subchannel_index,
668 use_mnf: value.mnf_interrupt_latency.into(),
669 offer_order: value.offer_order,
670 user_defined,
671 flags,
672 }
673 }
674}
675
676#[derive(Debug, Copy, Clone, Inspect, PartialEq, Eq)]
677pub struct ConnectionTarget {
678 pub vp: u32,
679 pub sint: u8,
680}
681
682#[derive(Debug, Copy, Clone, PartialEq, Eq)]
683pub enum MessageTarget {
684 Default,
685 ReservedChannel(OfferId, ConnectionTarget),
686 Custom(ConnectionTarget),
687}
688
689impl MessageTarget {
690 pub fn for_offer(offer_id: OfferId, reserved_state: &Option<ReservedState>) -> Self {
691 if let Some(state) = reserved_state {
692 Self::ReservedChannel(offer_id, state.target)
693 } else {
694 Self::Default
695 }
696 }
697}
698
699#[derive(Debug, Copy, Clone)]
700pub struct ReservedState {
701 version: VersionInfo,
702 target: ConnectionTarget,
703}
704
705#[derive(Debug)]
707struct Channel {
708 info: Option<OfferedInfo>,
709 offer: OfferParamsInternal,
710 state: ChannelState,
711 restore_state: RestoreState,
712}
713
714#[derive(Debug, Copy, Clone)]
715struct OfferedInfo {
716 channel_id: ChannelId,
717 connection_id: u32,
718 monitor_id: Option<MonitorId>,
719}
720
721impl Channel {
722 fn inspect_state(&self, resp: &mut inspect::Response<'_>) {
723 let mut target_vp = None;
724 let mut event_flag = None;
725 let mut connection_id = None;
726 let mut reserved_target = None;
727 let state = match &self.state {
728 ChannelState::ClientReleased => "client_released",
729 ChannelState::Closed => "closed",
730 ChannelState::Opening { reserved_state, .. } => {
731 reserved_target = reserved_state.map(|state| state.target);
732 "opening"
733 }
734 ChannelState::Open {
735 params,
736 reserved_state,
737 ..
738 } => {
739 target_vp = Some(params.target_vp);
740 if let Some(id) = params.guest_specified_interrupt_info {
741 event_flag = Some(id.event_flag);
742 connection_id = Some(id.connection_id);
743 }
744 reserved_target = reserved_state.map(|state| state.target);
745 "open"
746 }
747 ChannelState::Closing { reserved_state, .. } => {
748 reserved_target = reserved_state.map(|state| state.target);
749 "closing"
750 }
751 ChannelState::ClosingReopen { .. } => "closing_reopen",
752 ChannelState::Revoked => "revoked",
753 ChannelState::Reoffered => "reoffered",
754 ChannelState::ClosingClientRelease => "closing_client_release",
755 ChannelState::OpeningClientRelease => "opening_client_release",
756 };
757 let restore_state = match self.restore_state {
758 RestoreState::New => "new",
759 RestoreState::Restoring => "restoring",
760 RestoreState::Restored => "restored",
761 RestoreState::Unmatched => "unmatched",
762 };
763 if let Some(info) = &self.info {
764 resp.field("channel_id", info.channel_id.0)
765 .field("offered_connection_id", info.connection_id)
766 .field("monitor_id", info.monitor_id.map(|id| id.0));
767 }
768 resp.field("state", state)
769 .field("restore_state", restore_state)
770 .field("interface_name", self.offer.interface_name.clone())
771 .display("instance_id", &self.offer.instance_id)
772 .display("interface_id", &self.offer.interface_id)
773 .field("mmio_megabytes", self.offer.mmio_megabytes)
774 .field("target_vp", target_vp)
775 .field("guest_specified_event_flag", event_flag)
776 .field("guest_specified_connection_id", connection_id)
777 .field("reserved_connection_target", reserved_target)
778 .binary("offer_flags", self.offer.flags.into_bits());
779 }
780
781 fn handled_monitor_info(&self) -> Option<MonitorInfo> {
791 self.offer.use_mnf.enabled_and_then(|latency| {
792 if self.state.is_reserved() {
793 None
794 } else {
795 self.info.and_then(|info| {
796 info.monitor_id.map(|monitor_id| MonitorInfo {
797 monitor_id,
798 latency,
799 })
800 })
801 }
802 })
803 }
804
805 fn prepare_channel(
808 &mut self,
809 offer_id: OfferId,
810 assigned_channels: &mut AssignedChannels,
811 assigned_monitors: &mut AssignedMonitors,
812 ) {
813 assert!(self.info.is_none());
814
815 let entry = assigned_channels
817 .allocate()
818 .expect("there are enough channel IDs for everything in ChannelList");
819
820 let channel_id = entry.id();
821 entry.insert(offer_id);
822 let connection_id = ConnectionId::new(channel_id.0, assigned_channels.vtl, SINT);
823
824 let monitor_id = match self.offer.use_mnf {
829 MnfUsage::Enabled { .. } => {
830 let monitor_id = assigned_monitors.assign_monitor();
831 if monitor_id.is_none() {
832 tracelimit::warn_ratelimited!("Out of monitor IDs.");
833 }
834
835 monitor_id
836 }
837 MnfUsage::Relayed { monitor_id } => Some(MonitorId(monitor_id)),
838 MnfUsage::Disabled => None,
839 };
840
841 self.info = Some(OfferedInfo {
842 channel_id,
843 connection_id: connection_id.0,
844 monitor_id,
845 });
846 }
847
848 fn release_channel(
850 &mut self,
851 offer_id: OfferId,
852 assigned_channels: &mut AssignedChannels,
853 assigned_monitors: &mut AssignedMonitors,
854 ) {
855 if let Some(info) = self.info.take() {
856 assigned_channels.free(info.channel_id, offer_id);
857
858 if let Some(monitor_id) = info.monitor_id {
860 if self.offer.use_mnf.is_enabled() {
861 assigned_monitors.release_monitor(monitor_id);
862 }
863 }
864 }
865 }
866}
867
868#[derive(Debug)]
869struct AssignedChannels {
870 assignments: Vec<Option<OfferId>>,
871 vtl: Vtl,
872 reserved_offset: usize,
873 count_in_reserved_range: usize,
875}
876
877impl AssignedChannels {
878 fn new(vtl: Vtl, channel_id_offset: u16) -> Self {
879 Self {
880 assignments: vec![None; MAX_CHANNELS],
881 vtl,
882 reserved_offset: channel_id_offset as usize,
883 count_in_reserved_range: 0,
884 }
885 }
886
887 fn allowable_channel_count(&self) -> usize {
888 MAX_CHANNELS - self.reserved_offset + self.count_in_reserved_range
889 }
890
891 fn get(&self, channel_id: ChannelId) -> Option<OfferId> {
892 self.assignments
893 .get(Self::index(channel_id))
894 .copied()
895 .flatten()
896 }
897
898 fn set(&mut self, channel_id: ChannelId) -> Result<AssignmentEntry<'_>, OfferError> {
899 let index = Self::index(channel_id);
900 if self
901 .assignments
902 .get(index)
903 .ok_or(OfferError::InvalidChannelId(channel_id))?
904 .is_some()
905 {
906 return Err(OfferError::ChannelIdInUse(channel_id));
907 }
908 Ok(AssignmentEntry { list: self, index })
909 }
910
911 fn allocate(&mut self) -> Option<AssignmentEntry<'_>> {
912 let index = self.reserved_offset
913 + self.assignments[self.reserved_offset..]
914 .iter()
915 .position(|x| x.is_none())?;
916 Some(AssignmentEntry { list: self, index })
917 }
918
919 fn free(&mut self, channel_id: ChannelId, offer_id: OfferId) {
920 let index = Self::index(channel_id);
921 let slot = &mut self.assignments[index];
922 assert_eq!(slot.take(), Some(offer_id));
923 if index < self.reserved_offset {
924 self.count_in_reserved_range -= 1;
925 }
926 }
927
928 fn index(channel_id: ChannelId) -> usize {
929 channel_id.0.wrapping_sub(1) as usize
930 }
931}
932
933struct AssignmentEntry<'a> {
934 list: &'a mut AssignedChannels,
935 index: usize,
936}
937
938impl AssignmentEntry<'_> {
939 pub fn id(&self) -> ChannelId {
940 ChannelId(self.index as u32 + 1)
941 }
942
943 pub fn insert(self, offer_id: OfferId) {
944 assert!(
945 self.list.assignments[self.index]
946 .replace(offer_id)
947 .is_none()
948 );
949
950 if self.index < self.list.reserved_offset {
951 self.list.count_in_reserved_range += 1;
952 }
953 }
954}
955
956struct ChannelList {
957 channels: Slab<Channel>,
958}
959
960fn channel_inspect_path(offer: &OfferParamsInternal, suffix: std::fmt::Arguments<'_>) -> String {
961 if offer.subchannel_index == 0 {
962 format!("{}{}", offer.instance_id, suffix)
963 } else {
964 format!(
965 "{}/subchannels/{}{}",
966 offer.instance_id, offer.subchannel_index, suffix
967 )
968 }
969}
970
971impl ChannelList {
972 fn inspect(
973 &self,
974 notifier: &impl Notifier,
975 version: Option<VersionInfo>,
976 resp: &mut inspect::Response<'_>,
977 ) {
978 for (offer_id, channel) in self.iter() {
979 resp.child(
980 &channel_inspect_path(&channel.offer, format_args!("")),
981 |req| {
982 let mut resp = req.respond();
983 channel.inspect_state(&mut resp);
984
985 resp.merge(inspect::adhoc(|req| {
989 if !matches!(channel.state, ChannelState::Revoked) {
990 notifier.inspect(version, offer_id, req);
991 }
992 }));
993 },
994 );
995 }
996 }
997}
998
999pub const MAX_CHANNELS: usize = 2047;
1002
1003impl ChannelList {
1004 fn new() -> Self {
1005 Self {
1006 channels: Slab::new(),
1007 }
1008 }
1009
1010 fn len(&self) -> usize {
1012 self.channels.len()
1013 }
1014
1015 fn offer(&mut self, new_channel: Channel) -> OfferId {
1017 OfferId(self.channels.insert(new_channel))
1018 }
1019
1020 fn remove(&mut self, offer_id: OfferId) {
1022 let channel = self.channels.remove(offer_id.0);
1023 assert!(channel.info.is_none());
1024 }
1025
1026 fn get_by_channel_id_mut(
1028 &mut self,
1029 assigned_channels: &AssignedChannels,
1030 channel_id: ChannelId,
1031 ) -> Result<(OfferId, &mut Channel), ChannelError> {
1032 let offer_id = assigned_channels
1033 .get(channel_id)
1034 .ok_or(ChannelError::UnknownChannelId)?;
1035 let channel = &mut self[offer_id];
1036 if channel.state.is_released() {
1037 return Err(ChannelError::ChannelReleased);
1038 }
1039 assert_eq!(
1040 channel.info.as_ref().map(|info| info.channel_id),
1041 Some(channel_id)
1042 );
1043 Ok((offer_id, channel))
1044 }
1045
1046 fn get_by_channel_id(
1048 &self,
1049 assigned_channels: &AssignedChannels,
1050 channel_id: ChannelId,
1051 ) -> Result<(OfferId, &Channel), ChannelError> {
1052 let offer_id = assigned_channels
1053 .get(channel_id)
1054 .ok_or(ChannelError::UnknownChannelId)?;
1055 let channel = &self[offer_id];
1056 if channel.state.is_released() {
1057 return Err(ChannelError::ChannelReleased);
1058 }
1059 assert_eq!(
1060 channel.info.as_ref().map(|info| info.channel_id),
1061 Some(channel_id)
1062 );
1063 Ok((offer_id, channel))
1064 }
1065
1066 fn get_by_key_mut(&mut self, key: &OfferKey) -> Option<(OfferId, &mut Channel)> {
1069 for (offer_id, channel) in self.iter_mut() {
1070 if channel.offer.instance_id == key.instance_id
1071 && channel.offer.interface_id == key.interface_id
1072 && channel.offer.subchannel_index == key.subchannel_index
1073 {
1074 return Some((offer_id, channel));
1075 }
1076 }
1077 None
1078 }
1079
1080 fn iter(&self) -> impl Iterator<Item = (OfferId, &Channel)> {
1082 self.channels
1083 .iter()
1084 .map(|(id, channel)| (OfferId(id), channel))
1085 }
1086
1087 fn iter_mut(&mut self) -> impl Iterator<Item = (OfferId, &mut Channel)> {
1089 self.channels
1090 .iter_mut()
1091 .map(|(id, channel)| (OfferId(id), channel))
1092 }
1093
1094 fn retain<F>(&mut self, mut f: F)
1096 where
1097 F: FnMut(OfferId, &mut Channel) -> bool,
1098 {
1099 self.channels.retain(|id, channel| {
1100 let retain = f(OfferId(id), channel);
1101 if !retain {
1102 assert!(channel.info.is_none());
1103 }
1104 retain
1105 })
1106 }
1107}
1108
1109impl Index<OfferId> for ChannelList {
1110 type Output = Channel;
1111
1112 fn index(&self, offer_id: OfferId) -> &Self::Output {
1113 &self.channels[offer_id.0]
1114 }
1115}
1116
1117impl IndexMut<OfferId> for ChannelList {
1118 fn index_mut(&mut self, offer_id: OfferId) -> &mut Self::Output {
1119 &mut self.channels[offer_id.0]
1120 }
1121}
1122
1123#[derive(Debug, Inspect)]
1125struct Gpadl {
1126 count: u16,
1127 #[inspect(skip)]
1128 buf: Vec<u64>,
1129 state: GpadlState,
1130}
1131
1132#[derive(Debug, Copy, Clone, PartialEq, Eq, Inspect)]
1133enum GpadlState {
1134 InProgress,
1136 Offered,
1138 OfferedTearingDown,
1141 Accepted,
1143 TearingDown,
1145}
1146
1147impl Gpadl {
1148 fn new(count: u16, len: usize) -> Self {
1151 Self {
1152 state: GpadlState::InProgress,
1153 count,
1154 buf: Vec::with_capacity(len),
1155 }
1156 }
1157
1158 fn append(&mut self, data: &[u8]) -> Result<bool, ChannelError> {
1160 if self.state == GpadlState::InProgress {
1161 let buf = &mut self.buf;
1162 let len = min(data.len() & !7, (buf.capacity() - buf.len()) * 8);
1167 let data = &data[..len];
1168 let start = buf.len();
1169 buf.resize(buf.len() + data.len() / 8, 0);
1170 buf[start..].as_mut_bytes().copy_from_slice(data);
1171 Ok(if buf.len() == buf.capacity() {
1172 gparange::MultiPagedRangeBuf::<Vec<u64>>::validate(self.count as usize, buf)
1173 .map_err(ChannelError::InvalidGpaRange)?;
1174 self.state = GpadlState::Offered;
1175 true
1176 } else {
1177 false
1178 })
1179 } else {
1180 Err(ChannelError::GpadlAlreadyComplete)
1181 }
1182 }
1183}
1184
1185#[derive(Debug, Copy, Clone)]
1187pub struct OpenParams {
1188 pub open_data: OpenData,
1189 pub connection_id: u32,
1190 pub event_flag: u16,
1191 pub monitor_info: Option<MonitorInfo>,
1192 pub flags: protocol::OpenChannelFlags,
1193 pub reserved_target: Option<ConnectionTarget>,
1194 pub channel_id: ChannelId,
1195}
1196
1197impl OpenParams {
1198 fn from_request(
1199 info: &OfferedInfo,
1200 request: &OpenRequest,
1201 monitor_info: Option<MonitorInfo>,
1202 reserved_target: Option<ConnectionTarget>,
1203 ) -> Self {
1204 let (event_flag, connection_id) = if let Some(id) = request.guest_specified_interrupt_info {
1207 (id.event_flag, id.connection_id)
1208 } else {
1209 (info.channel_id.0 as u16, info.connection_id)
1210 };
1211
1212 Self {
1213 open_data: OpenData {
1214 target_vp: request.target_vp,
1215 ring_offset: request.downstream_ring_buffer_page_offset,
1216 ring_gpadl_id: request.ring_buffer_gpadl_id,
1217 user_data: request.user_data,
1218 event_flag,
1219 connection_id,
1220 },
1221 connection_id,
1222 event_flag,
1223 monitor_info,
1224 flags: request.flags.with_unused(0),
1225 reserved_target,
1226 channel_id: info.channel_id,
1227 }
1228 }
1229}
1230
1231#[derive(Debug)]
1233pub enum Action {
1234 Open(OpenParams, VersionInfo),
1235 Close,
1236 Gpadl(GpadlId, u16, Vec<u64>),
1237 TeardownGpadl {
1238 gpadl_id: GpadlId,
1239 post_restore: bool,
1240 },
1241 Modify {
1242 target_vp: u32,
1243 },
1244}
1245
1246static SUPPORTED_VERSIONS: &[Version] = &[
1248 Version::V1,
1249 Version::Win7,
1250 Version::Win8,
1251 Version::Win8_1,
1252 Version::Win10,
1253 Version::Win10Rs3_0,
1254 Version::Win10Rs3_1,
1255 Version::Win10Rs4,
1256 Version::Win10Rs5,
1257 Version::Iron,
1258 Version::Copper,
1259];
1260
1261const SUPPORTED_FEATURE_FLAGS: FeatureFlags = FeatureFlags::new()
1264 .with_guest_specified_signal_parameters(true)
1265 .with_channel_interrupt_redirection(true)
1266 .with_modify_connection(true)
1267 .with_client_id(true)
1268 .with_pause_resume(true);
1269
1270pub trait Notifier: Send {
1272 fn notify(&mut self, offer_id: OfferId, action: Action);
1274
1275 fn forward_unhandled(&mut self, request: InitiateContactRequest);
1277
1278 fn modify_connection(&mut self, request: ModifyConnectionRequest) -> anyhow::Result<()>;
1284
1285 fn inspect(&self, version: Option<VersionInfo>, offer_id: OfferId, req: inspect::Request<'_>) {
1287 let _ = (version, offer_id, req);
1288 }
1289
1290 #[must_use]
1293 fn send_message(&mut self, message: &OutgoingMessage, target: MessageTarget) -> bool;
1294
1295 fn notify_hvsock(&mut self, request: &HvsockConnectRequest);
1297
1298 fn reset_complete(&mut self);
1300
1301 fn unload_complete(&mut self);
1303}
1304
1305impl Server {
1306 pub fn new(vtl: Vtl, child_connection_id: u32, channel_id_offset: u16) -> Self {
1308 Server {
1309 state: ConnectionState::Disconnected,
1310 channels: ChannelList::new(),
1311 assigned_channels: AssignedChannels::new(vtl, channel_id_offset),
1312 assigned_monitors: AssignedMonitors::new(),
1313 gpadls: Default::default(),
1314 incomplete_gpadls: Default::default(),
1315 child_connection_id,
1316 max_version: None,
1317 delayed_max_version: None,
1318 pending_messages: PendingMessages(VecDeque::new()),
1319 }
1320 }
1321
1322 pub fn with_notifier<'a, T: Notifier>(
1324 &'a mut self,
1325 notifier: &'a mut T,
1326 ) -> ServerWithNotifier<'a, T> {
1327 self.validate();
1328 ServerWithNotifier {
1329 inner: self,
1330 notifier,
1331 }
1332 }
1333
1334 fn validate(&self) {
1335 #[cfg(debug_assertions)]
1336 for (_, channel) in self.channels.iter() {
1337 let should_have_info = !channel.state.is_released();
1338 if channel.info.is_some() != should_have_info {
1339 panic!("channel invariant violation: {channel:?}");
1340 }
1341 }
1342 }
1343
1344 pub fn set_compatibility_version(&mut self, version: MaxVersionInfo, delay: bool) {
1346 if delay {
1347 self.delayed_max_version = Some(version)
1348 } else {
1349 tracing::info!(?version, "Limiting VmBus connections to version");
1350 self.max_version = Some(version);
1351 }
1352 }
1353
1354 pub fn channel_gpadls(&self, offer_id: OfferId) -> Vec<RestoredGpadl> {
1355 self.gpadls
1356 .iter()
1357 .filter_map(|(&(gpadl_id, gpadl_offer_id), gpadl)| {
1358 if offer_id != gpadl_offer_id {
1359 return None;
1360 }
1361 let accepted = match gpadl.state {
1362 GpadlState::Offered | GpadlState::OfferedTearingDown => false,
1363 GpadlState::Accepted => true,
1364 GpadlState::InProgress | GpadlState::TearingDown => return None,
1365 };
1366 Some(RestoredGpadl {
1367 request: GpadlRequest {
1368 id: gpadl_id,
1369 count: gpadl.count,
1370 buf: gpadl.buf.clone(),
1371 },
1372 accepted,
1373 })
1374 })
1375 .collect()
1376 }
1377
1378 pub fn get_version(&self) -> Option<VersionInfo> {
1379 self.state.get_version()
1380 }
1381
1382 pub fn get_restore_open_params(&self, offer_id: OfferId) -> Result<OpenParams, RestoreError> {
1383 let channel = &self.channels[offer_id];
1384
1385 match channel.restore_state {
1387 RestoreState::New => {
1388 return Err(RestoreError::MissingChannel(channel.offer.key()));
1392 }
1393 RestoreState::Restoring => {}
1394 RestoreState::Unmatched => unreachable!(),
1395 RestoreState::Restored => {
1396 return Err(RestoreError::AlreadyRestored(channel.offer.key()));
1397 }
1398 }
1399
1400 let info = channel
1401 .info
1402 .ok_or_else(|| RestoreError::MissingChannel(channel.offer.key()))?;
1403
1404 let (request, reserved_state) = match channel.state {
1405 ChannelState::Closed => {
1406 return Err(RestoreError::MismatchedOpenState(channel.offer.key()));
1407 }
1408 ChannelState::Closing { params, .. } | ChannelState::ClosingReopen { params, .. } => {
1409 (params, None)
1410 }
1411 ChannelState::Opening {
1412 request,
1413 reserved_state,
1414 } => (request, reserved_state),
1415 ChannelState::Open {
1416 params,
1417 reserved_state,
1418 ..
1419 } => (params, reserved_state),
1420 ChannelState::ClientReleased | ChannelState::Reoffered => {
1421 return Err(RestoreError::MissingChannel(channel.offer.key()));
1422 }
1423 ChannelState::Revoked
1424 | ChannelState::ClosingClientRelease
1425 | ChannelState::OpeningClientRelease => unreachable!(),
1426 };
1427
1428 Ok(OpenParams::from_request(
1429 &info,
1430 &request,
1431 channel.handled_monitor_info(),
1432 reserved_state.map(|state| state.target),
1433 ))
1434 }
1435
1436 pub fn has_pending_messages(&self) -> bool {
1438 !self.pending_messages.0.is_empty() && !self.state.is_paused()
1439 }
1440
1441 pub fn poll_flush_pending_messages(
1443 &mut self,
1444 mut send: impl FnMut(&OutgoingMessage) -> Poll<()>,
1445 ) -> Poll<()> {
1446 if !self.state.is_paused() {
1447 while let Some(message) = self.pending_messages.0.front() {
1448 ready!(send(message));
1449 self.pending_messages.0.pop_front();
1450 }
1451 }
1452
1453 Poll::Ready(())
1454 }
1455}
1456
1457impl<'a, N: 'a + Notifier> ServerWithNotifier<'a, N> {
1458 pub fn restore_channel(&mut self, offer_id: OfferId, open: bool) -> Result<(), RestoreError> {
1464 let channel = &mut self.inner.channels[offer_id];
1465
1466 match channel.restore_state {
1469 RestoreState::New => {
1470 if open {
1474 return Err(RestoreError::MissingChannel(channel.offer.key()));
1475 } else {
1476 return Ok(());
1477 }
1478 }
1479 RestoreState::Restoring => {}
1480 RestoreState::Unmatched => unreachable!(),
1481 RestoreState::Restored => {
1482 return Err(RestoreError::AlreadyRestored(channel.offer.key()));
1483 }
1484 }
1485
1486 let info = channel
1487 .info
1488 .ok_or_else(|| RestoreError::MissingChannel(channel.offer.key()))?;
1489
1490 if let Some(monitor_info) = channel.handled_monitor_info() {
1491 if !self
1492 .inner
1493 .assigned_monitors
1494 .claim_monitor(monitor_info.monitor_id)
1495 {
1496 return Err(RestoreError::DuplicateMonitorId(monitor_info.monitor_id.0));
1497 }
1498 }
1499
1500 if open {
1501 match channel.state {
1502 ChannelState::Closed => {
1503 return Err(RestoreError::MismatchedOpenState(channel.offer.key()));
1504 }
1505 ChannelState::Closing { .. } | ChannelState::ClosingReopen { .. } => {
1506 self.notifier.notify(offer_id, Action::Close);
1507 }
1508 ChannelState::Opening {
1509 request,
1510 reserved_state,
1511 } => {
1512 self.inner
1513 .pending_messages
1514 .sender(self.notifier, self.inner.state.is_paused())
1515 .send_open_result(
1516 info.channel_id,
1517 &request,
1518 protocol::STATUS_SUCCESS,
1519 MessageTarget::for_offer(offer_id, &reserved_state),
1520 );
1521 channel.state = ChannelState::Open {
1522 params: request,
1523 modify_state: ModifyState::NotModifying,
1524 reserved_state,
1525 };
1526 }
1527 ChannelState::Open { .. } => {}
1528 ChannelState::ClientReleased | ChannelState::Reoffered => {
1529 return Err(RestoreError::MissingChannel(channel.offer.key()));
1530 }
1531 ChannelState::Revoked
1532 | ChannelState::ClosingClientRelease
1533 | ChannelState::OpeningClientRelease => unreachable!(),
1534 };
1535 } else {
1536 match channel.state {
1537 ChannelState::Closed => {}
1538 ChannelState::Reoffered => {}
1543 ChannelState::Closing { .. } => {
1544 channel.state = ChannelState::Closed;
1545 }
1546 ChannelState::ClosingReopen { request, .. } => {
1547 self.notifier.notify(
1548 offer_id,
1549 Action::Open(
1550 OpenParams::from_request(
1551 &info,
1552 &request,
1553 channel.handled_monitor_info(),
1554 None,
1555 ),
1556 self.inner.state.get_version().expect("must be connected"),
1557 ),
1558 );
1559 channel.state = ChannelState::Opening {
1560 request,
1561 reserved_state: None,
1562 };
1563 }
1564 ChannelState::Opening {
1565 request,
1566 reserved_state,
1567 } => {
1568 self.notifier.notify(
1569 offer_id,
1570 Action::Open(
1571 OpenParams::from_request(
1572 &info,
1573 &request,
1574 channel.handled_monitor_info(),
1575 reserved_state.map(|state| state.target),
1576 ),
1577 self.inner.state.get_version().expect("must be connected"),
1578 ),
1579 );
1580 }
1581 ChannelState::Open { .. } => {
1582 return Err(RestoreError::MismatchedOpenState(channel.offer.key()));
1583 }
1584 ChannelState::ClientReleased => {
1585 return Err(RestoreError::MissingChannel(channel.offer.key()));
1586 }
1587 ChannelState::Revoked
1588 | ChannelState::ClosingClientRelease
1589 | ChannelState::OpeningClientRelease => unreachable!(),
1590 }
1591 }
1592
1593 channel.restore_state = RestoreState::Restored;
1594 Ok(())
1595 }
1596
1597 pub fn revoke_unclaimed_channels(&mut self) {
1600 for (offer_id, channel) in self.inner.channels.iter_mut() {
1601 match channel.restore_state {
1602 RestoreState::Restored => {
1603 }
1605 RestoreState::New => {
1606 if let ConnectionState::Connected(info) = &self.inner.state {
1611 if matches!(channel.state, ChannelState::ClientReleased) {
1612 channel.prepare_channel(
1613 offer_id,
1614 &mut self.inner.assigned_channels,
1615 &mut self.inner.assigned_monitors,
1616 );
1617 channel.state = ChannelState::Closed;
1618 self.inner
1619 .pending_messages
1620 .sender(self.notifier, self.inner.state.is_paused())
1621 .send_offer(channel, info.version);
1622 }
1623 }
1624 }
1625 RestoreState::Restoring => {
1626 let retain = revoke(
1630 self.inner
1631 .pending_messages
1632 .sender(self.notifier, self.inner.state.is_paused()),
1633 offer_id,
1634 channel,
1635 &mut self.inner.gpadls,
1636 );
1637 assert!(retain, "channel has not been released");
1638 channel.state = ChannelState::Reoffered;
1639 }
1640 RestoreState::Unmatched => {
1641 let retain = revoke(
1644 self.inner
1645 .pending_messages
1646 .sender(self.notifier, self.inner.state.is_paused()),
1647 offer_id,
1648 channel,
1649 &mut self.inner.gpadls,
1650 );
1651 assert!(retain, "channel has not been released");
1652 }
1653 }
1654 }
1655
1656 for (&(gpadl_id, offer_id), gpadl) in self.inner.gpadls.iter_mut() {
1658 match gpadl.state {
1659 GpadlState::InProgress | GpadlState::Accepted => {}
1660 GpadlState::Offered => {
1661 self.notifier.notify(
1662 offer_id,
1663 Action::Gpadl(gpadl_id, gpadl.count, gpadl.buf.clone()),
1664 );
1665 }
1666 GpadlState::TearingDown => {
1667 self.notifier.notify(
1668 offer_id,
1669 Action::TeardownGpadl {
1670 gpadl_id,
1671 post_restore: true,
1672 },
1673 );
1674 }
1675 GpadlState::OfferedTearingDown => unreachable!(),
1676 }
1677 }
1678
1679 self.check_disconnected();
1680 }
1681
1682 pub fn reset(&mut self) {
1687 assert!(!self.is_resetting());
1688 if self.request_disconnect(ConnectionAction::Reset) {
1689 self.complete_reset();
1690 }
1691 }
1692
1693 fn complete_reset(&mut self) {
1694 for (_, channel) in self.inner.channels.iter_mut() {
1696 channel.restore_state = RestoreState::New;
1697 }
1698 self.inner.pending_messages.0.clear();
1699 self.notifier.reset_complete();
1700 }
1701
1702 pub fn offer_channel(&mut self, offer: OfferParamsInternal) -> Result<OfferId, OfferError> {
1704 if let Some((offer_id, channel)) = self.inner.channels.get_by_key_mut(&offer.key()) {
1706 if channel.restore_state != RestoreState::Unmatched
1710 && !matches!(channel.state, ChannelState::Revoked)
1711 {
1712 return Err(OfferError::AlreadyExists(offer.key()));
1713 }
1714
1715 let info = channel.info.expect("assigned");
1716 if channel.restore_state == RestoreState::Unmatched {
1717 tracing::debug!(
1718 offer_id = offer_id.0,
1719 key = %channel.offer.key(),
1720 "matched channel"
1721 );
1722
1723 assert!(!matches!(channel.state, ChannelState::Revoked));
1724 channel.restore_state = RestoreState::Restoring;
1728
1729 if let MnfUsage::Relayed { monitor_id } = offer.use_mnf {
1732 if info.monitor_id != Some(MonitorId(monitor_id)) {
1733 return Err(OfferError::MismatchedMonitorId(
1734 info.monitor_id,
1735 MonitorId(monitor_id),
1736 ));
1737 }
1738 }
1739 } else {
1740 channel.state = ChannelState::Reoffered;
1744 tracing::info!(?offer_id, key = %channel.offer.key(), "channel marked for reoffer");
1745 }
1746
1747 channel.offer = offer;
1748 return Ok(offer_id);
1749 }
1750
1751 let mut connected_version = None;
1752 let state = match self.inner.state {
1753 ConnectionState::Connected(ConnectionInfo {
1754 offers_sent: true,
1755 version,
1756 ..
1757 }) => {
1758 connected_version = Some(version);
1759 ChannelState::Closed
1760 }
1761 ConnectionState::Connected(ConnectionInfo {
1762 offers_sent: false, ..
1763 })
1764 | ConnectionState::Connecting { .. }
1765 | ConnectionState::Disconnecting { .. }
1766 | ConnectionState::Disconnected => ChannelState::ClientReleased,
1767 };
1768
1769 if self.inner.channels.len() >= self.inner.assigned_channels.allowable_channel_count() {
1771 return Err(OfferError::TooManyChannels);
1772 }
1773
1774 let key = offer.key();
1775 let confidential_ring_buffer = offer.flags.confidential_ring_buffer();
1776 let confidential_external_memory = offer.flags.confidential_external_memory();
1777 let channel = Channel {
1778 info: None,
1779 offer,
1780 state,
1781 restore_state: RestoreState::New,
1782 };
1783
1784 let offer_id = self.inner.channels.offer(channel);
1785 if let Some(version) = connected_version {
1786 let channel = &mut self.inner.channels[offer_id];
1787 channel.prepare_channel(
1788 offer_id,
1789 &mut self.inner.assigned_channels,
1790 &mut self.inner.assigned_monitors,
1791 );
1792
1793 self.inner
1794 .pending_messages
1795 .sender(self.notifier, self.inner.state.is_paused())
1796 .send_offer(channel, version);
1797 }
1798
1799 tracing::info!(?offer_id, %key, confidential_ring_buffer, confidential_external_memory, "new channel");
1800 Ok(offer_id)
1801 }
1802
1803 pub fn revoke_channel(&mut self, offer_id: OfferId) {
1805 let channel = &mut self.inner.channels[offer_id];
1806 let retain = revoke(
1807 self.inner
1808 .pending_messages
1809 .sender(self.notifier, self.inner.state.is_paused()),
1810 offer_id,
1811 channel,
1812 &mut self.inner.gpadls,
1813 );
1814 if !retain {
1815 self.inner.channels.remove(offer_id);
1816 }
1817
1818 self.check_disconnected();
1819 }
1820
1821 pub fn open_complete(&mut self, offer_id: OfferId, result: i32) {
1823 tracing::debug!(offer_id = offer_id.0, result, "open complete");
1824
1825 let channel = &mut self.inner.channels[offer_id];
1826 match channel.state {
1827 ChannelState::Opening {
1828 request,
1829 reserved_state,
1830 } => {
1831 let channel_id = channel.info.expect("assigned").channel_id;
1832 if result >= 0 {
1833 tracelimit::info_ratelimited!(
1834 offer_id = offer_id.0,
1835 channel_id = channel_id.0,
1836 result,
1837 "opened channel"
1838 );
1839 } else {
1840 tracelimit::error_ratelimited!(
1842 offer_id = offer_id.0,
1843 channel_id = channel_id.0,
1844 result,
1845 "failed to open channel"
1846 );
1847 }
1848
1849 self.inner
1850 .pending_messages
1851 .sender(self.notifier, self.inner.state.is_paused())
1852 .send_open_result(
1853 channel_id,
1854 &request,
1855 result,
1856 MessageTarget::for_offer(offer_id, &reserved_state),
1857 );
1858 channel.state = if result >= 0 {
1859 ChannelState::Open {
1860 params: request,
1861 modify_state: ModifyState::NotModifying,
1862 reserved_state,
1863 }
1864 } else {
1865 ChannelState::Closed
1866 };
1867 }
1868 ChannelState::OpeningClientRelease => {
1869 tracing::info!(
1870 offer_id = offer_id.0,
1871 result,
1872 "opened channel (client released)"
1873 );
1874
1875 if result >= 0 {
1876 channel.state = ChannelState::ClosingClientRelease;
1877 self.notifier.notify(offer_id, Action::Close);
1878 } else {
1879 channel.state = ChannelState::ClientReleased;
1880 self.check_disconnected();
1881 }
1882 }
1883
1884 ChannelState::ClientReleased
1885 | ChannelState::Closed
1886 | ChannelState::Open { .. }
1887 | ChannelState::Closing { .. }
1888 | ChannelState::ClosingReopen { .. }
1889 | ChannelState::Revoked
1890 | ChannelState::Reoffered
1891 | ChannelState::ClosingClientRelease => {
1892 tracing::error!(?offer_id, state = ?channel.state, "invalid open complete")
1893 }
1894 }
1895 }
1896
1897 fn are_channels_reset(&self, include_reserved: bool) -> bool {
1900 self.inner.gpadls.keys().all(|(_, offer_id)| {
1901 !include_reserved && self.inner.channels[*offer_id].state.is_reserved()
1902 }) && self.inner.channels.iter().all(|(_, channel)| {
1903 matches!(channel.state, ChannelState::ClientReleased)
1904 || (!include_reserved && channel.state.is_reserved())
1905 })
1906 }
1907
1908 fn check_disconnected(&mut self) {
1912 match self.inner.state {
1913 ConnectionState::Disconnecting {
1914 next_action,
1915 modify_sent: false,
1916 } => {
1917 if self.are_channels_reset(matches!(next_action, ConnectionAction::Reset)) {
1918 self.inner.state = ConnectionState::Disconnecting {
1919 next_action,
1920 modify_sent: true,
1921 };
1922
1923 self.notifier
1925 .modify_connection(ModifyConnectionRequest {
1926 monitor_page: Update::Reset,
1927 interrupt_page: Update::Reset,
1928 ..Default::default()
1929 })
1930 .expect("resetting state should not fail");
1931 }
1932 }
1933 ConnectionState::Disconnecting {
1934 modify_sent: true, ..
1935 }
1936 | ConnectionState::Disconnected
1937 | ConnectionState::Connected { .. }
1938 | ConnectionState::Connecting { .. } => (),
1939 }
1940 }
1941
1942 fn is_resetting(&self) -> bool {
1945 matches!(
1946 &self.inner.state,
1947 ConnectionState::Connecting {
1948 next_action: ConnectionAction::Reset,
1949 ..
1950 } | ConnectionState::Disconnecting {
1951 next_action: ConnectionAction::Reset,
1952 ..
1953 }
1954 )
1955 }
1956
1957 pub fn close_complete(&mut self, offer_id: OfferId) {
1959 let channel = &mut self.inner.channels[offer_id];
1960 tracing::info!(offer_id = offer_id.0, "closed channel");
1961 match channel.state {
1962 ChannelState::Closing {
1963 reserved_state: Some(reserved_state),
1964 ..
1965 } => {
1966 channel.state = ChannelState::Closed;
1967 if matches!(self.inner.state, ConnectionState::Connected { .. }) {
1968 let channel_id = channel.info.expect("assigned").channel_id;
1969 self.send_close_reserved_channel_response(
1970 channel_id,
1971 offer_id,
1972 reserved_state.target,
1973 );
1974 } else {
1975 if Self::client_release_channel(
1978 self.inner
1979 .pending_messages
1980 .sender(self.notifier, self.inner.state.is_paused()),
1981 offer_id,
1982 channel,
1983 &mut self.inner.gpadls,
1984 &mut self.inner.assigned_channels,
1985 &mut self.inner.assigned_monitors,
1986 None,
1987 ) {
1988 self.inner.channels.remove(offer_id);
1989 }
1990 }
1991 }
1992 ChannelState::Closing { .. } => {
1993 channel.state = ChannelState::Closed;
1994 }
1995 ChannelState::ClosingClientRelease => {
1996 channel.state = ChannelState::ClientReleased;
1997 self.check_disconnected();
1998 }
1999 ChannelState::ClosingReopen { request, .. } => {
2000 channel.state = ChannelState::Closed;
2001 self.open_channel(offer_id, &request, None);
2002 }
2003
2004 ChannelState::Closed
2005 | ChannelState::ClientReleased
2006 | ChannelState::Opening { .. }
2007 | ChannelState::Open { .. }
2008 | ChannelState::Revoked
2009 | ChannelState::Reoffered
2010 | ChannelState::OpeningClientRelease => {
2011 tracing::error!(?offer_id, state = ?channel.state, "invalid close complete")
2012 }
2013 }
2014 }
2015
2016 fn send_close_reserved_channel_response(
2017 &mut self,
2018 channel_id: ChannelId,
2019 offer_id: OfferId,
2020 target: ConnectionTarget,
2021 ) {
2022 self.sender().send_message_with_target(
2023 &protocol::CloseReservedChannelResponse { channel_id },
2024 MessageTarget::ReservedChannel(offer_id, target),
2025 );
2026 }
2027
2028 fn handle_initiate_contact(
2031 &mut self,
2032 input: &protocol::InitiateContact2,
2033 message: &SynicMessage,
2034 includes_client_id: bool,
2035 ) -> Result<(), ChannelError> {
2036 let target_info =
2037 protocol::TargetInfo::from(input.initiate_contact.interrupt_page_or_target_info);
2038
2039 let target_sint = if message.multiclient
2040 && input.initiate_contact.version_requested >= Version::Win10Rs3_1 as u32
2041 {
2042 target_info.sint()
2043 } else {
2044 SINT
2045 };
2046
2047 let target_vtl = if message.multiclient
2048 && input.initiate_contact.version_requested >= Version::Win10Rs4 as u32
2049 {
2050 target_info.vtl()
2051 } else {
2052 0
2053 };
2054
2055 let feature_flags = if input.initiate_contact.version_requested >= Version::Copper as u32 {
2056 target_info.feature_flags()
2057 } else {
2058 0
2059 };
2060
2061 let target_message_vp =
2066 if input.initiate_contact.version_requested >= Version::Win8_1 as u32 {
2067 input.initiate_contact.target_message_vp
2068 } else {
2069 0
2070 };
2071
2072 let interrupt_page = (input.initiate_contact.version_requested < Version::Win8 as u32
2079 && input.initiate_contact.interrupt_page_or_target_info != 0)
2080 .then_some(input.initiate_contact.interrupt_page_or_target_info);
2081
2082 let monitor_page = if (input.initiate_contact.parent_to_child_monitor_page_gpa == 0)
2085 != (input.initiate_contact.child_to_parent_monitor_page_gpa == 0)
2086 {
2087 MonitorPageRequest::Invalid
2088 } else if input.initiate_contact.parent_to_child_monitor_page_gpa != 0 {
2089 MonitorPageRequest::Some(MonitorPageGpas {
2090 parent_to_child: input.initiate_contact.parent_to_child_monitor_page_gpa,
2091 child_to_parent: input.initiate_contact.child_to_parent_monitor_page_gpa,
2092 })
2093 } else {
2094 MonitorPageRequest::None
2095 };
2096
2097 let client_id = if FeatureFlags::from(feature_flags).client_id() {
2100 if includes_client_id {
2101 input.client_id
2102 } else {
2103 return Err(ChannelError::ParseError(
2104 protocol::ParseError::MessageTooSmall(Some(
2105 protocol::MessageType::INITIATE_CONTACT,
2106 )),
2107 ));
2108 }
2109 } else {
2110 Guid::ZERO
2111 };
2112
2113 let request = InitiateContactRequest {
2114 version_requested: input.initiate_contact.version_requested,
2115 target_message_vp,
2116 monitor_page,
2117 target_sint,
2118 target_vtl,
2119 feature_flags,
2120 interrupt_page,
2121 client_id,
2122 trusted: message.trusted,
2123 };
2124 self.initiate_contact(request);
2125 Ok(())
2126 }
2127
2128 pub fn initiate_contact(&mut self, request: InitiateContactRequest) {
2129 let vtl = self.inner.assigned_channels.vtl as u8;
2132 if request.target_vtl != vtl {
2133 self.notifier.forward_unhandled(request);
2135 return;
2136 }
2137
2138 if request.target_sint != SINT {
2139 tracelimit::warn_ratelimited!(
2140 "unsupported multiclient request for VTL {} SINT {}, version {:#x}",
2141 request.target_vtl,
2142 request.target_sint,
2143 request.version_requested,
2144 );
2145
2146 self.send_version_response_with_target(
2148 None,
2149 MessageTarget::Custom(ConnectionTarget {
2150 vp: request.target_message_vp,
2151 sint: request.target_sint,
2152 }),
2153 );
2154
2155 return;
2156 }
2157
2158 if !self.request_disconnect(ConnectionAction::Reconnect {
2159 initiate_contact: request,
2160 }) {
2161 return;
2162 }
2163
2164 let Some(version) = self.check_version_supported(&request) else {
2165 tracelimit::warn_ratelimited!(
2166 vtl,
2167 version = request.version_requested,
2168 client_id = ?request.client_id,
2169 "Guest requested unsupported version"
2170 );
2171
2172 self.send_version_response(None);
2174 return;
2175 };
2176
2177 tracelimit::info_ratelimited!(
2178 vtl,
2179 ?version,
2180 client_id = ?request.client_id,
2181 trusted = request.trusted,
2182 "Guest negotiated version"
2183 );
2184
2185 let monitor_page = match request.monitor_page {
2188 MonitorPageRequest::Some(mp) => Some(mp),
2189 MonitorPageRequest::None => None,
2190 MonitorPageRequest::Invalid => {
2191 self.send_version_response(Some((
2193 version,
2194 protocol::ConnectionState::FAILED_UNKNOWN_FAILURE,
2195 )));
2196
2197 return;
2198 }
2199 };
2200
2201 self.inner.state = ConnectionState::Connecting {
2202 info: ConnectionInfo {
2203 version,
2204 trusted: request.trusted,
2205 interrupt_page: request.interrupt_page,
2206 monitor_page,
2207 target_message_vp: request.target_message_vp,
2208 modifying: false,
2209 offers_sent: false,
2210 client_id: request.client_id,
2211 paused: false,
2212 },
2213 next_action: ConnectionAction::None,
2214 };
2215
2216 if let Err(err) = self.notifier.modify_connection(ModifyConnectionRequest {
2219 version: Some(request.version_requested),
2220 monitor_page: monitor_page.into(),
2221 interrupt_page: request.interrupt_page.into(),
2222 target_message_vp: Some(request.target_message_vp),
2223 notify_relay: true,
2224 }) {
2225 tracelimit::error_ratelimited!(?err, "server failed to change state");
2226 self.inner.state = ConnectionState::Disconnected;
2227 self.send_version_response(Some((
2228 version,
2229 protocol::ConnectionState::FAILED_UNKNOWN_FAILURE,
2230 )));
2231 }
2232 }
2233
2234 pub(crate) fn complete_initiate_contact(&mut self, response: ModifyConnectionResponse) {
2235 let ConnectionState::Connecting {
2236 mut info,
2237 next_action,
2238 } = self.inner.state
2239 else {
2240 panic!("Invalid state for completing InitiateContact.");
2241 };
2242
2243 const LOCAL_FEATURE_FLAGS: FeatureFlags = FeatureFlags::new()
2245 .with_client_id(true)
2246 .with_confidential_channels(true);
2247
2248 let relay_feature_flags = match response {
2249 ModifyConnectionResponse::Supported(
2251 protocol::ConnectionState::SUCCESSFUL,
2252 feature_flags,
2253 ) => feature_flags,
2254 ModifyConnectionResponse::Supported(connection_state, feature_flags) => {
2257 tracelimit::error_ratelimited!(
2258 ?connection_state,
2259 "initiate contact failed because relay request failed"
2260 );
2261
2262 info.version.feature_flags &= feature_flags | LOCAL_FEATURE_FLAGS;
2265
2266 self.send_version_response(Some((info.version, connection_state)));
2267 self.inner.state = ConnectionState::Disconnected;
2268 return;
2269 }
2270 ModifyConnectionResponse::Unsupported => {
2273 self.send_version_response(None);
2274 self.inner.state = ConnectionState::Disconnected;
2275 return;
2276 }
2277 };
2278
2279 info.version.feature_flags &= relay_feature_flags | LOCAL_FEATURE_FLAGS;
2282 self.inner.state = ConnectionState::Connected(info);
2283
2284 self.send_version_response(Some((info.version, protocol::ConnectionState::SUCCESSFUL)));
2285 if !matches!(next_action, ConnectionAction::None) && self.request_disconnect(next_action) {
2286 self.do_next_action(next_action);
2287 }
2288 }
2289
2290 fn check_version_supported(&self, request: &InitiateContactRequest) -> Option<VersionInfo> {
2292 let version = SUPPORTED_VERSIONS
2293 .iter()
2294 .find(|v| request.version_requested == **v as u32)
2295 .copied()?;
2296
2297 if let Some(max_version) = self.inner.max_version {
2299 if version as u32 > max_version.version {
2300 return None;
2301 }
2302 }
2303
2304 let supported_flags = if version >= Version::Copper {
2305 let max_supported_flags =
2307 SUPPORTED_FEATURE_FLAGS.with_confidential_channels(request.trusted);
2308
2309 if let Some(max_version) = self.inner.max_version {
2311 max_supported_flags & max_version.feature_flags
2312 } else {
2313 max_supported_flags
2314 }
2315 } else {
2316 FeatureFlags::new()
2317 };
2318
2319 let feature_flags = supported_flags & request.feature_flags.into();
2320
2321 assert!(version >= Version::Copper || feature_flags == FeatureFlags::new());
2322 if feature_flags.into_bits() != request.feature_flags {
2323 tracelimit::warn_ratelimited!(
2324 supported = feature_flags.into_bits(),
2325 requested = request.feature_flags,
2326 "Guest requested unsupported feature flags."
2327 );
2328 }
2329
2330 Some(VersionInfo {
2331 version,
2332 feature_flags,
2333 })
2334 }
2335
2336 fn send_version_response(&mut self, data: Option<(VersionInfo, protocol::ConnectionState)>) {
2337 self.send_version_response_with_target(data, MessageTarget::Default);
2338 }
2339
2340 fn send_version_response_with_target(
2341 &mut self,
2342 data: Option<(VersionInfo, protocol::ConnectionState)>,
2343 target: MessageTarget,
2344 ) {
2345 let mut response2 = protocol::VersionResponse2::new_zeroed();
2346 let response = &mut response2.version_response;
2347 let mut send_response2 = false;
2348 if let Some((version, state)) = data {
2349 if state == protocol::ConnectionState::SUCCESSFUL || version.version >= Version::Win8 {
2352 response.version_supported = 1;
2353 response.connection_state = state;
2354 response.selected_version_or_connection_id =
2355 if version.version >= Version::Win10Rs3_1 {
2356 self.inner.child_connection_id
2357 } else {
2358 version.version as u32
2359 };
2360
2361 if version.version >= Version::Copper {
2362 response2.supported_features = version.feature_flags.into();
2363 send_response2 = true;
2364 }
2365 }
2366 }
2367
2368 if send_response2 {
2369 self.sender().send_message_with_target(&response2, target);
2370 } else {
2371 self.sender().send_message_with_target(response, target);
2372 }
2373 }
2374
2375 fn request_disconnect(&mut self, new_action: ConnectionAction) -> bool {
2378 assert!(!self.is_resetting());
2379
2380 let gpadls = &mut self.inner.gpadls;
2382 let vm_reset = matches!(new_action, ConnectionAction::Reset);
2383 self.inner.channels.retain(|offer_id, channel| {
2384 (!vm_reset && channel.state.is_reserved())
2386 || !Self::client_release_channel(
2387 self.inner
2388 .pending_messages
2389 .sender(self.notifier, self.inner.state.is_paused()),
2390 offer_id,
2391 channel,
2392 gpadls,
2393 &mut self.inner.assigned_channels,
2394 &mut self.inner.assigned_monitors,
2395 None,
2396 )
2397 });
2398
2399 match &mut self.inner.state {
2403 ConnectionState::Disconnected => {
2404 if vm_reset {
2406 if !self.are_channels_reset(true) {
2407 self.inner.state = ConnectionState::Disconnecting {
2408 next_action: ConnectionAction::Reset,
2409 modify_sent: false,
2410 };
2411 }
2412 } else {
2413 assert!(self.are_channels_reset(false));
2414 }
2415 }
2416
2417 ConnectionState::Connected { .. } => {
2418 if self.are_channels_reset(vm_reset) {
2419 self.inner.state = ConnectionState::Disconnected;
2420 } else {
2421 self.inner.state = ConnectionState::Disconnecting {
2422 next_action: new_action,
2423 modify_sent: false,
2424 };
2425 }
2426 }
2427
2428 ConnectionState::Connecting { next_action, .. }
2429 | ConnectionState::Disconnecting { next_action, .. } => {
2430 *next_action = new_action;
2431 }
2432 }
2433
2434 matches!(self.inner.state, ConnectionState::Disconnected)
2435 }
2436
2437 pub(crate) fn complete_disconnect(&mut self) {
2438 if let ConnectionState::Disconnecting {
2439 next_action,
2440 modify_sent,
2441 } = std::mem::replace(&mut self.inner.state, ConnectionState::Disconnected)
2442 {
2443 assert!(self.are_channels_reset(matches!(next_action, ConnectionAction::Reset)));
2444 if !modify_sent {
2445 tracelimit::warn_ratelimited!("unexpected modify response");
2446 }
2447
2448 self.inner.state = ConnectionState::Disconnected;
2449 self.do_next_action(next_action);
2450 } else {
2451 unreachable!("not ready for disconnect");
2452 }
2453 }
2454
2455 fn do_next_action(&mut self, action: ConnectionAction) {
2456 match action {
2457 ConnectionAction::None => {}
2458 ConnectionAction::Reset => {
2459 self.complete_reset();
2460 }
2461 ConnectionAction::SendUnloadComplete => {
2462 self.complete_unload();
2463 }
2464 ConnectionAction::Reconnect { initiate_contact } => {
2465 self.initiate_contact(initiate_contact);
2466 }
2467 ConnectionAction::SendFailedVersionResponse => {
2468 self.send_version_response(None);
2471 }
2472 }
2473 }
2474
2475 fn handle_unload(&mut self) {
2477 tracing::debug!(
2478 vtl = self.inner.assigned_channels.vtl as u8,
2479 state = ?self.inner.state,
2480 "VmBus received unload request from guest",
2481 );
2482
2483 if self.request_disconnect(ConnectionAction::SendUnloadComplete) {
2484 self.complete_unload();
2485 }
2486 }
2487
2488 fn complete_unload(&mut self) {
2489 self.notifier.unload_complete();
2490 if let Some(version) = self.inner.delayed_max_version.take() {
2491 self.inner.set_compatibility_version(version, false);
2492 }
2493
2494 self.sender().send_message(&protocol::UnloadComplete {});
2495 tracelimit::info_ratelimited!("Vmbus disconnected");
2496 }
2497
2498 fn handle_request_offers(&mut self) -> Result<(), ChannelError> {
2500 let ConnectionState::Connected(info) = &mut self.inner.state else {
2501 unreachable!(
2502 "in unexpected state {:?}, should be prevented by Message::parse()",
2503 self.inner.state
2504 );
2505 };
2506
2507 if info.offers_sent {
2508 return Err(ChannelError::OffersAlreadySent);
2509 }
2510
2511 info.offers_sent = true;
2512
2513 let mut sorted_channels: Vec<_> = self
2516 .inner
2517 .channels
2518 .iter_mut()
2519 .filter(|(_, channel)| !channel.state.is_reserved())
2520 .collect();
2521
2522 sorted_channels.sort_unstable_by_key(|(_, channel)| {
2523 (
2524 channel.offer.interface_id,
2525 channel.offer.offer_order.unwrap_or(u32::MAX),
2526 channel.offer.instance_id,
2527 )
2528 });
2529
2530 for (offer_id, channel) in sorted_channels {
2531 assert!(matches!(channel.state, ChannelState::ClientReleased));
2532 assert!(channel.info.is_none());
2533
2534 channel.prepare_channel(
2535 offer_id,
2536 &mut self.inner.assigned_channels,
2537 &mut self.inner.assigned_monitors,
2538 );
2539
2540 channel.state = ChannelState::Closed;
2541 self.inner
2542 .pending_messages
2543 .sender(self.notifier, info.paused)
2544 .send_offer(channel, info.version);
2545 }
2546 self.sender().send_message(&protocol::AllOffersDelivered {});
2547
2548 Ok(())
2549 }
2550
2551 #[must_use]
2554 fn gpadl_updated(
2555 mut sender: MessageSender<'_, N>,
2556 offer_id: OfferId,
2557 channel: &Channel,
2558 gpadl_id: GpadlId,
2559 gpadl: &Gpadl,
2560 ) -> bool {
2561 if channel.state.is_revoked() {
2562 let channel_id = channel.info.as_ref().expect("assigned").channel_id;
2563 sender.send_gpadl_created(channel_id, gpadl_id, protocol::STATUS_UNSUCCESSFUL);
2564 false
2565 } else {
2566 sender.notifier.notify(
2568 offer_id,
2569 Action::Gpadl(gpadl_id, gpadl.count, gpadl.buf.clone()),
2570 );
2571 true
2572 }
2573 }
2574
2575 fn handle_gpadl_header(
2577 &mut self,
2578 input: &protocol::GpadlHeader,
2579 range: &[u8],
2580 ) -> Result<(), ChannelError> {
2581 let (offer_id, channel) = self
2583 .inner
2584 .channels
2585 .get_by_channel_id_mut(&self.inner.assigned_channels, input.channel_id)?;
2586
2587 if channel.state.is_reserved() {
2590 return Err(ChannelError::ChannelReserved);
2591 }
2592
2593 let mut gpadl = Gpadl::new(input.count, input.len as usize / 8);
2595 let done = gpadl.append(range)?;
2596
2597 let gpadl = match self.inner.gpadls.entry((input.gpadl_id, offer_id)) {
2599 Entry::Vacant(entry) => entry.insert(gpadl),
2600 Entry::Occupied(_) => return Err(ChannelError::DuplicateGpadlId),
2601 };
2602
2603 if !done
2605 && self
2606 .inner
2607 .incomplete_gpadls
2608 .insert(input.gpadl_id, offer_id)
2609 .is_some()
2610 {
2611 unreachable!("gpadl ID validated above");
2612 }
2613
2614 if done
2615 && !Self::gpadl_updated(
2616 self.inner
2617 .pending_messages
2618 .sender(self.notifier, self.inner.state.is_paused()),
2619 offer_id,
2620 channel,
2621 input.gpadl_id,
2622 gpadl,
2623 )
2624 {
2625 self.inner.gpadls.remove(&(input.gpadl_id, offer_id));
2626 }
2627 Ok(())
2628 }
2629
2630 fn handle_gpadl_body(
2633 &mut self,
2634 input: &protocol::GpadlBody,
2635 range: &[u8],
2636 ) -> Result<(), ChannelError> {
2637 let &offer_id = self
2639 .inner
2640 .incomplete_gpadls
2641 .get(&input.gpadl_id)
2642 .ok_or(ChannelError::UnknownGpadlId)?;
2643 let gpadl = self
2644 .inner
2645 .gpadls
2646 .get_mut(&(input.gpadl_id, offer_id))
2647 .ok_or(ChannelError::UnknownGpadlId)?;
2648 let channel = &mut self.inner.channels[offer_id];
2649
2650 if gpadl.append(range)? {
2651 self.inner.incomplete_gpadls.remove(&input.gpadl_id);
2652 if !Self::gpadl_updated(
2653 self.inner
2654 .pending_messages
2655 .sender(self.notifier, self.inner.state.is_paused()),
2656 offer_id,
2657 channel,
2658 input.gpadl_id,
2659 gpadl,
2660 ) {
2661 self.inner.gpadls.remove(&(input.gpadl_id, offer_id));
2662 }
2663 }
2664
2665 Ok(())
2666 }
2667
2668 fn handle_gpadl_teardown(
2670 &mut self,
2671 input: &protocol::GpadlTeardown,
2672 ) -> Result<(), ChannelError> {
2673 tracing::debug!(
2674 channel_id = input.channel_id.0,
2675 gpadl_id = input.gpadl_id.0,
2676 "Received GPADL teardown request"
2677 );
2678
2679 let (offer_id, channel) = self
2680 .inner
2681 .channels
2682 .get_by_channel_id_mut(&self.inner.assigned_channels, input.channel_id)?;
2683
2684 let gpadl = self
2685 .inner
2686 .gpadls
2687 .get_mut(&(input.gpadl_id, offer_id))
2688 .ok_or(ChannelError::UnknownGpadlId)?;
2689
2690 match gpadl.state {
2691 GpadlState::InProgress
2692 | GpadlState::Offered
2693 | GpadlState::OfferedTearingDown
2694 | GpadlState::TearingDown => {
2695 return Err(ChannelError::InvalidGpadlState);
2696 }
2697 GpadlState::Accepted => {
2698 if channel.info.as_ref().map(|info| info.channel_id) != Some(input.channel_id) {
2699 return Err(ChannelError::WrongGpadlChannelId);
2700 }
2701
2702 if channel.state.is_reserved() {
2706 return Err(ChannelError::ChannelReserved);
2707 }
2708
2709 if channel.state.is_revoked() {
2710 tracing::trace!(
2711 channel_id = input.channel_id.0,
2712 gpadl_id = input.gpadl_id.0,
2713 "Gpadl teardown for revoked channel"
2714 );
2715
2716 self.inner.gpadls.remove(&(input.gpadl_id, offer_id));
2717 self.sender().send_gpadl_torndown(input.gpadl_id);
2718 } else {
2719 gpadl.state = GpadlState::TearingDown;
2720 self.notifier.notify(
2721 offer_id,
2722 Action::TeardownGpadl {
2723 gpadl_id: input.gpadl_id,
2724 post_restore: false,
2725 },
2726 );
2727 }
2728 }
2729 }
2730 Ok(())
2731 }
2732
2733 fn open_channel(
2736 &mut self,
2737 offer_id: OfferId,
2738 input: &OpenRequest,
2739 reserved_state: Option<ReservedState>,
2740 ) {
2741 let channel = &mut self.inner.channels[offer_id];
2742 assert!(matches!(channel.state, ChannelState::Closed));
2743
2744 channel.state = ChannelState::Opening {
2745 request: *input,
2746 reserved_state,
2747 };
2748
2749 let info = channel.info.as_ref().expect("assigned");
2752 self.notifier.notify(
2753 offer_id,
2754 Action::Open(
2755 OpenParams::from_request(
2756 info,
2757 input,
2758 channel.handled_monitor_info(),
2759 reserved_state.map(|state| state.target),
2760 ),
2761 self.inner.state.get_version().expect("must be connected"),
2762 ),
2763 );
2764 }
2765
2766 fn handle_open_channel(&mut self, input: &protocol::OpenChannel2) -> Result<(), ChannelError> {
2768 let (offer_id, channel) = self
2769 .inner
2770 .channels
2771 .get_by_channel_id_mut(&self.inner.assigned_channels, input.open_channel.channel_id)?;
2772
2773 let guest_specified_interrupt_info = self
2774 .inner
2775 .state
2776 .check_feature_flags(|ff| ff.guest_specified_signal_parameters())
2777 .then_some(SignalInfo {
2778 event_flag: input.event_flag,
2779 connection_id: input.connection_id,
2780 });
2781
2782 let flags = if self
2783 .inner
2784 .state
2785 .check_feature_flags(|ff| ff.channel_interrupt_redirection())
2786 {
2787 input.flags
2788 } else {
2789 Default::default()
2790 };
2791
2792 let request = OpenRequest {
2793 open_id: input.open_channel.open_id,
2794 ring_buffer_gpadl_id: input.open_channel.ring_buffer_gpadl_id,
2795 target_vp: input.open_channel.target_vp,
2796 downstream_ring_buffer_page_offset: input
2797 .open_channel
2798 .downstream_ring_buffer_page_offset,
2799 user_data: input.open_channel.user_data,
2800 guest_specified_interrupt_info,
2801 flags,
2802 };
2803
2804 match channel.state {
2805 ChannelState::Closed => self.open_channel(offer_id, &request, None),
2806 ChannelState::Closing { params, .. } => {
2807 channel.state = ChannelState::ClosingReopen { params, request }
2811 }
2812 ChannelState::Revoked | ChannelState::Reoffered => {}
2813
2814 ChannelState::Open { .. }
2815 | ChannelState::Opening { .. }
2816 | ChannelState::ClosingReopen { .. } => return Err(ChannelError::ChannelAlreadyOpen),
2817
2818 ChannelState::ClientReleased
2819 | ChannelState::ClosingClientRelease
2820 | ChannelState::OpeningClientRelease => unreachable!(),
2821 }
2822 Ok(())
2823 }
2824
2825 fn handle_close_channel(&mut self, input: &protocol::CloseChannel) -> Result<(), ChannelError> {
2827 let (offer_id, channel) = self
2828 .inner
2829 .channels
2830 .get_by_channel_id_mut(&self.inner.assigned_channels, input.channel_id)?;
2831
2832 match channel.state {
2833 ChannelState::Open {
2834 params,
2835 modify_state,
2836 reserved_state: None,
2837 } => {
2838 if modify_state.is_modifying() {
2839 tracelimit::warn_ratelimited!(
2840 ?modify_state,
2841 "Client is closing the channel with a modify in progress"
2842 )
2843 }
2844
2845 channel.state = ChannelState::Closing {
2846 params,
2847 reserved_state: None,
2848 };
2849 self.notifier.notify(offer_id, Action::Close);
2850 }
2851
2852 ChannelState::Open {
2853 reserved_state: Some(_),
2854 ..
2855 } => return Err(ChannelError::ChannelReserved),
2856
2857 ChannelState::Revoked | ChannelState::Reoffered => {}
2858
2859 ChannelState::Closed
2860 | ChannelState::Opening { .. }
2861 | ChannelState::Closing { .. }
2862 | ChannelState::ClosingReopen { .. } => return Err(ChannelError::ChannelNotOpen),
2863
2864 ChannelState::ClientReleased
2865 | ChannelState::ClosingClientRelease
2866 | ChannelState::OpeningClientRelease => unreachable!(),
2867 }
2868
2869 Ok(())
2870 }
2871
2872 fn handle_open_reserved_channel(
2875 &mut self,
2876 input: &protocol::OpenReservedChannel,
2877 version: VersionInfo,
2878 ) -> Result<(), ChannelError> {
2879 let (offer_id, channel) = self
2880 .inner
2881 .channels
2882 .get_by_channel_id_mut(&self.inner.assigned_channels, input.channel_id)?;
2883
2884 let target = ConnectionTarget {
2885 vp: input.target_vp,
2886 sint: input.target_sint as u8,
2887 };
2888
2889 let reserved_state = Some(ReservedState { version, target });
2890
2891 let request = OpenRequest {
2892 ring_buffer_gpadl_id: input.ring_buffer_gpadl,
2893 target_vp: protocol::VP_INDEX_DISABLE_INTERRUPT,
2895 downstream_ring_buffer_page_offset: input.downstream_page_offset,
2896 open_id: 0,
2897 user_data: UserDefinedData::new_zeroed(),
2898 guest_specified_interrupt_info: None,
2899 flags: Default::default(),
2900 };
2901
2902 match channel.state {
2903 ChannelState::Closed => self.open_channel(offer_id, &request, reserved_state),
2904 ChannelState::Revoked | ChannelState::Reoffered => {}
2905
2906 ChannelState::Open { .. } | ChannelState::Opening { .. } => {
2907 return Err(ChannelError::ChannelAlreadyOpen);
2908 }
2909
2910 ChannelState::Closing { .. } | ChannelState::ClosingReopen { .. } => {
2911 return Err(ChannelError::InvalidChannelState);
2912 }
2913
2914 ChannelState::ClientReleased
2915 | ChannelState::ClosingClientRelease
2916 | ChannelState::OpeningClientRelease => unreachable!(),
2917 }
2918 Ok(())
2919 }
2920
2921 fn handle_close_reserved_channel(
2924 &mut self,
2925 input: &protocol::CloseReservedChannel,
2926 ) -> Result<(), ChannelError> {
2927 let (offer_id, channel) = self
2928 .inner
2929 .channels
2930 .get_by_channel_id_mut(&self.inner.assigned_channels, input.channel_id)?;
2931
2932 match channel.state {
2933 ChannelState::Open {
2934 params,
2935 reserved_state: Some(mut resvd),
2936 ..
2937 } => {
2938 resvd.target.vp = input.target_vp;
2939 resvd.target.sint = input.target_sint as u8;
2940 channel.state = ChannelState::Closing {
2941 params,
2942 reserved_state: Some(resvd),
2943 };
2944 self.notifier.notify(offer_id, Action::Close);
2945 }
2946
2947 ChannelState::Open {
2948 reserved_state: None,
2949 ..
2950 } => return Err(ChannelError::ChannelNotReserved),
2951
2952 ChannelState::Revoked | ChannelState::Reoffered => {}
2953
2954 ChannelState::Closed
2955 | ChannelState::Opening { .. }
2956 | ChannelState::Closing { .. }
2957 | ChannelState::ClosingReopen { .. } => return Err(ChannelError::ChannelNotOpen),
2958
2959 ChannelState::ClientReleased
2960 | ChannelState::ClosingClientRelease
2961 | ChannelState::OpeningClientRelease => unreachable!(),
2962 }
2963
2964 Ok(())
2965 }
2966
2967 #[must_use]
2971 fn client_release_channel(
2972 mut sender: MessageSender<'_, N>,
2973 offer_id: OfferId,
2974 channel: &mut Channel,
2975 gpadls: &mut GpadlMap,
2976 assigned_channels: &mut AssignedChannels,
2977 assigned_monitors: &mut AssignedMonitors,
2978 version: Option<VersionInfo>,
2979 ) -> bool {
2980 gpadls.retain(|&(gpadl_id, gpadl_offer_id), gpadl| {
2982 if gpadl_offer_id != offer_id {
2983 return true;
2984 }
2985 match gpadl.state {
2986 GpadlState::InProgress => false,
2987 GpadlState::Offered => {
2988 gpadl.state = GpadlState::OfferedTearingDown;
2989 true
2990 }
2991 GpadlState::Accepted => {
2992 if channel.state.is_revoked() {
2993 false
2995 } else {
2996 gpadl.state = GpadlState::TearingDown;
2997 sender.notifier.notify(
2998 offer_id,
2999 Action::TeardownGpadl {
3000 gpadl_id,
3001 post_restore: false,
3002 },
3003 );
3004 true
3005 }
3006 }
3007 GpadlState::OfferedTearingDown | GpadlState::TearingDown => true,
3008 }
3009 });
3010
3011 let remove = match &mut channel.state {
3012 ChannelState::Closed => {
3013 channel.state = ChannelState::ClientReleased;
3014 false
3015 }
3016 ChannelState::Reoffered => {
3017 if let Some(version) = version {
3018 channel.state = ChannelState::Closed;
3019 channel.restore_state = RestoreState::New;
3020 sender.send_offer(channel, version);
3021 return false;
3023 }
3024 channel.state = ChannelState::ClientReleased;
3025 false
3026 }
3027 ChannelState::Revoked => {
3028 channel.state = ChannelState::ClientReleased;
3029 true
3030 }
3031 ChannelState::Opening { .. } => {
3032 channel.state = ChannelState::OpeningClientRelease;
3033 false
3034 }
3035 ChannelState::Open { .. } => {
3036 channel.state = ChannelState::ClosingClientRelease;
3037 sender.notifier.notify(offer_id, Action::Close);
3038 false
3039 }
3040 ChannelState::Closing { .. } | ChannelState::ClosingReopen { .. } => {
3041 channel.state = ChannelState::ClosingClientRelease;
3042 false
3043 }
3044
3045 ChannelState::ClosingClientRelease
3046 | ChannelState::OpeningClientRelease
3047 | ChannelState::ClientReleased => false,
3048 };
3049
3050 assert!(channel.state.is_released());
3051
3052 channel.release_channel(offer_id, assigned_channels, assigned_monitors);
3053 remove
3054 }
3055
3056 fn handle_rel_id_released(
3058 &mut self,
3059 input: &protocol::RelIdReleased,
3060 ) -> Result<(), ChannelError> {
3061 let channel_id = input.channel_id;
3062 let (offer_id, channel) = self
3063 .inner
3064 .channels
3065 .get_by_channel_id_mut(&self.inner.assigned_channels, channel_id)?;
3066
3067 match channel.state {
3068 ChannelState::Closed
3069 | ChannelState::Revoked
3070 | ChannelState::Closing { .. }
3071 | ChannelState::Reoffered => {
3072 if Self::client_release_channel(
3073 self.inner
3074 .pending_messages
3075 .sender(self.notifier, self.inner.state.is_paused()),
3076 offer_id,
3077 channel,
3078 &mut self.inner.gpadls,
3079 &mut self.inner.assigned_channels,
3080 &mut self.inner.assigned_monitors,
3081 self.inner.state.get_version(),
3082 ) {
3083 self.inner.channels.remove(offer_id);
3084 }
3085
3086 self.check_disconnected();
3087 }
3088
3089 ChannelState::Opening { .. }
3090 | ChannelState::Open { .. }
3091 | ChannelState::ClosingReopen { .. } => return Err(ChannelError::InvalidChannelState),
3092
3093 ChannelState::ClientReleased
3094 | ChannelState::OpeningClientRelease
3095 | ChannelState::ClosingClientRelease => unreachable!(),
3096 }
3097 Ok(())
3098 }
3099
3100 fn handle_tl_connect_request(&mut self, request: protocol::TlConnectRequest2) {
3103 let version = self
3104 .inner
3105 .state
3106 .get_version()
3107 .expect("must be connected")
3108 .version;
3109
3110 let hosted_silo_unaware = version < Version::Win10Rs5;
3111 self.notifier
3112 .notify_hvsock(&HvsockConnectRequest::from_message(
3113 request,
3114 hosted_silo_unaware,
3115 ));
3116 }
3117
3118 pub fn send_tl_connect_result(&mut self, result: HvsockConnectResult) {
3120 if !result.success && self.inner.state.check_version(Version::Win10Rs3_0) {
3124 self.sender().send_message(&protocol::TlConnectResult {
3128 service_id: result.service_id,
3129 endpoint_id: result.endpoint_id,
3130 status: protocol::STATUS_CONNECTION_REFUSED,
3131 })
3132 }
3133 }
3134
3135 fn handle_modify_channel(
3138 &mut self,
3139 request: &protocol::ModifyChannel,
3140 ) -> Result<(), ChannelError> {
3141 let result = self.modify_channel(request);
3142 if result.is_err() {
3143 self.send_modify_channel_response(request.channel_id, protocol::STATUS_UNSUCCESSFUL);
3144 }
3145
3146 result
3147 }
3148
3149 fn modify_channel(&mut self, request: &protocol::ModifyChannel) -> Result<(), ChannelError> {
3151 let (offer_id, channel) = self
3152 .inner
3153 .channels
3154 .get_by_channel_id_mut(&self.inner.assigned_channels, request.channel_id)?;
3155
3156 let (open_request, modify_state) = match &mut channel.state {
3157 ChannelState::Open {
3158 params,
3159 modify_state,
3160 reserved_state: None,
3161 } => (params, modify_state),
3162 _ => return Err(ChannelError::InvalidChannelState),
3163 };
3164
3165 if let ModifyState::Modifying { pending_target_vp } = modify_state {
3166 if self.inner.state.check_version(Version::Iron) {
3167 tracelimit::warn_ratelimited!(
3170 "Client sent new ModifyChannel before receiving ModifyChannelResponse."
3171 );
3172 } else {
3173 *pending_target_vp = Some(request.target_vp);
3176 }
3177 } else {
3178 self.notifier.notify(
3179 offer_id,
3180 Action::Modify {
3181 target_vp: request.target_vp,
3182 },
3183 );
3184
3185 open_request.target_vp = request.target_vp;
3187 *modify_state = ModifyState::Modifying {
3188 pending_target_vp: None,
3189 };
3190 }
3191
3192 Ok(())
3193 }
3194
3195 pub fn modify_channel_complete(&mut self, offer_id: OfferId, status: i32) {
3202 let channel = &mut self.inner.channels[offer_id];
3203
3204 if let ChannelState::Open {
3205 params,
3206 modify_state: ModifyState::Modifying { pending_target_vp },
3207 reserved_state: None,
3208 } = channel.state
3209 {
3210 channel.state = ChannelState::Open {
3211 params,
3212 modify_state: ModifyState::NotModifying,
3213 reserved_state: None,
3214 };
3215
3216 let channel_id = channel.info.as_ref().expect("assigned").channel_id;
3218 self.send_modify_channel_response(channel_id, status);
3219
3220 if let Some(target_vp) = pending_target_vp {
3222 let request = protocol::ModifyChannel {
3223 channel_id,
3224 target_vp,
3225 };
3226
3227 if let Err(error) = self.handle_modify_channel(&request) {
3228 tracelimit::warn_ratelimited!(?error, "Pending ModifyChannel request failed.")
3229 }
3230 }
3231 }
3232 }
3233
3234 fn send_modify_channel_response(&mut self, channel_id: ChannelId, status: i32) {
3235 if self.inner.state.check_version(Version::Iron) {
3236 self.sender()
3237 .send_message(&protocol::ModifyChannelResponse { channel_id, status });
3238 }
3239 }
3240
3241 fn handle_modify_connection(&mut self, request: protocol::ModifyConnection) {
3242 if let Err(err) = self.modify_connection(request) {
3243 tracelimit::error_ratelimited!(?err, "modifying connection failed");
3244 self.complete_modify_connection(ModifyConnectionResponse::Supported(
3245 protocol::ConnectionState::FAILED_UNKNOWN_FAILURE,
3246 FeatureFlags::new(),
3247 ));
3248 }
3249 }
3250
3251 fn modify_connection(&mut self, request: protocol::ModifyConnection) -> anyhow::Result<()> {
3252 let ConnectionState::Connected(info) = &mut self.inner.state else {
3253 anyhow::bail!(
3254 "Invalid state for ModifyConnection request: {:?}",
3255 self.inner.state
3256 );
3257 };
3258
3259 if info.modifying {
3260 anyhow::bail!(
3261 "Duplicate ModifyConnection request, state: {:?}",
3262 self.inner.state
3263 );
3264 }
3265
3266 if (request.child_to_parent_monitor_page_gpa == 0)
3267 != (request.parent_to_child_monitor_page_gpa == 0)
3268 {
3269 anyhow::bail!("Guest must specify either both or no monitor pages, {request:?}");
3270 }
3271
3272 let monitor_page =
3273 (request.child_to_parent_monitor_page_gpa != 0).then_some(MonitorPageGpas {
3274 child_to_parent: request.child_to_parent_monitor_page_gpa,
3275 parent_to_child: request.parent_to_child_monitor_page_gpa,
3276 });
3277
3278 info.modifying = true;
3279 info.monitor_page = monitor_page;
3280 tracing::debug!("modifying connection parameters.");
3281 self.notifier.modify_connection(request.into())?;
3282
3283 Ok(())
3284 }
3285
3286 pub fn complete_modify_connection(&mut self, response: ModifyConnectionResponse) {
3287 tracing::debug!(?response, "modifying connection parameters complete");
3288
3289 match &mut self.inner.state {
3293 ConnectionState::Connecting { .. } => self.complete_initiate_contact(response),
3294 ConnectionState::Disconnecting { .. } => self.complete_disconnect(),
3295 ConnectionState::Connected(info) => {
3296 let ModifyConnectionResponse::Supported(connection_state, ..) = response else {
3297 panic!(
3298 "Relay should not return {:?} for a modify request with no version.",
3299 response
3300 );
3301 };
3302
3303 if !info.modifying {
3304 panic!(
3305 "ModifyConnection response while not modifying, state: {:?}",
3306 self.inner.state
3307 );
3308 }
3309
3310 info.modifying = false;
3311 self.sender()
3312 .send_message(&protocol::ModifyConnectionResponse { connection_state });
3313 }
3314 _ => panic!(
3315 "Invalid state for ModifyConnection response: {:?}",
3316 self.inner.state
3317 ),
3318 }
3319 }
3320
3321 fn handle_pause(&mut self) {
3322 tracelimit::info_ratelimited!("pausing sending messages");
3323 self.sender().send_message(&protocol::PauseResponse {});
3324 let ConnectionState::Connected(info) = &mut self.inner.state else {
3325 unreachable!(
3326 "in unexpected state {:?}, should be prevented by Message::parse()",
3327 self.inner.state
3328 );
3329 };
3330 info.paused = true;
3331 }
3332
3333 pub fn handle_synic_message(&mut self, message: SynicMessage) -> Result<(), ChannelError> {
3335 assert!(!self.is_resetting());
3336
3337 let version = self.inner.state.get_version();
3338 let msg = Message::parse(&message.data, version)?;
3339 tracing::trace!(?msg, message.trusted, "received vmbus message");
3340 if self.inner.state.is_trusted() && !message.trusted {
3345 tracelimit::warn_ratelimited!(?msg, "Received untrusted message");
3346 return Err(ChannelError::UntrustedMessage);
3347 }
3348
3349 match &mut self.inner.state {
3351 ConnectionState::Connected(info) if info.paused => {
3352 if !matches!(
3353 msg,
3354 Message::Resume(..)
3355 | Message::Unload(..)
3356 | Message::InitiateContact { .. }
3357 | Message::InitiateContact2 { .. }
3358 ) {
3359 tracelimit::warn_ratelimited!(?msg, "Received message while paused");
3360 return Err(ChannelError::Paused);
3361 }
3362 tracelimit::info_ratelimited!("resuming sending messages");
3363 info.paused = false;
3364 }
3365 _ => {}
3366 }
3367
3368 match msg {
3369 Message::InitiateContact2(input, ..) => {
3370 self.handle_initiate_contact(&input, &message, true)?
3371 }
3372 Message::InitiateContact(input, ..) => {
3373 self.handle_initiate_contact(&input.into(), &message, false)?
3374 }
3375 Message::Unload(..) => self.handle_unload(),
3376 Message::RequestOffers(..) => self.handle_request_offers()?,
3377 Message::GpadlHeader(input, range) => self.handle_gpadl_header(&input, range)?,
3378 Message::GpadlBody(input, range) => self.handle_gpadl_body(&input, range)?,
3379 Message::GpadlTeardown(input, ..) => self.handle_gpadl_teardown(&input)?,
3380 Message::OpenChannel(input, ..) => self.handle_open_channel(&input.into())?,
3381 Message::OpenChannel2(input, ..) => self.handle_open_channel(&input)?,
3382 Message::CloseChannel(input, ..) => self.handle_close_channel(&input)?,
3383 Message::RelIdReleased(input, ..) => self.handle_rel_id_released(&input)?,
3384 Message::TlConnectRequest(input, ..) => self.handle_tl_connect_request(input.into()),
3385 Message::TlConnectRequest2(input, ..) => self.handle_tl_connect_request(input),
3386 Message::ModifyChannel(input, ..) => self.handle_modify_channel(&input)?,
3387 Message::ModifyConnection(input, ..) => self.handle_modify_connection(input),
3388 Message::OpenReservedChannel(input, ..) => self.handle_open_reserved_channel(
3389 &input,
3390 version.expect("version validated by Message::parse"),
3391 )?,
3392 Message::CloseReservedChannel(input, ..) => {
3393 self.handle_close_reserved_channel(&input)?
3394 }
3395 Message::Pause(protocol::Pause, ..) => self.handle_pause(),
3396 Message::Resume(protocol::Resume, ..) => {}
3397 Message::OfferChannel(..)
3399 | Message::RescindChannelOffer(..)
3400 | Message::AllOffersDelivered(..)
3401 | Message::OpenResult(..)
3402 | Message::GpadlCreated(..)
3403 | Message::GpadlTorndown(..)
3404 | Message::VersionResponse(..)
3405 | Message::VersionResponse2(..)
3406 | Message::UnloadComplete(..)
3407 | Message::CloseReservedChannelResponse(..)
3408 | Message::TlConnectResult(..)
3409 | Message::ModifyChannelResponse(..)
3410 | Message::ModifyConnectionResponse(..)
3411 | Message::PauseResponse(..) => {
3412 unreachable!("Server received client message {:?}", msg);
3413 }
3414 }
3415 Ok(())
3416 }
3417
3418 fn get_gpadl(
3419 gpadls: &mut GpadlMap,
3420 offer_id: OfferId,
3421 gpadl_id: GpadlId,
3422 ) -> Option<&mut Gpadl> {
3423 let gpadl = gpadls.get_mut(&(gpadl_id, offer_id));
3424 if gpadl.is_none() {
3425 tracelimit::error_ratelimited!(?offer_id, ?gpadl_id, "invalid gpadl ID for channel");
3426 }
3427 gpadl
3428 }
3429
3430 pub fn gpadl_create_complete(&mut self, offer_id: OfferId, gpadl_id: GpadlId, status: i32) {
3432 let gpadl = if let Some(gpadl) = Self::get_gpadl(&mut self.inner.gpadls, offer_id, gpadl_id)
3433 {
3434 gpadl
3435 } else {
3436 return;
3437 };
3438 let retain = match gpadl.state {
3439 GpadlState::InProgress | GpadlState::TearingDown | GpadlState::Accepted => {
3440 tracelimit::error_ratelimited!(?offer_id, ?gpadl_id, ?gpadl, "invalid gpadl state");
3441 return;
3442 }
3443 GpadlState::Offered => {
3444 let channel_id = self.inner.channels[offer_id]
3445 .info
3446 .as_ref()
3447 .expect("assigned")
3448 .channel_id;
3449 self.inner
3450 .pending_messages
3451 .sender(self.notifier, self.inner.state.is_paused())
3452 .send_gpadl_created(channel_id, gpadl_id, status);
3453 if status >= 0 {
3454 gpadl.state = GpadlState::Accepted;
3455 true
3456 } else {
3457 false
3458 }
3459 }
3460 GpadlState::OfferedTearingDown => {
3461 if status >= 0 {
3462 self.notifier.notify(
3464 offer_id,
3465 Action::TeardownGpadl {
3466 gpadl_id,
3467 post_restore: false,
3468 },
3469 );
3470 gpadl.state = GpadlState::TearingDown;
3471 true
3472 } else {
3473 false
3474 }
3475 }
3476 };
3477 if !retain {
3478 self.inner
3479 .gpadls
3480 .remove(&(gpadl_id, offer_id))
3481 .expect("gpadl validated above");
3482
3483 self.check_disconnected();
3484 }
3485 }
3486
3487 pub fn gpadl_teardown_complete(&mut self, offer_id: OfferId, gpadl_id: GpadlId) {
3489 tracing::debug!(
3490 offer_id = offer_id.0,
3491 gpadl_id = gpadl_id.0,
3492 "Gpadl teardown complete"
3493 );
3494
3495 let gpadl = if let Some(gpadl) = Self::get_gpadl(&mut self.inner.gpadls, offer_id, gpadl_id)
3496 {
3497 gpadl
3498 } else {
3499 return;
3500 };
3501 let channel = &mut self.inner.channels[offer_id];
3502 match gpadl.state {
3503 GpadlState::InProgress
3504 | GpadlState::Offered
3505 | GpadlState::OfferedTearingDown
3506 | GpadlState::Accepted => {
3507 tracelimit::error_ratelimited!(?offer_id, ?gpadl_id, ?gpadl, "invalid gpadl state");
3508 }
3509 GpadlState::TearingDown => {
3510 if !channel.state.is_released() {
3511 self.sender().send_gpadl_torndown(gpadl_id);
3512 }
3513 self.inner
3514 .gpadls
3515 .remove(&(gpadl_id, offer_id))
3516 .expect("gpadl validated above");
3517
3518 self.check_disconnected();
3519 }
3520 }
3521 }
3522
3523 fn sender(&mut self) -> MessageSender<'_, N> {
3528 self.inner
3529 .pending_messages
3530 .sender(self.notifier, self.inner.state.is_paused())
3531 }
3532}
3533
3534fn revoke<N: Notifier>(
3535 mut sender: MessageSender<'_, N>,
3536 offer_id: OfferId,
3537 channel: &mut Channel,
3538 gpadls: &mut GpadlMap,
3539) -> bool {
3540 let info = match channel.state {
3541 ChannelState::Closed
3542 | ChannelState::Open { .. }
3543 | ChannelState::Opening { .. }
3544 | ChannelState::Closing { .. }
3545 | ChannelState::ClosingReopen { .. } => {
3546 channel.state = ChannelState::Revoked;
3547 Some(channel.info.as_ref().expect("assigned"))
3548 }
3549 ChannelState::Reoffered => {
3550 channel.state = ChannelState::Revoked;
3551 None
3552 }
3553 ChannelState::ClientReleased
3554 | ChannelState::OpeningClientRelease
3555 | ChannelState::ClosingClientRelease => None,
3556 ChannelState::Revoked => return true,
3558 };
3559 let retain = !channel.state.is_released();
3560
3561 gpadls.retain(|&(gpadl_id, gpadl_offer_id), gpadl| {
3563 if gpadl_offer_id != offer_id {
3564 return true;
3565 }
3566
3567 match gpadl.state {
3568 GpadlState::InProgress => true,
3569 GpadlState::Offered => {
3570 if let Some(info) = info {
3571 sender.send_gpadl_created(
3572 info.channel_id,
3573 gpadl_id,
3574 protocol::STATUS_UNSUCCESSFUL,
3575 );
3576 }
3577 false
3578 }
3579 GpadlState::OfferedTearingDown => false,
3580 GpadlState::Accepted => true,
3581 GpadlState::TearingDown => {
3582 if info.is_some() {
3583 sender.send_gpadl_torndown(gpadl_id);
3584 }
3585 false
3586 }
3587 }
3588 });
3589 if let Some(info) = info {
3590 sender.send_rescind(info);
3591 }
3592 if channel.restore_state != RestoreState::New {
3594 channel.restore_state = RestoreState::Restored;
3595 }
3596 retain
3597}
3598
3599struct PendingMessages(VecDeque<OutgoingMessage>);
3600
3601impl PendingMessages {
3602 fn sender<'a, N: Notifier>(
3604 &'a mut self,
3605 notifier: &'a mut N,
3606 is_paused: bool,
3607 ) -> MessageSender<'a, N> {
3608 MessageSender {
3609 notifier,
3610 pending_messages: self,
3611 is_paused,
3612 }
3613 }
3614}
3615
3616struct MessageSender<'a, N> {
3619 notifier: &'a mut N,
3620 pending_messages: &'a mut PendingMessages,
3621 is_paused: bool,
3622}
3623
3624impl<N: Notifier> MessageSender<'_, N> {
3625 fn send_message<
3627 T: IntoBytes + protocol::VmbusMessage + std::fmt::Debug + Immutable + KnownLayout,
3628 >(
3629 &mut self,
3630 msg: &T,
3631 ) {
3632 let message = OutgoingMessage::new(msg);
3633
3634 tracing::trace!(typ = ?T::MESSAGE_TYPE, ?msg, "sending message");
3635 if !self.pending_messages.0.is_empty()
3637 || self.is_paused
3638 || !self.notifier.send_message(&message, MessageTarget::Default)
3639 {
3640 tracing::trace!("message queued");
3641 self.pending_messages.0.push_back(message);
3643 }
3644 }
3645
3646 fn send_message_with_target<
3648 T: IntoBytes + protocol::VmbusMessage + std::fmt::Debug + Immutable + KnownLayout,
3649 >(
3650 &mut self,
3651 msg: &T,
3652 target: MessageTarget,
3653 ) {
3654 if target == MessageTarget::Default {
3655 self.send_message(msg);
3656 } else {
3657 tracing::trace!(typ = ?T::MESSAGE_TYPE, ?msg, "sending message");
3658 let message = OutgoingMessage::new(msg);
3661 if !self.notifier.send_message(&message, target) {
3662 tracelimit::warn_ratelimited!(?target, "failed to send message");
3663 }
3664 }
3665 }
3666
3667 fn send_offer(&mut self, channel: &mut Channel, version: VersionInfo) {
3669 let info = channel.info.as_ref().expect("assigned");
3670 let mut flags = channel.offer.flags;
3671 if !version.feature_flags.confidential_channels() {
3672 flags.set_confidential_ring_buffer(false);
3673 flags.set_confidential_external_memory(false);
3674 }
3675
3676 let msg = protocol::OfferChannel {
3677 interface_id: channel.offer.interface_id,
3678 instance_id: channel.offer.instance_id,
3679 rsvd: [0; 4],
3680 flags,
3681 mmio_megabytes: channel.offer.mmio_megabytes,
3682 user_defined: channel.offer.user_defined,
3683 subchannel_index: channel.offer.subchannel_index,
3684 mmio_megabytes_optional: channel.offer.mmio_megabytes_optional,
3685 channel_id: info.channel_id,
3686 monitor_id: info.monitor_id.unwrap_or(MonitorId::INVALID).0,
3687 monitor_allocated: info.monitor_id.is_some() as u8,
3688 is_dedicated: 1,
3691 connection_id: info.connection_id,
3692 };
3693 tracing::info!(
3694 channel_id = msg.channel_id.0,
3695 connection_id = msg.connection_id,
3696 key = %channel.offer.key(),
3697 "sending offer to guest"
3698 );
3699
3700 self.send_message(&msg);
3701 }
3702
3703 fn send_open_result(
3704 &mut self,
3705 channel_id: ChannelId,
3706 open_request: &OpenRequest,
3707 result: i32,
3708 target: MessageTarget,
3709 ) {
3710 self.send_message_with_target(
3711 &protocol::OpenResult {
3712 channel_id,
3713 open_id: open_request.open_id,
3714 status: result as u32,
3715 },
3716 target,
3717 );
3718 }
3719
3720 fn send_gpadl_created(&mut self, channel_id: ChannelId, gpadl_id: GpadlId, status: i32) {
3721 self.send_message(&protocol::GpadlCreated {
3722 channel_id,
3723 gpadl_id,
3724 status,
3725 });
3726 }
3727
3728 fn send_gpadl_torndown(&mut self, gpadl_id: GpadlId) {
3729 self.send_message(&protocol::GpadlTorndown { gpadl_id });
3730 }
3731
3732 fn send_rescind(&mut self, info: &OfferedInfo) {
3733 tracing::info!(
3734 channel_id = info.channel_id.0,
3735 "rescinding channel from guest"
3736 );
3737
3738 self.send_message(&protocol::RescindChannelOffer {
3739 channel_id: info.channel_id,
3740 });
3741 }
3742}
3743
3744#[cfg(test)]
3745mod tests {
3746 use crate::MESSAGE_CONNECTION_ID;
3747
3748 use super::*;
3749 use guid::Guid;
3750 use protocol::VmbusMessage;
3751 use std::collections::VecDeque;
3752 use std::sync::mpsc;
3753 use test_with_tracing::test;
3754 use vmbus_core::protocol::TargetInfo;
3755 use zerocopy::FromBytes;
3756
3757 fn in_msg<T: IntoBytes + Immutable + KnownLayout>(
3758 message_type: protocol::MessageType,
3759 t: T,
3760 ) -> SynicMessage {
3761 in_msg_ex(message_type, t, false, false)
3762 }
3763
3764 fn in_msg_ex<T: IntoBytes + Immutable + KnownLayout>(
3765 message_type: protocol::MessageType,
3766 t: T,
3767 multiclient: bool,
3768 trusted: bool,
3769 ) -> SynicMessage {
3770 let mut data = Vec::new();
3771 data.extend_from_slice(&message_type.0.to_ne_bytes());
3772 data.extend_from_slice(&0u32.to_ne_bytes());
3773 data.extend_from_slice(t.as_bytes());
3774 SynicMessage {
3775 data,
3776 multiclient,
3777 trusted,
3778 }
3779 }
3780
3781 #[test]
3782 fn test_version_negotiation_not_supported() {
3783 let (mut notifier, _recv) = TestNotifier::new();
3784 let mut server = Server::new(Vtl::Vtl0, MESSAGE_CONNECTION_ID, 0);
3785
3786 test_initiate_contact(&mut server, &mut notifier, 0xffffffff, 0, false, 0);
3787 }
3788
3789 #[test]
3790 fn test_version_negotiation_success() {
3791 let (mut notifier, _recv) = TestNotifier::new();
3792 let mut server = Server::new(Vtl::Vtl0, MESSAGE_CONNECTION_ID, 0);
3793
3794 test_initiate_contact(
3795 &mut server,
3796 &mut notifier,
3797 Version::Win10 as u32,
3798 0,
3799 true,
3800 0,
3801 );
3802 }
3803
3804 #[test]
3805 fn test_version_negotiation_multiclient_sint() {
3806 let (mut notifier, _recv) = TestNotifier::new();
3807 let mut server = Server::new(Vtl::Vtl0, MESSAGE_CONNECTION_ID, 0);
3808
3809 let target_info = TargetInfo::new()
3810 .with_sint(3)
3811 .with_vtl(0)
3812 .with_feature_flags(FeatureFlags::new().into());
3813
3814 server
3815 .with_notifier(&mut notifier)
3816 .handle_synic_message(in_msg_ex(
3817 protocol::MessageType::INITIATE_CONTACT,
3818 protocol::InitiateContact {
3819 version_requested: Version::Win10Rs3_1 as u32,
3820 target_message_vp: 0,
3821 interrupt_page_or_target_info: target_info.into(),
3822 parent_to_child_monitor_page_gpa: 0,
3823 child_to_parent_monitor_page_gpa: 0,
3824 },
3825 true,
3826 false,
3827 ))
3828 .unwrap();
3829
3830 assert!(notifier.modify_requests.is_empty());
3833 assert!(matches!(server.state, ConnectionState::Disconnected));
3834 notifier.check_message_with_target(
3835 OutgoingMessage::new(&protocol::VersionResponse {
3836 version_supported: 0,
3837 connection_state: protocol::ConnectionState::SUCCESSFUL,
3838 padding: 0,
3839 selected_version_or_connection_id: 0,
3840 }),
3841 MessageTarget::Custom(ConnectionTarget { vp: 0, sint: 3 }),
3842 );
3843
3844 test_initiate_contact(
3846 &mut server,
3847 &mut notifier,
3848 Version::Win10Rs3_1 as u32,
3849 target_info.into(),
3850 true,
3851 0,
3852 );
3853 }
3854
3855 #[test]
3856 fn test_version_negotiation_multiclient_vtl() {
3857 let (mut notifier, _recv) = TestNotifier::new();
3858 let mut server = Server::new(Vtl::Vtl0, MESSAGE_CONNECTION_ID, 0);
3859
3860 let target_info = TargetInfo::new()
3861 .with_sint(SINT)
3862 .with_vtl(2)
3863 .with_feature_flags(FeatureFlags::new().into());
3864
3865 server
3866 .with_notifier(&mut notifier)
3867 .handle_synic_message(in_msg_ex(
3868 protocol::MessageType::INITIATE_CONTACT,
3869 protocol::InitiateContact {
3870 version_requested: Version::Win10Rs4 as u32,
3871 target_message_vp: 0,
3872 interrupt_page_or_target_info: target_info.into(),
3873 parent_to_child_monitor_page_gpa: 0,
3874 child_to_parent_monitor_page_gpa: 0,
3875 },
3876 true,
3877 false,
3878 ))
3879 .unwrap();
3880
3881 let action = notifier.forward_request.take().unwrap();
3882 assert!(matches!(action, InitiateContactRequest { .. }));
3883
3884 assert!(notifier.messages.is_empty());
3886 assert!(matches!(server.state, ConnectionState::Disconnected));
3887
3888 test_initiate_contact(
3890 &mut server,
3891 &mut notifier,
3892 Version::Win10Rs4 as u32,
3893 target_info.into(),
3894 true,
3895 0,
3896 );
3897
3898 assert!(notifier.forward_request.is_none());
3899 }
3900
3901 #[test]
3902 fn test_version_negotiation_feature_flags() {
3903 let (mut notifier, _recv) = TestNotifier::new();
3904 let mut server = Server::new(Vtl::Vtl0, MESSAGE_CONNECTION_ID, 0);
3905
3906 let mut target_info = TargetInfo::new()
3908 .with_sint(SINT)
3909 .with_vtl(0)
3910 .with_feature_flags(FeatureFlags::new().into());
3911 test_initiate_contact(
3912 &mut server,
3913 &mut notifier,
3914 Version::Copper as u32,
3915 target_info.into(),
3916 true,
3917 0,
3918 );
3919
3920 target_info.set_feature_flags(
3922 FeatureFlags::new()
3923 .with_guest_specified_signal_parameters(true)
3924 .into(),
3925 );
3926 test_initiate_contact(
3927 &mut server,
3928 &mut notifier,
3929 Version::Copper as u32,
3930 target_info.into(),
3931 true,
3932 FeatureFlags::new()
3933 .with_guest_specified_signal_parameters(true)
3934 .into(),
3935 );
3936
3937 target_info.set_feature_flags(
3939 u32::from(FeatureFlags::new().with_guest_specified_signal_parameters(true))
3940 | 0xf0000000,
3941 );
3942 test_initiate_contact(
3943 &mut server,
3944 &mut notifier,
3945 Version::Copper as u32,
3946 target_info.into(),
3947 true,
3948 FeatureFlags::new()
3949 .with_guest_specified_signal_parameters(true)
3950 .into(),
3951 );
3952
3953 target_info.set_feature_flags(FeatureFlags::new().with_client_id(true).into());
3955 test_initiate_contact(
3956 &mut server,
3957 &mut notifier,
3958 Version::Copper as u32,
3959 target_info.into(),
3960 true,
3961 FeatureFlags::new().with_client_id(true).into(),
3962 );
3963 }
3964
3965 #[test]
3966 fn test_version_negotiation_interrupt_page() {
3967 let (mut notifier, _recv) = TestNotifier::new();
3968 let mut server = Server::new(Vtl::Vtl0, MESSAGE_CONNECTION_ID, 0);
3969 test_initiate_contact(
3970 &mut server,
3971 &mut notifier,
3972 Version::V1 as u32,
3973 1234,
3974 true,
3975 0,
3976 );
3977
3978 let (mut notifier, _recv) = TestNotifier::new();
3979 let mut server = Server::new(Vtl::Vtl0, MESSAGE_CONNECTION_ID, 0);
3980 test_initiate_contact(
3981 &mut server,
3982 &mut notifier,
3983 Version::Win7 as u32,
3984 1234,
3985 true,
3986 0,
3987 );
3988
3989 let (mut notifier, _recv) = TestNotifier::new();
3990 let mut server = Server::new(Vtl::Vtl0, MESSAGE_CONNECTION_ID, 0);
3991 test_initiate_contact(
3992 &mut server,
3993 &mut notifier,
3994 Version::Win8 as u32,
3995 1234,
3996 true,
3997 0,
3998 );
3999 }
4000
4001 fn test_initiate_contact(
4002 server: &mut Server,
4003 notifier: &mut TestNotifier,
4004 version: u32,
4005 target_info: u64,
4006 expect_supported: bool,
4007 expected_features: u32,
4008 ) {
4009 server
4010 .with_notifier(notifier)
4011 .handle_synic_message(in_msg(
4012 protocol::MessageType::INITIATE_CONTACT,
4013 protocol::InitiateContact2 {
4014 initiate_contact: protocol::InitiateContact {
4015 version_requested: version,
4016 target_message_vp: 1,
4017 interrupt_page_or_target_info: target_info,
4018 parent_to_child_monitor_page_gpa: 0,
4019 child_to_parent_monitor_page_gpa: 0,
4020 },
4021 client_id: guid::guid!("e6e6e6e6-e6e6-e6e6-e6e6-e6e6e6e6e6e6"),
4022 },
4023 ))
4024 .unwrap();
4025
4026 let selected_version_or_connection_id = if expect_supported {
4027 let request = notifier.next_action();
4028 let interrupt_page = if version < Version::Win8 as u32 {
4029 Update::Set(target_info)
4030 } else {
4031 Update::Reset
4032 };
4033
4034 let target_message_vp = if version < Version::Win8_1 as u32 {
4035 Some(0)
4036 } else {
4037 Some(1)
4038 };
4039
4040 assert_eq!(
4041 request,
4042 ModifyConnectionRequest {
4043 version: Some(version),
4044 monitor_page: Update::Reset,
4045 interrupt_page,
4046 target_message_vp,
4047 ..Default::default()
4048 }
4049 );
4050
4051 server.with_notifier(notifier).complete_initiate_contact(
4052 ModifyConnectionResponse::Supported(
4053 protocol::ConnectionState::SUCCESSFUL,
4054 SUPPORTED_FEATURE_FLAGS,
4055 ),
4056 );
4057
4058 if version >= Version::Win10Rs3_1 as u32 {
4059 1
4060 } else {
4061 version
4062 }
4063 } else {
4064 0
4065 };
4066
4067 let version_response = protocol::VersionResponse {
4068 version_supported: if expect_supported { 1 } else { 0 },
4069 connection_state: protocol::ConnectionState::SUCCESSFUL,
4070 padding: 0,
4071 selected_version_or_connection_id,
4072 };
4073
4074 if version >= Version::Copper as u32 && expect_supported {
4075 notifier.check_message(OutgoingMessage::new(&protocol::VersionResponse2 {
4076 version_response,
4077 supported_features: expected_features,
4078 }));
4079 } else {
4080 notifier.check_message(OutgoingMessage::new(&version_response));
4081 assert_eq!(expected_features, 0);
4082 }
4083
4084 assert!(notifier.messages.is_empty());
4085 if expect_supported {
4086 assert!(matches!(server.state, ConnectionState::Connected { .. }));
4087 if version < Version::Win8_1 as u32 {
4088 assert_eq!(Some(0), notifier.target_message_vp);
4089 } else {
4090 assert_eq!(Some(1), notifier.target_message_vp);
4091 }
4092 } else {
4093 assert!(matches!(server.state, ConnectionState::Disconnected));
4094 assert!(notifier.target_message_vp.is_none());
4095 }
4096
4097 if version < Version::Win8 as u32 {
4098 assert_eq!(notifier.interrupt_page, Some(target_info));
4099 } else {
4100 assert!(notifier.interrupt_page.is_none());
4101 }
4102 }
4103
4104 struct TestNotifier {
4105 send: mpsc::Sender<(OfferId, Action)>,
4106 modify_requests: VecDeque<ModifyConnectionRequest>,
4107 messages: VecDeque<(OutgoingMessage, MessageTarget)>,
4108 hvsock_requests: Vec<HvsockConnectRequest>,
4109 forward_request: Option<InitiateContactRequest>,
4110 interrupt_page: Option<u64>,
4111 reset: bool,
4112 monitor_page: Option<MonitorPageGpas>,
4113 target_message_vp: Option<u32>,
4114 pend_messages: bool,
4115 }
4116
4117 impl TestNotifier {
4118 fn new() -> (Self, mpsc::Receiver<(OfferId, Action)>) {
4119 let (send, recv) = mpsc::channel();
4120 (
4121 Self {
4122 send,
4123 modify_requests: VecDeque::new(),
4124 messages: VecDeque::new(),
4125 hvsock_requests: Vec::new(),
4126 forward_request: None,
4127 interrupt_page: None,
4128 reset: false,
4129 monitor_page: None,
4130 target_message_vp: None,
4131 pend_messages: false,
4132 },
4133 recv,
4134 )
4135 }
4136
4137 fn check_message(&mut self, message: OutgoingMessage) {
4138 self.check_message_with_target(message, MessageTarget::Default);
4139 }
4140
4141 fn check_message_with_target(&mut self, message: OutgoingMessage, target: MessageTarget) {
4142 assert_eq!(self.messages.pop_front().unwrap(), (message, target));
4143 assert!(self.messages.is_empty());
4144 }
4145
4146 fn get_message<T: VmbusMessage + FromBytes + Immutable + KnownLayout>(&mut self) -> T {
4147 let (message, _) = self.messages.pop_front().unwrap();
4148 let (header, data) = protocol::MessageHeader::read_from_prefix(message.data()).unwrap();
4149
4150 assert_eq!(header.message_type(), T::MESSAGE_TYPE);
4151 T::read_from_prefix(data).unwrap().0 }
4153
4154 fn check_messages(&mut self, messages: &[OutgoingMessage]) {
4155 let messages: Vec<_> = messages
4156 .iter()
4157 .map(|m| (m.clone(), MessageTarget::Default))
4158 .collect();
4159 assert_eq!(self.messages, messages.as_slice());
4160 self.messages.clear();
4161 }
4162
4163 fn is_reset(&mut self) -> bool {
4164 std::mem::replace(&mut self.reset, false)
4165 }
4166
4167 fn check_reset(&mut self) {
4168 assert!(self.is_reset());
4169 assert!(self.monitor_page.is_none());
4170 assert!(self.target_message_vp.is_none());
4171 }
4172
4173 fn next_action(&mut self) -> ModifyConnectionRequest {
4174 self.modify_requests.pop_front().unwrap()
4175 }
4176 }
4177
4178 impl Notifier for TestNotifier {
4179 fn notify(&mut self, offer_id: OfferId, action: Action) {
4180 tracing::debug!(?offer_id, ?action, "notify");
4181 self.send.send((offer_id, action)).unwrap()
4182 }
4183
4184 fn forward_unhandled(&mut self, request: InitiateContactRequest) {
4185 assert!(self.forward_request.is_none());
4186 self.forward_request = Some(request);
4187 }
4188
4189 fn modify_connection(&mut self, request: ModifyConnectionRequest) -> anyhow::Result<()> {
4190 match request.monitor_page {
4191 Update::Unchanged => (),
4192 Update::Reset => self.monitor_page = None,
4193 Update::Set(value) => self.monitor_page = Some(value),
4194 }
4195
4196 if let Some(vp) = request.target_message_vp {
4197 self.target_message_vp = Some(vp);
4198 }
4199
4200 match request.interrupt_page {
4201 Update::Unchanged => (),
4202 Update::Reset => self.interrupt_page = None,
4203 Update::Set(value) => self.interrupt_page = Some(value),
4204 }
4205
4206 self.modify_requests.push_back(request);
4207 Ok(())
4208 }
4209
4210 fn send_message(&mut self, message: &OutgoingMessage, target: MessageTarget) -> bool {
4211 if self.pend_messages {
4212 return false;
4213 }
4214
4215 self.messages.push_back((message.clone(), target));
4216 true
4217 }
4218
4219 fn notify_hvsock(&mut self, request: &HvsockConnectRequest) {
4220 tracing::debug!(?request, "notify_hvsock");
4221 self.hvsock_requests.push(*request);
4224 }
4225
4226 fn reset_complete(&mut self) {
4227 self.monitor_page = None;
4228 self.target_message_vp = None;
4229 self.reset = true;
4230 }
4231
4232 fn unload_complete(&mut self) {}
4233 }
4234
4235 #[test]
4236 fn test_channel_lifetime() {
4237 test_channel_lifetime_helper(Version::Win10Rs5, FeatureFlags::new());
4238 }
4239
4240 #[test]
4241 fn test_channel_lifetime_iron() {
4242 test_channel_lifetime_helper(Version::Iron, FeatureFlags::new());
4243 }
4244
4245 #[test]
4246 fn test_channel_lifetime_copper() {
4247 test_channel_lifetime_helper(Version::Copper, FeatureFlags::new());
4248 }
4249
4250 #[test]
4251 fn test_channel_lifetime_copper_guest_signal() {
4252 test_channel_lifetime_helper(
4253 Version::Copper,
4254 FeatureFlags::new().with_guest_specified_signal_parameters(true),
4255 );
4256 }
4257
4258 #[test]
4259 fn test_channel_lifetime_copper_open_flags() {
4260 test_channel_lifetime_helper(
4261 Version::Copper,
4262 FeatureFlags::new().with_channel_interrupt_redirection(true),
4263 );
4264 }
4265
4266 fn test_channel_lifetime_helper(version: Version, feature_flags: FeatureFlags) {
4267 let (mut notifier, recv) = TestNotifier::new();
4268 let mut server = Server::new(Vtl::Vtl0, MESSAGE_CONNECTION_ID, 0);
4269 let interface_id = Guid::new_random();
4270 let instance_id = Guid::new_random();
4271 let offer_id = server
4272 .with_notifier(&mut notifier)
4273 .offer_channel(OfferParamsInternal {
4274 interface_name: "test".to_owned(),
4275 instance_id,
4276 interface_id,
4277 ..Default::default()
4278 })
4279 .unwrap();
4280
4281 let mut target_info = TargetInfo::new()
4282 .with_sint(SINT)
4283 .with_vtl(2)
4284 .with_feature_flags(FeatureFlags::new().into());
4285 if version >= Version::Copper {
4286 target_info.set_feature_flags(feature_flags.into());
4287 }
4288
4289 server
4290 .with_notifier(&mut notifier)
4291 .handle_synic_message(in_msg(
4292 protocol::MessageType::INITIATE_CONTACT,
4293 protocol::InitiateContact {
4294 version_requested: version as u32,
4295 target_message_vp: 0,
4296 interrupt_page_or_target_info: target_info.into(),
4297 parent_to_child_monitor_page_gpa: 0,
4298 child_to_parent_monitor_page_gpa: 0,
4299 },
4300 ))
4301 .unwrap();
4302
4303 let request = notifier.next_action();
4304 assert_eq!(
4305 request,
4306 ModifyConnectionRequest {
4307 version: Some(version as u32),
4308 monitor_page: Update::Reset,
4309 interrupt_page: Update::Reset,
4310 target_message_vp: Some(0),
4311 ..Default::default()
4312 }
4313 );
4314
4315 server
4316 .with_notifier(&mut notifier)
4317 .complete_initiate_contact(ModifyConnectionResponse::Supported(
4318 protocol::ConnectionState::SUCCESSFUL,
4319 SUPPORTED_FEATURE_FLAGS,
4320 ));
4321
4322 let version_response = protocol::VersionResponse {
4323 version_supported: 1,
4324 selected_version_or_connection_id: 1,
4325 ..FromZeros::new_zeroed()
4326 };
4327
4328 if version >= Version::Copper {
4329 notifier.check_message(OutgoingMessage::new(&protocol::VersionResponse2 {
4330 version_response,
4331 supported_features: feature_flags.into(),
4332 }));
4333 } else {
4334 notifier.check_message(OutgoingMessage::new(&version_response));
4335 }
4336
4337 server
4338 .with_notifier(&mut notifier)
4339 .handle_synic_message(in_msg(protocol::MessageType::REQUEST_OFFERS, ()))
4340 .unwrap();
4341
4342 let channel_id = ChannelId(1);
4343 notifier.check_messages(&[
4344 OutgoingMessage::new(&protocol::OfferChannel {
4345 interface_id,
4346 instance_id,
4347 channel_id,
4348 connection_id: 0x2001,
4349 is_dedicated: 1,
4350 monitor_id: 0xff,
4351 ..protocol::OfferChannel::new_zeroed()
4352 }),
4353 OutgoingMessage::new(&protocol::AllOffersDelivered {}),
4354 ]);
4355
4356 let open_channel = protocol::OpenChannel {
4357 channel_id,
4358 open_id: 1,
4359 ring_buffer_gpadl_id: GpadlId(1),
4360 target_vp: 3,
4361 downstream_ring_buffer_page_offset: 2,
4362 user_data: UserDefinedData::new_zeroed(),
4363 };
4364
4365 let mut event_flag = 1;
4366 let mut connection_id = 0x2001;
4367 let mut expected_flags = protocol::OpenChannelFlags::new();
4368 if version >= Version::Copper
4369 && (feature_flags.guest_specified_signal_parameters()
4370 || feature_flags.channel_interrupt_redirection())
4371 {
4372 if feature_flags.channel_interrupt_redirection() {
4373 expected_flags.set_redirect_interrupt(true);
4374 }
4375
4376 if feature_flags.guest_specified_signal_parameters() {
4377 event_flag = 2;
4378 connection_id = 0x2002;
4379 }
4380
4381 server
4382 .with_notifier(&mut notifier)
4383 .handle_synic_message(in_msg(
4384 protocol::MessageType::OPEN_CHANNEL,
4385 protocol::OpenChannel2 {
4386 open_channel,
4387 event_flag: 2,
4388 connection_id: 0x2002,
4389 flags: (u16::from(
4390 protocol::OpenChannelFlags::new().with_redirect_interrupt(true),
4391 ) | 0xabc)
4392 .into(), },
4394 ))
4395 .unwrap();
4396 } else {
4397 server
4398 .with_notifier(&mut notifier)
4399 .handle_synic_message(in_msg(protocol::MessageType::OPEN_CHANNEL, open_channel))
4400 .unwrap();
4401 }
4402
4403 let (id, action) = recv.recv().unwrap();
4404 assert_eq!(id, offer_id);
4405 let Action::Open(op, ..) = action else {
4406 panic!("unexpected action: {:?}", action);
4407 };
4408 assert_eq!(op.open_data.ring_gpadl_id, GpadlId(1));
4409 assert_eq!(op.open_data.ring_offset, 2);
4410 assert_eq!(op.open_data.target_vp, 3);
4411 assert_eq!(op.open_data.event_flag, event_flag);
4412 assert_eq!(op.open_data.connection_id, connection_id);
4413 assert_eq!(op.connection_id, connection_id);
4414 assert_eq!(op.event_flag, event_flag);
4415 assert_eq!(op.monitor_info, None);
4416 assert_eq!(op.flags, expected_flags);
4417
4418 server
4419 .with_notifier(&mut notifier)
4420 .open_complete(offer_id, 0);
4421
4422 notifier.check_message(OutgoingMessage::new(&protocol::OpenResult {
4423 channel_id,
4424 open_id: 1,
4425 status: 0,
4426 }));
4427
4428 server
4429 .with_notifier(&mut notifier)
4430 .handle_synic_message(in_msg(
4431 protocol::MessageType::MODIFY_CHANNEL,
4432 protocol::ModifyChannel {
4433 channel_id,
4434 target_vp: 4,
4435 },
4436 ))
4437 .unwrap();
4438
4439 let (id, action) = recv.recv().unwrap();
4440 assert_eq!(id, offer_id);
4441 assert!(matches!(action, Action::Modify { target_vp: 4 }));
4442
4443 server
4444 .with_notifier(&mut notifier)
4445 .modify_channel_complete(id, 0);
4446
4447 if version >= Version::Iron {
4448 notifier.check_message(OutgoingMessage::new(&protocol::ModifyChannelResponse {
4449 channel_id,
4450 status: 0,
4451 }));
4452 }
4453
4454 assert!(notifier.messages.is_empty());
4455
4456 server.with_notifier(&mut notifier).revoke_channel(offer_id);
4457
4458 server
4459 .with_notifier(&mut notifier)
4460 .handle_synic_message(in_msg(
4461 protocol::MessageType::REL_ID_RELEASED,
4462 protocol::RelIdReleased { channel_id },
4463 ))
4464 .unwrap();
4465 }
4466
4467 #[test]
4468 fn test_hvsock() {
4469 test_hvsock_helper(Version::Win10, false);
4470 }
4471
4472 #[test]
4473 fn test_hvsock_rs3() {
4474 test_hvsock_helper(Version::Win10Rs3_0, false);
4475 }
4476
4477 #[test]
4478 fn test_hvsock_rs5() {
4479 test_hvsock_helper(Version::Win10Rs5, false);
4480 test_hvsock_helper(Version::Win10Rs5, true);
4481 }
4482
4483 fn test_hvsock_helper(version: Version, force_small_message: bool) {
4484 let (mut notifier, _recv) = TestNotifier::new();
4485 let mut server = Server::new(Vtl::Vtl0, MESSAGE_CONNECTION_ID, 0);
4486
4487 server
4488 .with_notifier(&mut notifier)
4489 .handle_synic_message(in_msg(
4490 protocol::MessageType::INITIATE_CONTACT,
4491 protocol::InitiateContact {
4492 version_requested: version as u32,
4493 target_message_vp: 0,
4494 interrupt_page_or_target_info: 0,
4495 parent_to_child_monitor_page_gpa: 0,
4496 child_to_parent_monitor_page_gpa: 0,
4497 },
4498 ))
4499 .unwrap();
4500
4501 let request = notifier.next_action();
4502 assert_eq!(
4503 request,
4504 ModifyConnectionRequest {
4505 version: Some(version as u32),
4506 monitor_page: Update::Reset,
4507 interrupt_page: Update::Reset,
4508 target_message_vp: Some(0),
4509 ..Default::default()
4510 }
4511 );
4512
4513 server
4514 .with_notifier(&mut notifier)
4515 .complete_initiate_contact(ModifyConnectionResponse::Supported(
4516 protocol::ConnectionState::SUCCESSFUL,
4517 SUPPORTED_FEATURE_FLAGS,
4518 ));
4519
4520 notifier.messages.pop_front();
4522
4523 let service_id = Guid::new_random();
4524 let endpoint_id = Guid::new_random();
4525 let request_msg = if version >= Version::Win10Rs5 && !force_small_message {
4526 in_msg(
4527 protocol::MessageType::TL_CONNECT_REQUEST,
4528 protocol::TlConnectRequest2 {
4529 base: protocol::TlConnectRequest {
4530 service_id,
4531 endpoint_id,
4532 },
4533 silo_id: Guid::ZERO,
4534 },
4535 )
4536 } else {
4537 in_msg(
4538 protocol::MessageType::TL_CONNECT_REQUEST,
4539 protocol::TlConnectRequest {
4540 service_id,
4541 endpoint_id,
4542 },
4543 )
4544 };
4545
4546 server
4547 .with_notifier(&mut notifier)
4548 .handle_synic_message(request_msg)
4549 .unwrap();
4550
4551 let request = notifier.hvsock_requests.pop().unwrap();
4552 assert_eq!(request.service_id, service_id);
4553 assert_eq!(request.endpoint_id, endpoint_id);
4554 assert!(notifier.hvsock_requests.is_empty());
4555
4556 server
4558 .with_notifier(&mut notifier)
4559 .send_tl_connect_result(HvsockConnectResult::from_request(&request, false));
4560
4561 if version >= Version::Win10Rs3_0 {
4562 notifier.check_message(OutgoingMessage::new(&protocol::TlConnectResult {
4563 service_id: request.service_id,
4564 endpoint_id: request.endpoint_id,
4565 status: protocol::STATUS_CONNECTION_REFUSED,
4566 }));
4567 }
4568
4569 assert!(notifier.messages.is_empty());
4570 }
4571
4572 struct TestEnv {
4573 server: Server,
4574 notifier: TestNotifier,
4575 version: Option<VersionInfo>,
4576 _recv: mpsc::Receiver<(OfferId, Action)>,
4577 }
4578
4579 impl TestEnv {
4580 fn new() -> Self {
4581 let (notifier, _recv) = TestNotifier::new();
4582 let server = Server::new(Vtl::Vtl0, MESSAGE_CONNECTION_ID, 0);
4583 Self {
4584 server,
4585 notifier,
4586 version: None,
4587 _recv,
4588 }
4589 }
4590
4591 fn c(&mut self) -> ServerWithNotifier<'_, TestNotifier> {
4592 self.server.with_notifier(&mut self.notifier)
4593 }
4594
4595 fn complete_reset(&mut self) {
4599 let _ = self.next_action();
4600 self.c()
4601 .complete_modify_connection(ModifyConnectionResponse::Supported(
4602 protocol::ConnectionState::SUCCESSFUL,
4603 SUPPORTED_FEATURE_FLAGS,
4604 ));
4605 }
4606
4607 fn offer(&mut self, id: u32) -> OfferId {
4608 self.offer_inner(id, id, MnfUsage::Disabled, None, OfferFlags::new())
4609 }
4610
4611 fn offer_with_mnf(&mut self, id: u32) -> OfferId {
4612 self.offer_inner(
4613 id,
4614 id,
4615 MnfUsage::Enabled {
4616 latency: Duration::from_micros(100),
4617 },
4618 None,
4619 OfferFlags::new(),
4620 )
4621 }
4622
4623 fn offer_with_preset_mnf(&mut self, id: u32, monitor_id: u8) -> OfferId {
4624 self.offer_inner(
4625 id,
4626 id,
4627 MnfUsage::Relayed { monitor_id },
4628 None,
4629 OfferFlags::new(),
4630 )
4631 }
4632
4633 fn offer_with_order(
4634 &mut self,
4635 interface_id: u32,
4636 instance_id: u32,
4637 order: Option<u32>,
4638 ) -> OfferId {
4639 self.offer_inner(
4640 interface_id,
4641 instance_id,
4642 MnfUsage::Disabled,
4643 order,
4644 OfferFlags::new(),
4645 )
4646 }
4647
4648 fn offer_with_flags(&mut self, id: u32, flags: OfferFlags) -> OfferId {
4649 self.offer_inner(id, id, MnfUsage::Disabled, None, flags)
4650 }
4651
4652 fn offer_inner(
4653 &mut self,
4654 interface_id: u32,
4655 instance_id: u32,
4656 use_mnf: MnfUsage,
4657 offer_order: Option<u32>,
4658 flags: OfferFlags,
4659 ) -> OfferId {
4660 self.c()
4661 .offer_channel(OfferParamsInternal {
4662 instance_id: Guid {
4663 data1: instance_id,
4664 ..Guid::ZERO
4665 },
4666 interface_id: Guid {
4667 data1: interface_id,
4668 ..Guid::ZERO
4669 },
4670 use_mnf,
4671 offer_order,
4672 flags,
4673 ..Default::default()
4674 })
4675 .unwrap()
4676 }
4677
4678 fn open(&mut self, id: u32) {
4679 self.c()
4680 .handle_open_channel(&protocol::OpenChannel2 {
4681 open_channel: protocol::OpenChannel {
4682 channel_id: ChannelId(id),
4683 ..FromZeros::new_zeroed()
4684 },
4685 ..FromZeros::new_zeroed()
4686 })
4687 .unwrap()
4688 }
4689
4690 fn close(&mut self, id: u32) -> Result<(), ChannelError> {
4691 self.c().handle_close_channel(&protocol::CloseChannel {
4692 channel_id: ChannelId(id),
4693 })
4694 }
4695
4696 fn open_reserved(&mut self, id: u32, target_vp: u32, target_sint: u32) {
4697 let version = self.server.state.get_version().expect("vmbus connected");
4698
4699 self.c()
4700 .handle_open_reserved_channel(
4701 &protocol::OpenReservedChannel {
4702 channel_id: ChannelId(id),
4703 target_vp,
4704 target_sint,
4705 ring_buffer_gpadl: GpadlId(id),
4706 ..FromZeros::new_zeroed()
4707 },
4708 version,
4709 )
4710 .unwrap()
4711 }
4712
4713 fn close_reserved(&mut self, id: u32, target_vp: u32, target_sint: u32) {
4714 self.c()
4715 .handle_close_reserved_channel(&protocol::CloseReservedChannel {
4716 channel_id: ChannelId(id),
4717 target_vp,
4718 target_sint,
4719 })
4720 .unwrap();
4721 }
4722
4723 fn gpadl(&mut self, channel_id: u32, gpadl_id: u32) {
4724 self.c()
4725 .handle_gpadl_header(
4726 &protocol::GpadlHeader {
4727 channel_id: ChannelId(channel_id),
4728 gpadl_id: GpadlId(gpadl_id),
4729 count: 1,
4730 len: 16,
4731 },
4732 [1u64, 0u64].as_bytes(),
4733 )
4734 .unwrap();
4735 }
4736
4737 fn teardown_gpadl(&mut self, channel_id: u32, gpadl_id: u32) {
4738 self.c()
4739 .handle_gpadl_teardown(&protocol::GpadlTeardown {
4740 channel_id: ChannelId(channel_id),
4741 gpadl_id: GpadlId(gpadl_id),
4742 })
4743 .unwrap();
4744 }
4745
4746 fn release(&mut self, id: u32) {
4747 self.c()
4748 .handle_rel_id_released(&protocol::RelIdReleased {
4749 channel_id: ChannelId(id),
4750 })
4751 .unwrap();
4752 }
4753
4754 fn connect(&mut self, version: Version, feature_flags: FeatureFlags) {
4755 self.start_connect(version, feature_flags, false);
4756 self.complete_connect();
4757 }
4758
4759 fn connect_trusted(&mut self, version: Version, feature_flags: FeatureFlags) {
4760 self.start_connect(version, feature_flags, true);
4761 self.complete_connect();
4762 }
4763
4764 fn start_connect(&mut self, version: Version, feature_flags: FeatureFlags, trusted: bool) {
4765 self.version = Some(VersionInfo {
4766 version,
4767 feature_flags,
4768 });
4769
4770 let result = self.c().handle_synic_message(in_msg_ex(
4771 protocol::MessageType::INITIATE_CONTACT,
4772 protocol::InitiateContact2 {
4773 initiate_contact: protocol::InitiateContact {
4774 version_requested: version as u32,
4775 interrupt_page_or_target_info: TargetInfo::new()
4776 .with_sint(SINT)
4777 .with_vtl(0)
4778 .with_feature_flags(feature_flags.into())
4779 .into(),
4780 child_to_parent_monitor_page_gpa: 0x123f000,
4781 parent_to_child_monitor_page_gpa: 0x321f000,
4782 ..FromZeros::new_zeroed()
4783 },
4784 client_id: Guid::ZERO,
4785 },
4786 false,
4787 trusted,
4788 ));
4789 assert!(result.is_ok());
4790
4791 let request = self.notifier.next_action();
4792 assert_eq!(
4793 request,
4794 ModifyConnectionRequest {
4795 version: Some(version as u32),
4796 monitor_page: Update::Set(MonitorPageGpas {
4797 child_to_parent: 0x123f000,
4798 parent_to_child: 0x321f000,
4799 }),
4800 interrupt_page: Update::Reset,
4801 target_message_vp: Some(0),
4802 ..Default::default()
4803 }
4804 );
4805 }
4806
4807 fn complete_connect(&mut self) {
4808 self.c()
4809 .complete_initiate_contact(ModifyConnectionResponse::Supported(
4810 protocol::ConnectionState::SUCCESSFUL,
4811 SUPPORTED_FEATURE_FLAGS,
4812 ));
4813
4814 let version = self.version.unwrap();
4815 if version.version >= Version::Copper {
4816 let response = self.notifier.get_message::<protocol::VersionResponse2>();
4817 assert_eq!(response.version_response.version_supported, 1);
4818 self.version = Some(VersionInfo {
4819 version: version.version,
4820 feature_flags: version.feature_flags & response.supported_features.into(),
4821 })
4822 } else {
4823 let response = self.notifier.get_message::<protocol::VersionResponse>();
4824 assert_eq!(response.version_supported, 1);
4825 }
4826 }
4827
4828 fn send_message(&mut self, message: SynicMessage) {
4829 self.try_send_message(message).unwrap();
4830 }
4831
4832 fn try_send_message(&mut self, message: SynicMessage) -> Result<(), ChannelError> {
4833 self.c().handle_synic_message(message)
4834 }
4835
4836 fn next_action(&mut self) -> ModifyConnectionRequest {
4837 self.notifier.next_action()
4838 }
4839 }
4840
4841 #[test]
4843 fn test_hot_add() {
4844 let mut env = TestEnv::new();
4845 let offer_id1 = env.offer(1);
4846 let result = env.c().handle_initiate_contact(
4847 &protocol::InitiateContact2 {
4848 initiate_contact: protocol::InitiateContact {
4849 version_requested: Version::Win10 as u32,
4850 ..FromZeros::new_zeroed()
4851 },
4852 ..FromZeros::new_zeroed()
4853 },
4854 &SynicMessage::default(),
4855 true,
4856 );
4857 assert!(result.is_ok());
4858 let offer_id2 = env.offer(2);
4859 env.c()
4860 .complete_initiate_contact(ModifyConnectionResponse::Supported(
4861 protocol::ConnectionState::SUCCESSFUL,
4862 SUPPORTED_FEATURE_FLAGS,
4863 ));
4864 let offer_id3 = env.offer(3);
4865 env.c().handle_request_offers().unwrap();
4866 let offer_id4 = env.offer(4);
4867 env.open(1);
4868 env.open(2);
4869 env.open(3);
4870 env.open(4);
4871 env.c().open_complete(offer_id1, 0);
4872 env.c().open_complete(offer_id2, 0);
4873 env.c().open_complete(offer_id3, 0);
4874 env.c().open_complete(offer_id4, 0);
4875 env.c().reset();
4876 env.c().close_complete(offer_id1);
4877 env.c().close_complete(offer_id2);
4878 env.c().close_complete(offer_id3);
4879 env.c().close_complete(offer_id4);
4880 env.complete_reset();
4881 assert!(env.notifier.is_reset());
4882 }
4883
4884 #[test]
4885 fn test_save_restore_with_no_connection() {
4886 let mut env = TestEnv::new();
4887
4888 let offer_id1 = env.offer(1);
4889 let _offer_id2 = env.offer(2);
4890
4891 let state = env.server.save();
4892 env.c().reset();
4893 assert!(env.notifier.is_reset());
4894 env.c().restore(state).unwrap();
4895 env.c().restore_channel(offer_id1, false).unwrap();
4896 }
4897
4898 #[test]
4899 fn test_save_restore_with_connection() {
4900 let mut env = TestEnv::new();
4901
4902 let offer_id1 = env.offer_with_mnf(1);
4903 let offer_id2 = env.offer(2);
4904 let offer_id3 = env.offer_with_mnf(3);
4905 let offer_id4 = env.offer(4);
4906 let offer_id5 = env.offer_with_mnf(5);
4907 let offer_id6 = env.offer(6);
4908 let offer_id7 = env.offer(7);
4909 let offer_id8 = env.offer(8);
4910 let offer_id9 = env.offer(9);
4911 let offer_id10 = env.offer(10);
4912
4913 let expected_monitor = MonitorPageGpas {
4914 child_to_parent: 0x123f000,
4915 parent_to_child: 0x321f000,
4916 };
4917
4918 env.connect(Version::Win10, FeatureFlags::new());
4919 assert_eq!(env.notifier.monitor_page, Some(expected_monitor));
4920
4921 env.c().handle_request_offers().unwrap();
4922 assert_eq!(env.server.assigned_monitors.bitmap(), 7);
4923
4924 env.open(1);
4925 env.open(2);
4926 env.open(3);
4927 env.open(5);
4928
4929 env.c().open_complete(offer_id1, 0);
4930 env.c().open_complete(offer_id2, 0);
4931 env.c().open_complete(offer_id5, 0);
4932
4933 env.gpadl(1, 10);
4934 env.c().gpadl_create_complete(offer_id1, GpadlId(10), 0);
4935 env.gpadl(1, 11);
4936 env.gpadl(2, 20);
4937 env.c().gpadl_create_complete(offer_id2, GpadlId(20), 0);
4938 env.gpadl(2, 21);
4939 env.gpadl(3, 30);
4940 env.c().gpadl_create_complete(offer_id3, GpadlId(30), 0);
4941 env.gpadl(3, 31);
4942
4943 env.open_reserved(7, 1, SINT.into());
4945 env.open_reserved(8, 2, SINT.into());
4946 env.open_reserved(9, 3, SINT.into());
4947 env.c().open_complete(offer_id8, 0);
4948 env.c().open_complete(offer_id9, 0);
4949 env.close_reserved(9, 3, SINT.into());
4950
4951 env.c().revoke_channel(offer_id10);
4954 let offer_id10 = env.offer(10);
4955
4956 let state = env.server.save();
4957
4958 env.c().reset();
4959
4960 env.c().close_complete(offer_id1);
4961 env.c().close_complete(offer_id2);
4962 env.c().open_complete(offer_id3, -1);
4963 env.c().close_complete(offer_id5);
4964 env.c().open_complete(offer_id7, -1);
4965 env.c().close_complete(offer_id8);
4966 env.c().close_complete(offer_id9);
4967
4968 env.c().gpadl_teardown_complete(offer_id1, GpadlId(10));
4969 env.c().gpadl_create_complete(offer_id1, GpadlId(11), -1);
4970 env.c().gpadl_teardown_complete(offer_id2, GpadlId(20));
4971 env.c().gpadl_create_complete(offer_id2, GpadlId(21), -1);
4972 env.c().gpadl_teardown_complete(offer_id3, GpadlId(30));
4973 env.c().gpadl_create_complete(offer_id3, GpadlId(31), -1);
4974
4975 env.complete_reset();
4976 env.notifier.check_reset();
4977
4978 env.c().revoke_channel(offer_id5);
4979 env.c().revoke_channel(offer_id6);
4980
4981 env.c().restore(state.clone()).unwrap();
4982
4983 env.c().revoke_channel(offer_id1);
4984 env.c().revoke_channel(offer_id4);
4985 env.c().restore_channel(offer_id3, false).unwrap();
4986 let offer_id5 = env.offer_with_mnf(5);
4987 env.c().restore_channel(offer_id5, true).unwrap();
4988 env.c().restore_channel(offer_id7, false).unwrap();
4989 env.c().restore_channel(offer_id8, true).unwrap();
4990 env.c().restore_channel(offer_id9, true).unwrap();
4991 env.c().restore_channel(offer_id10, false).unwrap();
4992 assert!(matches!(
4993 env.server.channels[offer_id10].state,
4994 ChannelState::Reoffered
4995 ));
4996
4997 env.c().revoke_unclaimed_channels();
4998
4999 assert_eq!(env.notifier.monitor_page, Some(expected_monitor));
5000 assert_eq!(env.notifier.target_message_vp, Some(0));
5001
5002 assert_eq!(env.server.assigned_monitors.bitmap(), 6);
5003 env.release(1);
5004 env.release(2);
5005 env.release(4);
5006
5007 env.c().open_complete(offer_id7, 0);
5009 env.close_reserved(8, 2, SINT.into());
5010 env.c().close_complete(offer_id8);
5011 env.c().close_complete(offer_id9);
5012
5013 env.c().reset();
5014
5015 env.c().open_complete(offer_id3, -1);
5016 env.c().gpadl_teardown_complete(offer_id3, GpadlId(30));
5017 env.c().gpadl_create_complete(offer_id3, GpadlId(31), -1);
5018 env.c().close_complete(offer_id5);
5019 env.c().close_complete(offer_id7);
5020
5021 env.complete_reset();
5022 env.notifier.check_reset();
5023
5024 env.c().restore(state).unwrap();
5025 env.c().restore_channel(offer_id3, false).unwrap();
5026 assert_eq!(env.notifier.monitor_page, Some(expected_monitor));
5027 assert_eq!(env.notifier.target_message_vp, Some(0));
5028 }
5029
5030 #[test]
5031 fn test_save_restore_connecting() {
5032 let mut env = TestEnv::new();
5033
5034 let offer_id1 = env.offer_with_mnf(1);
5035 let _offer_id2 = env.offer(2);
5036
5037 env.start_connect(Version::Win10, FeatureFlags::new(), false);
5038 assert_eq!(
5039 env.notifier.monitor_page,
5040 Some(MonitorPageGpas {
5041 child_to_parent: 0x123f000,
5042 parent_to_child: 0x321f000
5043 })
5044 );
5045
5046 let state = env.server.save();
5047
5048 env.c().reset();
5049 env.complete_connect();
5051 env.notifier.check_reset();
5052
5053 env.c().restore(state).unwrap();
5054 env.c().restore_channel(offer_id1, false).unwrap();
5055 assert_eq!(
5056 env.notifier.monitor_page,
5057 Some(MonitorPageGpas {
5058 child_to_parent: 0x123f000,
5059 parent_to_child: 0x321f000
5060 })
5061 );
5062
5063 let request = env.next_action();
5065 assert_eq!(
5066 request,
5067 ModifyConnectionRequest {
5068 version: Some(Version::Win10 as u32),
5069 monitor_page: Update::Set(MonitorPageGpas {
5070 child_to_parent: 0x123f000,
5071 parent_to_child: 0x321f000,
5072 }),
5073 interrupt_page: Update::Reset,
5074 target_message_vp: Some(0),
5075 ..Default::default()
5076 }
5077 );
5078
5079 assert_eq!(Some(0), env.notifier.target_message_vp);
5080
5081 env.complete_connect();
5083 }
5084
5085 #[test]
5086 fn test_save_restore_modifying() {
5087 let mut env = TestEnv::new();
5088 env.connect(
5089 Version::Copper,
5090 FeatureFlags::new().with_modify_connection(true),
5091 );
5092
5093 let expected = MonitorPageGpas {
5094 parent_to_child: 0x123f000,
5095 child_to_parent: 0x321f000,
5096 };
5097
5098 env.send_message(in_msg(
5099 protocol::MessageType::MODIFY_CONNECTION,
5100 protocol::ModifyConnection {
5101 parent_to_child_monitor_page_gpa: expected.parent_to_child,
5102 child_to_parent_monitor_page_gpa: expected.child_to_parent,
5103 },
5104 ));
5105
5106 env.next_action();
5108
5109 assert_eq!(env.notifier.monitor_page, Some(expected));
5110
5111 let state = env.server.save();
5112 env.c().reset();
5113 env.notifier.check_reset();
5114
5115 env.c().restore(state).unwrap();
5116
5117 let request = env.next_action();
5119 assert_eq!(
5120 request,
5121 ModifyConnectionRequest {
5122 monitor_page: Update::Set(MonitorPageGpas {
5123 parent_to_child: 0x123f000,
5124 child_to_parent: 0x321f000,
5125 }),
5126 interrupt_page: Update::Reset,
5127 target_message_vp: Some(0),
5128 ..Default::default()
5129 }
5130 );
5131
5132 assert_eq!(env.notifier.monitor_page, Some(expected));
5133
5134 env.c()
5136 .complete_modify_connection(ModifyConnectionResponse::Supported(
5137 protocol::ConnectionState::SUCCESSFUL,
5138 SUPPORTED_FEATURE_FLAGS,
5139 ));
5140
5141 env.notifier
5142 .check_message(OutgoingMessage::new(&protocol::ModifyConnectionResponse {
5143 connection_state: protocol::ConnectionState::SUCCESSFUL,
5144 }));
5145 }
5146
5147 #[test]
5148 fn test_save_restore_disconnected_reserved() {
5149 let mut env = TestEnv::new();
5150
5151 let offer_id1 = env.offer(1);
5152 let _offer_id2 = env.offer(2);
5153 let _offer_id3 = env.offer(3);
5154
5155 env.connect(Version::Copper, FeatureFlags::new());
5156 env.c().handle_request_offers().unwrap();
5157
5158 env.gpadl(1, 1);
5159 env.c().gpadl_create_complete(offer_id1, GpadlId(1), 0);
5160 env.open_reserved(1, 0, 3);
5161 env.c().open_complete(offer_id1, protocol::STATUS_SUCCESS);
5162 env.c().handle_unload();
5163
5164 let state = env.server.save();
5165 let mut env = TestEnv::new();
5166 let offer_id1 = env.offer(1);
5167 let offer_id2 = env.offer(2);
5168 let offer_id3 = env.offer(3);
5169
5170 env.c().restore(state).unwrap();
5171
5172 env.c().restore_channel(offer_id1, true).unwrap();
5174 env.c().restore_channel(offer_id2, false).unwrap();
5175 env.c().restore_channel(offer_id3, false).unwrap();
5176
5177 assert!(env.server.gpadls.contains_key(&(GpadlId(1), offer_id1)));
5179 }
5180
5181 #[test]
5182 fn test_pending_messages() {
5183 let mut env = TestEnv::new();
5184
5185 let offer_id1 = env.offer(1);
5186 let offer_id2 = env.offer(2);
5187 let offer_id3 = env.offer(3);
5188
5189 env.connect(Version::Copper, FeatureFlags::new());
5190 env.c().handle_request_offers().unwrap();
5191
5192 env.notifier.messages.clear();
5193 env.notifier.pend_messages = true;
5194 env.open_reserved(2, 4, SINT.into());
5195 env.c().open_complete(offer_id2, protocol::STATUS_SUCCESS);
5196
5197 assert!(env.notifier.messages.is_empty());
5199 assert!(!env.server.has_pending_messages());
5200
5201 env.gpadl(1, 10);
5202 env.c()
5203 .gpadl_create_complete(offer_id1, GpadlId(10), protocol::STATUS_SUCCESS);
5204
5205 env.notifier.pend_messages = true;
5207 env.open(3);
5208 env.c().open_complete(offer_id3, protocol::STATUS_SUCCESS);
5209
5210 assert!(env.notifier.messages.is_empty());
5212 assert!(env.server.has_pending_messages());
5213 env.notifier.pend_messages = false;
5214
5215 let state = env.server.save();
5216
5217 let mut env = TestEnv::new();
5219
5220 let offer_id1 = env.offer(1);
5221 let offer_id2 = env.offer(2);
5222 let offer_id3 = env.offer(3);
5223
5224 env.c().restore(state).unwrap();
5225 env.c().restore_channel(offer_id1, false).unwrap();
5226 env.c().restore_channel(offer_id2, true).unwrap();
5227 env.c().restore_channel(offer_id3, true).unwrap();
5228
5229 assert!(env.server.has_pending_messages());
5231 let mut pending_messages = Vec::new();
5232 let r = env.server.poll_flush_pending_messages(|msg| {
5233 pending_messages.push(msg.clone());
5234 Poll::Ready(())
5235 });
5236 assert!(r.is_ready());
5237 assert_eq!(pending_messages.len(), 2);
5238 assert_eq!(
5239 protocol::MessageHeader::read_from_prefix(pending_messages[0].data())
5240 .unwrap()
5241 .0
5242 .message_type(),
5243 protocol::MessageType::GPADL_CREATED
5244 );
5245
5246 assert_eq!(
5247 protocol::MessageHeader::read_from_prefix(pending_messages[1].data())
5248 .unwrap()
5249 .0
5250 .message_type(),
5251 protocol::MessageType::OPEN_CHANNEL_RESULT
5252 );
5253
5254 assert!(!env.server.has_pending_messages());
5255 }
5256
5257 #[test]
5258 fn test_modify_connection() {
5259 let mut env = TestEnv::new();
5260 env.connect(
5261 Version::Copper,
5262 FeatureFlags::new().with_modify_connection(true),
5263 );
5264
5265 env.send_message(in_msg(
5266 protocol::MessageType::MODIFY_CONNECTION,
5267 protocol::ModifyConnection {
5268 parent_to_child_monitor_page_gpa: 5,
5269 child_to_parent_monitor_page_gpa: 6,
5270 },
5271 ));
5272
5273 assert_eq!(
5274 env.notifier.monitor_page,
5275 Some(MonitorPageGpas {
5276 parent_to_child: 5,
5277 child_to_parent: 6
5278 })
5279 );
5280
5281 let request = env.next_action();
5282 assert_eq!(
5283 request,
5284 ModifyConnectionRequest {
5285 monitor_page: Update::Set(MonitorPageGpas {
5286 child_to_parent: 6,
5287 parent_to_child: 5,
5288 }),
5289 ..Default::default()
5290 }
5291 );
5292
5293 env.c()
5294 .complete_modify_connection(ModifyConnectionResponse::Supported(
5295 protocol::ConnectionState::FAILED_UNKNOWN_FAILURE,
5296 SUPPORTED_FEATURE_FLAGS,
5297 ));
5298
5299 env.notifier
5300 .check_message(OutgoingMessage::new(&protocol::ModifyConnectionResponse {
5301 connection_state: protocol::ConnectionState::FAILED_UNKNOWN_FAILURE,
5302 }));
5303 }
5304
5305 #[test]
5306 fn test_modify_connection_unsupported() {
5307 let mut env = TestEnv::new();
5308 env.connect(Version::Copper, FeatureFlags::new());
5309
5310 let err = env
5311 .try_send_message(in_msg(
5312 protocol::MessageType::MODIFY_CONNECTION,
5313 protocol::ModifyConnection {
5314 parent_to_child_monitor_page_gpa: 5,
5315 child_to_parent_monitor_page_gpa: 6,
5316 },
5317 ))
5318 .unwrap_err();
5319
5320 assert!(matches!(
5321 err,
5322 ChannelError::ParseError(protocol::ParseError::InvalidMessageType(
5323 protocol::MessageType::MODIFY_CONNECTION
5324 ))
5325 ));
5326 }
5327
5328 #[test]
5329 fn test_reserved_channels() {
5330 let mut env = TestEnv::new();
5331
5332 let offer_id1 = env.offer(1);
5333 let offer_id2 = env.offer(2);
5334 let offer_id3 = env.offer(3);
5335
5336 env.connect(Version::Win10, FeatureFlags::new());
5337 env.c().handle_request_offers().unwrap();
5338
5339 env.gpadl(1, 10);
5341 env.c().gpadl_create_complete(offer_id1, GpadlId(10), 0);
5342
5343 env.notifier.messages.clear();
5344
5345 env.open_reserved(1, 1, SINT.into());
5347 env.c().open_complete(offer_id1, 0);
5348 env.notifier.check_message_with_target(
5349 OutgoingMessage::new(&protocol::OpenResult {
5350 channel_id: ChannelId(1),
5351 ..FromZeros::new_zeroed()
5352 }),
5353 MessageTarget::ReservedChannel(offer_id1, ConnectionTarget { vp: 1, sint: SINT }),
5354 );
5355 env.open_reserved(2, 2, SINT.into());
5356 env.c().open_complete(offer_id2, 0);
5357 env.open_reserved(3, 3, SINT.into());
5358 env.c().open_complete(offer_id3, 0);
5359
5360 assert!(matches!(env.close(2), Err(ChannelError::ChannelReserved)));
5362
5363 env.c().handle_unload();
5365
5366 env.close_reserved(2, 2, SINT.into());
5368 env.c().close_complete(offer_id2);
5369
5370 env.notifier.messages.clear();
5371 env.connect(Version::Copper, FeatureFlags::new());
5372 env.c().handle_request_offers().unwrap();
5373
5374 env.gpadl(2, 10);
5377 env.c().gpadl_create_complete(offer_id2, GpadlId(10), 0);
5378
5379 env.open_reserved(2, 3, SINT.into());
5381 env.c().open_complete(offer_id2, 0);
5382
5383 env.notifier.messages.clear();
5384
5385 env.close_reserved(1, 4, SINT.into());
5388 env.c().close_complete(offer_id1);
5389 env.notifier.check_message_with_target(
5390 OutgoingMessage::new(&protocol::CloseReservedChannelResponse {
5391 channel_id: ChannelId(1),
5392 }),
5393 MessageTarget::ReservedChannel(offer_id1, ConnectionTarget { vp: 4, sint: SINT }),
5394 );
5395 env.teardown_gpadl(1, 10);
5396 env.c().gpadl_teardown_complete(offer_id1, GpadlId(10));
5397
5398 env.c().reset();
5400 env.c().close_complete(offer_id2);
5401 env.c().gpadl_teardown_complete(offer_id2, GpadlId(10));
5402 env.c().close_complete(offer_id3);
5403
5404 env.complete_reset();
5405 assert!(env.notifier.is_reset());
5406 }
5407
5408 #[test]
5409 fn test_disconnected_reset() {
5410 let mut env = TestEnv::new();
5411
5412 let offer_id1 = env.offer(1);
5413
5414 env.connect(Version::Win10, FeatureFlags::new());
5415 env.c().handle_request_offers().unwrap();
5416
5417 env.gpadl(1, 10);
5418 env.c().gpadl_create_complete(offer_id1, GpadlId(10), 0);
5419 env.open_reserved(1, 1, SINT.into());
5420 env.c().open_complete(offer_id1, 0);
5421
5422 env.c().handle_unload();
5423
5424 env.c().reset();
5427 env.c().close_complete(offer_id1);
5428 env.c().gpadl_teardown_complete(offer_id1, GpadlId(10));
5429
5430 env.complete_reset();
5431 assert!(env.notifier.is_reset());
5432
5433 let offer_id2 = env.offer(2);
5434
5435 env.notifier.messages.clear();
5436 env.connect(Version::Win10, FeatureFlags::new());
5437 env.c().handle_request_offers().unwrap();
5438
5439 env.gpadl(2, 20);
5440 env.c().gpadl_create_complete(offer_id2, GpadlId(20), 0);
5441 env.open_reserved(2, 2, SINT.into());
5442 env.c().open_complete(offer_id2, 0);
5443
5444 env.c().handle_unload();
5445
5446 env.close_reserved(2, 2, SINT.into());
5447 env.c().close_complete(offer_id2);
5448 env.c().gpadl_teardown_complete(offer_id2, GpadlId(20));
5449
5450 env.c().reset();
5451 assert!(env.notifier.is_reset());
5452 }
5453
5454 #[test]
5455 fn test_mnf_channel() {
5456 let mut env = TestEnv::new();
5457
5458 let _offer_id1 = env.offer(1);
5461 let _offer_id2 = env.offer_with_mnf(2);
5462 let _offer_id3 = env.offer_with_preset_mnf(3, 5);
5463
5464 env.connect(Version::Copper, FeatureFlags::new());
5465 env.c().handle_request_offers().unwrap();
5466
5467 assert_eq!(env.server.assigned_monitors.bitmap(), 1);
5469
5470 env.notifier.check_messages(&[
5471 OutgoingMessage::new(&protocol::OfferChannel {
5472 interface_id: Guid {
5473 data1: 1,
5474 ..Guid::ZERO
5475 },
5476 instance_id: Guid {
5477 data1: 1,
5478 ..Guid::ZERO
5479 },
5480 channel_id: ChannelId(1),
5481 connection_id: 0x2001,
5482 is_dedicated: 1,
5483 monitor_id: 0xff,
5484 ..protocol::OfferChannel::new_zeroed()
5485 }),
5486 OutgoingMessage::new(&protocol::OfferChannel {
5487 interface_id: Guid {
5488 data1: 2,
5489 ..Guid::ZERO
5490 },
5491 instance_id: Guid {
5492 data1: 2,
5493 ..Guid::ZERO
5494 },
5495 channel_id: ChannelId(2),
5496 connection_id: 0x2002,
5497 is_dedicated: 1,
5498 monitor_id: 0,
5499 monitor_allocated: 1,
5500 ..protocol::OfferChannel::new_zeroed()
5501 }),
5502 OutgoingMessage::new(&protocol::OfferChannel {
5503 interface_id: Guid {
5504 data1: 3,
5505 ..Guid::ZERO
5506 },
5507 instance_id: Guid {
5508 data1: 3,
5509 ..Guid::ZERO
5510 },
5511 channel_id: ChannelId(3),
5512 connection_id: 0x2003,
5513 is_dedicated: 1,
5514 monitor_id: 5,
5515 monitor_allocated: 1,
5516 ..protocol::OfferChannel::new_zeroed()
5517 }),
5518 OutgoingMessage::new(&protocol::AllOffersDelivered {}),
5519 ])
5520 }
5521
5522 #[test]
5523 fn test_channel_id_order() {
5524 let mut env = TestEnv::new();
5525
5526 let _offer_id1 = env.offer(3);
5527 let _offer_id2 = env.offer(10);
5528 let _offer_id3 = env.offer(5);
5529 let _offer_id4 = env.offer(17);
5530 let _offer_id5 = env.offer_with_order(5, 6, Some(2));
5531 let _offer_id6 = env.offer_with_order(5, 8, Some(1));
5532 let _offer_id7 = env.offer_with_order(5, 1, None);
5533
5534 env.connect(Version::Win10, FeatureFlags::new());
5535 env.c().handle_request_offers().unwrap();
5536
5537 env.notifier.check_messages(&[
5538 OutgoingMessage::new(&protocol::OfferChannel {
5539 interface_id: Guid {
5540 data1: 3,
5541 ..Guid::ZERO
5542 },
5543 instance_id: Guid {
5544 data1: 3,
5545 ..Guid::ZERO
5546 },
5547 channel_id: ChannelId(1),
5548 connection_id: 0x2001,
5549 is_dedicated: 1,
5550 monitor_id: 0xff,
5551 ..protocol::OfferChannel::new_zeroed()
5552 }),
5553 OutgoingMessage::new(&protocol::OfferChannel {
5554 interface_id: Guid {
5555 data1: 5,
5556 ..Guid::ZERO
5557 },
5558 instance_id: Guid {
5559 data1: 8,
5560 ..Guid::ZERO
5561 },
5562 channel_id: ChannelId(2),
5563 connection_id: 0x2002,
5564 is_dedicated: 1,
5565 monitor_id: 0xff,
5566 ..protocol::OfferChannel::new_zeroed()
5567 }),
5568 OutgoingMessage::new(&protocol::OfferChannel {
5569 interface_id: Guid {
5570 data1: 5,
5571 ..Guid::ZERO
5572 },
5573 instance_id: Guid {
5574 data1: 6,
5575 ..Guid::ZERO
5576 },
5577 channel_id: ChannelId(3),
5578 connection_id: 0x2003,
5579 is_dedicated: 1,
5580 monitor_id: 0xff,
5581 ..protocol::OfferChannel::new_zeroed()
5582 }),
5583 OutgoingMessage::new(&protocol::OfferChannel {
5584 interface_id: Guid {
5585 data1: 5,
5586 ..Guid::ZERO
5587 },
5588 instance_id: Guid {
5589 data1: 1,
5590 ..Guid::ZERO
5591 },
5592 channel_id: ChannelId(4),
5593 connection_id: 0x2004,
5594 is_dedicated: 1,
5595 monitor_id: 0xff,
5596 ..protocol::OfferChannel::new_zeroed()
5597 }),
5598 OutgoingMessage::new(&protocol::OfferChannel {
5599 interface_id: Guid {
5600 data1: 5,
5601 ..Guid::ZERO
5602 },
5603 instance_id: Guid {
5604 data1: 5,
5605 ..Guid::ZERO
5606 },
5607 channel_id: ChannelId(5),
5608 connection_id: 0x2005,
5609 is_dedicated: 1,
5610 monitor_id: 0xff,
5611 ..protocol::OfferChannel::new_zeroed()
5612 }),
5613 OutgoingMessage::new(&protocol::OfferChannel {
5614 interface_id: Guid {
5615 data1: 10,
5616 ..Guid::ZERO
5617 },
5618 instance_id: Guid {
5619 data1: 10,
5620 ..Guid::ZERO
5621 },
5622 channel_id: ChannelId(6),
5623 connection_id: 0x2006,
5624 is_dedicated: 1,
5625 monitor_id: 0xff,
5626 ..protocol::OfferChannel::new_zeroed()
5627 }),
5628 OutgoingMessage::new(&protocol::OfferChannel {
5629 interface_id: Guid {
5630 data1: 17,
5631 ..Guid::ZERO
5632 },
5633 instance_id: Guid {
5634 data1: 17,
5635 ..Guid::ZERO
5636 },
5637 channel_id: ChannelId(7),
5638 connection_id: 0x2007,
5639 is_dedicated: 1,
5640 monitor_id: 0xff,
5641 ..protocol::OfferChannel::new_zeroed()
5642 }),
5643 OutgoingMessage::new(&protocol::AllOffersDelivered {}),
5644 ])
5645 }
5646
5647 #[test]
5648 fn test_confidential_connection() {
5649 let mut env = TestEnv::new();
5650 env.connect_trusted(
5651 Version::Copper,
5652 FeatureFlags::new().with_confidential_channels(true),
5653 );
5654
5655 assert_eq!(
5656 env.version.unwrap(),
5657 VersionInfo {
5658 version: Version::Copper,
5659 feature_flags: FeatureFlags::new().with_confidential_channels(true)
5660 }
5661 );
5662
5663 env.offer(1); env.offer_with_flags(2, OfferFlags::new().with_confidential_ring_buffer(true));
5665 env.offer_with_flags(
5666 3,
5667 OfferFlags::new()
5668 .with_confidential_ring_buffer(true)
5669 .with_confidential_external_memory(true),
5670 );
5671
5672 let error = env
5674 .try_send_message(in_msg(
5675 protocol::MessageType::REQUEST_OFFERS,
5676 protocol::RequestOffers {},
5677 ))
5678 .unwrap_err();
5679
5680 assert!(matches!(error, ChannelError::UntrustedMessage));
5681 assert!(env.notifier.messages.is_empty());
5682
5683 env.send_message(in_msg_ex(
5685 protocol::MessageType::REQUEST_OFFERS,
5686 protocol::RequestOffers {},
5687 false,
5688 true,
5689 ));
5690
5691 let offer = env.notifier.get_message::<protocol::OfferChannel>();
5692 assert_eq!(offer.channel_id, ChannelId(1));
5693 assert_eq!(offer.flags, OfferFlags::new());
5694
5695 let offer = env.notifier.get_message::<protocol::OfferChannel>();
5696 assert_eq!(offer.channel_id, ChannelId(2));
5697 assert_eq!(
5698 offer.flags,
5699 OfferFlags::new().with_confidential_ring_buffer(true)
5700 );
5701
5702 let offer = env.notifier.get_message::<protocol::OfferChannel>();
5703 assert_eq!(offer.channel_id, ChannelId(3));
5704 assert_eq!(
5705 offer.flags,
5706 OfferFlags::new()
5707 .with_confidential_ring_buffer(true)
5708 .with_confidential_external_memory(true)
5709 );
5710
5711 env.notifier
5712 .check_message(OutgoingMessage::new(&protocol::AllOffersDelivered {}));
5713 }
5714
5715 #[test]
5716 fn test_confidential_channels_unsupported() {
5717 let mut env = TestEnv::new();
5718
5719 env.connect_trusted(Version::Copper, FeatureFlags::new());
5722
5723 assert_eq!(
5724 env.version.unwrap(),
5725 VersionInfo {
5726 version: Version::Copper,
5727 feature_flags: FeatureFlags::new()
5728 }
5729 );
5730
5731 env.offer_with_flags(1, OfferFlags::new().with_enumerate_device_interface(true)); env.offer_with_flags(
5733 2,
5734 OfferFlags::new()
5735 .with_named_pipe_mode(true)
5736 .with_confidential_ring_buffer(true)
5737 .with_confidential_external_memory(true),
5738 );
5739
5740 env.send_message(in_msg_ex(
5741 protocol::MessageType::REQUEST_OFFERS,
5742 protocol::RequestOffers {},
5743 false,
5744 true,
5745 ));
5746
5747 let offer = env.notifier.get_message::<protocol::OfferChannel>();
5748 assert_eq!(offer.channel_id, ChannelId(1));
5749 assert_eq!(
5750 offer.flags,
5751 OfferFlags::new().with_enumerate_device_interface(true)
5752 );
5753
5754 let offer = env.notifier.get_message::<protocol::OfferChannel>();
5756 assert_eq!(offer.channel_id, ChannelId(2));
5757 assert_eq!(offer.flags, OfferFlags::new().with_named_pipe_mode(true));
5758
5759 env.notifier
5760 .check_message(OutgoingMessage::new(&protocol::AllOffersDelivered {}));
5761 }
5762
5763 #[test]
5764 fn test_confidential_channels_untrusted() {
5765 let mut env = TestEnv::new();
5766
5767 env.connect(
5768 Version::Copper,
5769 FeatureFlags::new().with_confidential_channels(true),
5770 );
5771
5772 assert_eq!(
5775 env.version.unwrap(),
5776 VersionInfo {
5777 version: Version::Copper,
5778 feature_flags: FeatureFlags::new()
5779 }
5780 );
5781 }
5782}