1pub mod saved_state;
5#[cfg(test)]
6mod tests;
7
8use crate::Guid;
9use crate::SynicMessage;
10use crate::monitor::AssignedMonitors;
11use crate::protocol::Version;
12use hvdef::Vtl;
13use inspect::Inspect;
14pub use saved_state::RestoreError;
15pub use saved_state::SavedState;
16pub use saved_state::SavedStateData;
17use slab::Slab;
18use std::cmp::min;
19use std::collections::VecDeque;
20use std::collections::hash_map::Entry;
21use std::collections::hash_map::HashMap;
22use std::fmt::Display;
23use std::ops::Index;
24use std::ops::IndexMut;
25use std::task::Poll;
26use std::task::ready;
27use std::time::Duration;
28use thiserror::Error;
29use vmbus_channel::bus::ChannelType;
30use vmbus_channel::bus::GpadlRequest;
31use vmbus_channel::bus::OfferKey;
32use vmbus_channel::bus::OfferParams;
33use vmbus_channel::bus::OpenData;
34use vmbus_channel::bus::RestoredGpadl;
35use vmbus_core::HvsockConnectRequest;
36use vmbus_core::HvsockConnectResult;
37use vmbus_core::MaxVersionInfo;
38use vmbus_core::OutgoingMessage;
39use vmbus_core::VMBUS_SINT;
40use vmbus_core::VersionInfo;
41use vmbus_core::protocol;
42use vmbus_core::protocol::ChannelId;
43use vmbus_core::protocol::ConnectionId;
44use vmbus_core::protocol::FeatureFlags;
45use vmbus_core::protocol::GpadlId;
46use vmbus_core::protocol::Message;
47use vmbus_core::protocol::OfferFlags;
48use vmbus_core::protocol::UserDefinedData;
49use vmbus_ring::gparange;
50use vmcore::monitor::MonitorId;
51use vmcore::synic::MonitorInfo;
52use vmcore::synic::MonitorPageGpas;
53use zerocopy::FromZeros;
54use zerocopy::Immutable;
55use zerocopy::IntoBytes;
56use zerocopy::KnownLayout;
57
58#[derive(Debug, Error)]
60pub enum ChannelError {
61 #[error("unknown channel ID")]
62 UnknownChannelId,
63 #[error("unknown GPADL ID")]
64 UnknownGpadlId,
65 #[error("parse error")]
66 ParseError(#[from] protocol::ParseError),
67 #[error("invalid gpa range")]
68 InvalidGpaRange(#[source] gparange::Error),
69 #[error("duplicate GPADL ID")]
70 DuplicateGpadlId,
71 #[error("GPADL is already complete")]
72 GpadlAlreadyComplete,
73 #[error("GPADL channel ID mismatch")]
74 WrongGpadlChannelId,
75 #[error("trying to open an open channel")]
76 ChannelAlreadyOpen,
77 #[error("trying to close a closed channel")]
78 ChannelNotOpen,
79 #[error("invalid GPADL state for operation")]
80 InvalidGpadlState,
81 #[error("invalid channel state for operation")]
82 InvalidChannelState,
83 #[error("channel ID has already been released")]
84 ChannelReleased,
85 #[error("channel offers have already been sent")]
86 OffersAlreadySent,
87 #[error("invalid operation on reserved channel")]
88 ChannelReserved,
89 #[error("invalid operation on non-reserved channel")]
90 ChannelNotReserved,
91 #[error("received untrusted message for trusted connection")]
92 UntrustedMessage,
93 #[error("received a non-resuming message while paused")]
94 Paused,
95 #[error("invalid target VP")]
96 InvalidTargetVp,
97 #[error("interrupts are disabled for this channel")]
98 InterruptsDisabled,
99}
100
101#[derive(Debug, Error)]
102pub enum OfferError {
103 #[error("the channel ID {} is not valid for this operation", (.0).0)]
104 InvalidChannelId(ChannelId),
105 #[error("the channel ID {} is already in use", (.0).0)]
106 ChannelIdInUse(ChannelId),
107 #[error("offer {0} already exists")]
108 AlreadyExists(OfferKey),
109 #[error("specified resources do not match those of the existing saved or revoked offer")]
110 IncompatibleResources,
111 #[error("too many channels have been offered")]
112 TooManyChannels,
113 #[error("mismatched monitor ID from saved state; expected {0:?}, actual {1:?}")]
114 MismatchedMonitorId(Option<MonitorId>, MonitorId),
115}
116
117#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
119pub struct OfferId(usize);
120
121type IncompleteGpadlMap = HashMap<GpadlId, OfferId>;
122
123type GpadlMap = HashMap<(GpadlId, OfferId), Gpadl>;
124
125pub struct Server {
127 state: ConnectionState,
128 channels: ChannelList,
129 assigned_channels: AssignedChannels,
130 assigned_monitors: AssignedMonitors,
131 gpadls: GpadlMap,
132 incomplete_gpadls: IncompleteGpadlMap,
133 child_connection_id: u32,
134 max_version: Option<MaxVersionInfo>,
135 delayed_max_version: Option<MaxVersionInfo>,
136 pending_messages: PendingMessages,
139 require_server_allocated_mnf: bool,
143 use_absolute_channel_order: bool,
144}
145
146pub struct ServerWithNotifier<'a, T> {
147 inner: &'a mut Server,
148 notifier: &'a mut T,
149}
150
151impl<T> Drop for ServerWithNotifier<'_, T> {
152 fn drop(&mut self) {
153 self.inner.validate();
154 }
155}
156
157impl<T: Notifier> Inspect for ServerWithNotifier<'_, T> {
158 fn inspect(&self, req: inspect::Request<'_>) {
159 let mut resp = req.respond();
160 let (state, info, next_action) = match &self.inner.state {
161 ConnectionState::Disconnected => ("disconnected", None, None),
162 ConnectionState::Connecting { info, .. } => ("connecting", Some(info), None),
163 ConnectionState::Connected(info) => (
164 if info.offers_sent {
165 "connected"
166 } else {
167 "negotiated"
168 },
169 Some(info),
170 None,
171 ),
172 ConnectionState::Disconnecting { next_action, .. } => {
173 ("disconnecting", None, Some(next_action))
174 }
175 };
176
177 resp.field("connection_info", info);
178 let next_action = next_action.map(|a| match a {
179 ConnectionAction::None => "disconnect",
180 ConnectionAction::Reset => "reset",
181 ConnectionAction::SendUnloadComplete => "unload",
182 ConnectionAction::Reconnect { .. } => "reconnect",
183 ConnectionAction::SendFailedVersionResponse => "send_version_response",
184 });
185 resp.field("state", state)
186 .field("next_action", next_action)
187 .field(
188 "assigned_monitors_bitmap",
189 format_args!("{:x}", self.inner.assigned_monitors.bitmap()),
190 )
191 .child("channels", |req| {
192 let mut resp = req.respond();
193 self.inner
194 .channels
195 .inspect(self.notifier, self.inner.get_version(), &mut resp);
196 for ((gpadl_id, offer_id), gpadl) in &self.inner.gpadls {
197 let channel = &self.inner.channels[*offer_id];
198 resp.field(
199 &channel_inspect_path(
200 &channel.offer,
201 format_args!("/gpadls/{}", gpadl_id.0),
202 ),
203 gpadl,
204 );
205 }
206 });
207 }
208}
209
210#[derive(Debug, Copy, Clone, Inspect)]
212struct MonitorPageGpaInfo {
213 gpas: MonitorPageGpas,
214 server_allocated: bool,
215}
216
217impl MonitorPageGpaInfo {
218 fn from_guest_gpas(gpas: MonitorPageGpas) -> Self {
220 Self {
221 gpas,
222 server_allocated: false,
223 }
224 }
225
226 fn from_server_gpas(gpas: MonitorPageGpas) -> Self {
228 Self {
229 gpas,
230 server_allocated: true,
231 }
232 }
233}
234
235#[derive(Debug, Copy, Clone, Inspect)]
236struct ConnectionInfo {
237 version: VersionInfo,
238 trusted: bool,
241 offers_sent: bool,
242 interrupt_page: Option<u64>,
243 monitor_page: Option<MonitorPageGpaInfo>,
244 target_message_vp: u32,
245 modifying: bool,
246 client_id: Guid,
247 paused: bool,
248}
249
250#[derive(Debug)]
252enum ConnectionState {
253 Disconnected,
254 Disconnecting {
255 next_action: ConnectionAction,
256 modify_sent: bool,
257 },
258 Connecting {
259 info: ConnectionInfo,
260 next_action: ConnectionAction,
261 },
262 Connected(ConnectionInfo),
263}
264
265impl ConnectionState {
266 fn check_version(&self, min_version: Version) -> bool {
268 matches!(self, ConnectionState::Connected(info) if info.version.version >= min_version)
269 }
270
271 fn check_feature_flags(&self, flags: impl Fn(FeatureFlags) -> bool) -> bool {
274 matches!(self, ConnectionState::Connected(info) if flags(info.version.feature_flags))
275 }
276
277 fn get_version(&self) -> Option<VersionInfo> {
278 if let ConnectionState::Connected(info) = self {
279 Some(info.version)
280 } else {
281 None
282 }
283 }
284
285 fn get_connected_info(&self) -> Option<&ConnectionInfo> {
287 if let ConnectionState::Connected(info) = self {
288 Some(info)
289 } else {
290 None
291 }
292 }
293
294 fn is_trusted(&self) -> bool {
295 match self {
296 ConnectionState::Connected(info) => info.trusted,
297 ConnectionState::Connecting { info, .. } => info.trusted,
298 _ => false,
299 }
300 }
301
302 fn is_paused(&self) -> bool {
303 if let ConnectionState::Connected(info) = self {
304 info.paused
305 } else {
306 false
307 }
308 }
309}
310
311#[derive(Debug, Copy, Clone)]
312enum ConnectionAction {
313 None,
314 Reset,
315 SendUnloadComplete,
316 Reconnect {
317 initiate_contact: InitiateContactRequest,
318 },
319 SendFailedVersionResponse,
320}
321
322#[derive(PartialEq, Eq, Debug, Copy, Clone)]
323pub enum MonitorPageRequest {
324 None,
325 Some(MonitorPageGpas),
326 Invalid,
327}
328
329#[derive(PartialEq, Eq, Debug, Copy, Clone)]
330pub struct InitiateContactRequest {
331 pub version_requested: u32,
332 pub target_message_vp: u32,
333 pub monitor_page: MonitorPageRequest,
334 pub target_sint: u8,
335 pub target_vtl: u8,
336 pub feature_flags: u32,
337 pub interrupt_page: Option<u64>,
338 pub client_id: Guid,
339 pub trusted: bool,
340}
341
342#[derive(Debug, Copy, Clone)]
343pub struct OpenRequest {
344 pub open_id: u32,
345 pub ring_buffer_gpadl_id: GpadlId,
346 pub target_vp: Option<u32>,
347 pub downstream_ring_buffer_page_offset: u32,
348 pub user_data: UserDefinedData,
349 pub guest_specified_interrupt_info: Option<SignalInfo>,
350 pub flags: protocol::OpenChannelFlags,
351}
352
353#[derive(Debug, Copy, Clone, Eq, PartialEq)]
354pub enum Update<T: std::fmt::Debug + Copy + Clone> {
355 Unchanged,
356 Reset,
357 Set(T),
358}
359
360impl<T: std::fmt::Debug + Copy + Clone> From<Option<T>> for Update<T> {
361 fn from(value: Option<T>) -> Self {
362 match value {
363 None => Self::Reset,
364 Some(value) => Self::Set(value),
365 }
366 }
367}
368
369#[derive(Debug, Copy, Clone, Eq, PartialEq)]
370pub struct ModifyConnectionRequest {
371 pub version: Option<VersionInfo>,
372 pub monitor_page: Update<MonitorPageGpas>,
373 pub interrupt_page: Update<u64>,
374 pub target_message_vp: Option<u32>,
375 pub notify_relay: bool,
376}
377
378impl Default for ModifyConnectionRequest {
380 fn default() -> Self {
381 Self {
382 version: None,
383 monitor_page: Update::Unchanged,
384 interrupt_page: Update::Unchanged,
385 target_message_vp: None,
386 notify_relay: true,
387 }
388 }
389}
390
391impl From<protocol::ModifyConnection> for ModifyConnectionRequest {
392 fn from(value: protocol::ModifyConnection) -> Self {
393 let monitor_page = if value.parent_to_child_monitor_page_gpa != 0 {
394 Update::Set(MonitorPageGpas {
395 parent_to_child: value.parent_to_child_monitor_page_gpa,
396 child_to_parent: value.child_to_parent_monitor_page_gpa,
397 })
398 } else {
399 Update::Reset
400 };
401
402 Self {
403 monitor_page,
404 ..Default::default()
405 }
406 }
407}
408
409#[derive(Debug, Copy, Clone)]
411pub enum ModifyConnectionResponse {
412 Supported(
418 protocol::ConnectionState,
419 FeatureFlags,
420 Option<MonitorPageGpas>,
421 ),
422 Unsupported,
424 Modified(protocol::ConnectionState),
427}
428
429#[derive(Debug, Copy, Clone)]
430pub enum ModifyState {
431 NotModifying,
432 Modifying { pending_target_vp: Option<u32> },
433}
434
435impl ModifyState {
436 pub fn is_modifying(&self) -> bool {
437 matches!(self, ModifyState::Modifying { .. })
438 }
439}
440
441#[derive(Debug, Copy, Clone)]
442pub struct SignalInfo {
443 pub event_flag: u16,
444 pub connection_id: u32,
445}
446
447#[derive(Debug, Copy, Clone, PartialEq, Eq)]
448enum RestoreState {
449 New,
451 Restoring,
455 Unmatched,
458 Restored,
460}
461
462#[derive(Debug, Clone)]
464enum ChannelState {
465 ClientReleased,
469
470 Closed,
472
473 Opening {
476 request: OpenRequest,
477 reserved_state: Option<ReservedState>,
478 },
479
480 Open {
482 params: OpenRequest,
483 modify_state: ModifyState,
484 reserved_state: Option<ReservedState>,
485 },
486
487 Closing {
489 params: OpenRequest,
490 reserved_state: Option<ReservedState>,
491 },
492
493 ClosingReopen {
496 params: OpenRequest,
497 request: OpenRequest,
498 },
499
500 Revoked,
502
503 Reoffered,
506
507 ClosingClientRelease,
510
511 OpeningClientRelease,
514}
515
516impl ChannelState {
517 fn is_released(&self) -> bool {
520 match self {
521 ChannelState::Closed
522 | ChannelState::Opening { .. }
523 | ChannelState::Open { .. }
524 | ChannelState::Closing { .. }
525 | ChannelState::ClosingReopen { .. }
526 | ChannelState::Revoked
527 | ChannelState::Reoffered => false,
528
529 ChannelState::ClientReleased
530 | ChannelState::ClosingClientRelease
531 | ChannelState::OpeningClientRelease => true,
532 }
533 }
534
535 fn is_revoked(&self) -> bool {
537 match self {
538 ChannelState::Revoked | ChannelState::Reoffered => true,
539
540 ChannelState::ClientReleased
541 | ChannelState::Closed
542 | ChannelState::Opening { .. }
543 | ChannelState::Open { .. }
544 | ChannelState::Closing { .. }
545 | ChannelState::ClosingReopen { .. }
546 | ChannelState::ClosingClientRelease
547 | ChannelState::OpeningClientRelease => false,
548 }
549 }
550
551 fn is_reserved(&self) -> bool {
552 match self {
553 ChannelState::Open {
555 reserved_state: Some(_),
556 ..
557 }
558 | ChannelState::Opening {
559 reserved_state: Some(_),
560 ..
561 }
562 | ChannelState::Closing {
563 reserved_state: Some(_),
564 ..
565 } => true,
566
567 ChannelState::Opening { .. }
568 | ChannelState::Open { .. }
569 | ChannelState::Closing { .. }
570 | ChannelState::ClientReleased
571 | ChannelState::Closed
572 | ChannelState::ClosingReopen { .. }
573 | ChannelState::Revoked
574 | ChannelState::Reoffered
575 | ChannelState::ClosingClientRelease
576 | ChannelState::OpeningClientRelease => false,
577 }
578 }
579}
580
581impl Display for ChannelState {
582 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
583 let state = match self {
584 Self::ClientReleased => "ClientReleased",
585 Self::Closed => "Closed",
586 Self::Opening { .. } => "Opening",
587 Self::Open { .. } => "Open",
588 Self::Closing { .. } => "Closing",
589 Self::ClosingReopen { .. } => "ClosingReopen",
590 Self::Revoked => "Revoked",
591 Self::Reoffered => "Reoffered",
592 Self::ClosingClientRelease => "ClosingClientRelease",
593 Self::OpeningClientRelease => "OpeningClientRelease",
594 };
595 write!(f, "{}", state)
596 }
597}
598
599#[derive(Debug, Clone, Default, mesh::MeshPayload)]
601pub enum MnfUsage {
602 #[default]
604 Disabled,
605 Enabled { latency: Duration },
607 Relayed { monitor_id: u8 },
610}
611
612impl MnfUsage {
613 pub fn is_enabled(&self) -> bool {
614 matches!(self, Self::Enabled { .. })
615 }
616
617 pub fn is_relayed(&self) -> bool {
618 matches!(self, Self::Relayed { .. })
619 }
620
621 pub fn enabled_and_then<T>(&self, f: impl FnOnce(Duration) -> Option<T>) -> Option<T> {
622 if let Self::Enabled { latency } = self {
623 f(*latency)
624 } else {
625 None
626 }
627 }
628}
629
630impl From<Option<Duration>> for MnfUsage {
631 fn from(value: Option<Duration>) -> Self {
632 match value {
633 None => Self::Disabled,
634 Some(latency) => Self::Enabled { latency },
635 }
636 }
637}
638
639#[derive(Debug, Clone, Default, mesh::MeshPayload)]
640pub struct OfferParamsInternal {
641 pub interface_name: String,
643 pub instance_id: Guid,
644 pub interface_id: Guid,
645 pub mmio_megabytes: u16,
646 pub mmio_megabytes_optional: u16,
647 pub subchannel_index: u16,
648 pub use_mnf: MnfUsage,
649 pub offer_order: Option<u64>,
650 pub flags: OfferFlags,
651 pub user_defined: UserDefinedData,
652}
653
654impl OfferParamsInternal {
655 pub fn key(&self) -> OfferKey {
657 OfferKey {
658 interface_id: self.interface_id,
659 instance_id: self.instance_id,
660 subchannel_index: self.subchannel_index,
661 }
662 }
663}
664
665impl From<OfferParams> for OfferParamsInternal {
666 fn from(value: OfferParams) -> Self {
667 let mut user_defined = UserDefinedData::new_zeroed();
668
669 let mut flags = OfferFlags::new()
672 .with_confidential_ring_buffer(true)
673 .with_confidential_external_memory(value.allow_confidential_external_memory);
674
675 match value.channel_type {
676 ChannelType::Device { pipe_packets } => {
677 if pipe_packets {
678 flags.set_named_pipe_mode(true);
679 user_defined.as_pipe_params_mut().pipe_type = protocol::PipeType::MESSAGE;
680 }
681 }
682 ChannelType::Interface {
683 user_defined: interface_user_defined,
684 } => {
685 flags.set_enumerate_device_interface(true);
686 user_defined = interface_user_defined;
687 }
688 ChannelType::Pipe { message_mode } => {
689 flags.set_enumerate_device_interface(true);
690 flags.set_named_pipe_mode(true);
691 user_defined.as_pipe_params_mut().pipe_type = if message_mode {
692 protocol::PipeType::MESSAGE
693 } else {
694 protocol::PipeType::BYTE
695 };
696 }
697 ChannelType::HvSocket {
698 is_connect,
699 is_for_container,
700 silo_id,
701 } => {
702 flags.set_enumerate_device_interface(true);
703 flags.set_tlnpi_provider(true);
704 flags.set_named_pipe_mode(true);
705 *user_defined.as_hvsock_params_mut() = protocol::HvsockUserDefinedParameters::new(
706 is_connect,
707 is_for_container,
708 silo_id,
709 );
710 }
711 };
712
713 Self {
714 interface_name: value.interface_name,
715 instance_id: value.instance_id,
716 interface_id: value.interface_id,
717 mmio_megabytes: value.mmio_megabytes,
718 mmio_megabytes_optional: value.mmio_megabytes_optional,
719 subchannel_index: value.subchannel_index,
720 use_mnf: value.mnf_interrupt_latency.into(),
721 offer_order: value.offer_order,
722 user_defined,
723 flags,
724 }
725 }
726}
727
728#[derive(Debug, Copy, Clone, Inspect, PartialEq, Eq)]
729pub struct ConnectionTarget {
730 pub vp: u32,
731 pub sint: u8,
732}
733
734#[derive(Debug, Copy, Clone, PartialEq, Eq)]
735pub enum MessageTarget {
736 Default,
737 ReservedChannel(OfferId, ConnectionTarget),
738 Custom(ConnectionTarget),
739}
740
741impl MessageTarget {
742 pub fn for_offer(offer_id: OfferId, reserved_state: &Option<ReservedState>) -> Self {
743 if let Some(state) = reserved_state {
744 Self::ReservedChannel(offer_id, state.target)
745 } else {
746 Self::Default
747 }
748 }
749}
750
751#[derive(Debug, Copy, Clone)]
752pub struct ReservedState {
753 version: VersionInfo,
754 target: ConnectionTarget,
755}
756
757#[derive(Debug)]
759struct Channel {
760 info: Option<OfferedInfo>,
761 offer: OfferParamsInternal,
762 state: ChannelState,
763 restore_state: RestoreState,
764}
765
766#[derive(Debug, Copy, Clone)]
767struct OfferedInfo {
768 channel_id: ChannelId,
769 connection_id: u32,
770 monitor_id: Option<MonitorId>,
771}
772
773impl Channel {
774 fn inspect_state(&self, resp: &mut inspect::Response<'_>) {
775 let mut target_vp = None;
776 let mut event_flag = None;
777 let mut connection_id = None;
778 let mut reserved_target = None;
779 let state = match &self.state {
780 ChannelState::ClientReleased => "client_released",
781 ChannelState::Closed => "closed",
782 ChannelState::Opening { reserved_state, .. } => {
783 reserved_target = reserved_state.map(|state| state.target);
784 "opening"
785 }
786 ChannelState::Open {
787 params,
788 reserved_state,
789 ..
790 } => {
791 target_vp = Some(params.target_vp);
792 if let Some(id) = params.guest_specified_interrupt_info {
793 event_flag = Some(id.event_flag);
794 connection_id = Some(id.connection_id);
795 }
796 reserved_target = reserved_state.map(|state| state.target);
797 "open"
798 }
799 ChannelState::Closing { reserved_state, .. } => {
800 reserved_target = reserved_state.map(|state| state.target);
801 "closing"
802 }
803 ChannelState::ClosingReopen { .. } => "closing_reopen",
804 ChannelState::Revoked => "revoked",
805 ChannelState::Reoffered => "reoffered",
806 ChannelState::ClosingClientRelease => "closing_client_release",
807 ChannelState::OpeningClientRelease => "opening_client_release",
808 };
809 let restore_state = match self.restore_state {
810 RestoreState::New => "new",
811 RestoreState::Restoring => "restoring",
812 RestoreState::Restored => "restored",
813 RestoreState::Unmatched => "unmatched",
814 };
815 if let Some(info) = &self.info {
816 resp.field("channel_id", info.channel_id.0)
817 .field("offered_connection_id", info.connection_id)
818 .field("monitor_id", info.monitor_id.map(|id| id.0));
819 }
820 resp.field("state", state)
821 .field("restore_state", restore_state)
822 .field("interface_name", self.offer.interface_name.clone())
823 .display("instance_id", &self.offer.instance_id)
824 .display("interface_id", &self.offer.interface_id)
825 .field("mmio_megabytes", self.offer.mmio_megabytes)
826 .field("target_vp", target_vp)
827 .field("guest_specified_event_flag", event_flag)
828 .field("guest_specified_connection_id", connection_id)
829 .field("reserved_connection_target", reserved_target)
830 .binary("offer_flags", self.offer.flags.into_bits());
831 }
832
833 fn handled_monitor_info(&self) -> Option<MonitorInfo> {
843 self.offer.use_mnf.enabled_and_then(|latency| {
844 if self.state.is_reserved() {
845 None
846 } else {
847 self.info.and_then(|info| {
848 info.monitor_id.map(|monitor_id| MonitorInfo {
849 monitor_id,
850 latency,
851 })
852 })
853 }
854 })
855 }
856
857 fn prepare_channel(
860 &mut self,
861 offer_id: OfferId,
862 assigned_channels: &mut AssignedChannels,
863 assigned_monitors: &mut AssignedMonitors,
864 ) {
865 assert!(self.info.is_none());
866
867 let entry = assigned_channels
869 .allocate()
870 .expect("there are enough channel IDs for everything in ChannelList");
871
872 let channel_id = entry.id();
873 entry.insert(offer_id);
874 let connection_id = ConnectionId::new(channel_id.0, assigned_channels.vtl, VMBUS_SINT);
875
876 let monitor_id = match self.offer.use_mnf {
881 MnfUsage::Enabled { .. } => {
882 let monitor_id = assigned_monitors.assign_monitor();
883 if monitor_id.is_none() {
884 tracelimit::warn_ratelimited!("Out of monitor IDs.");
885 }
886
887 monitor_id
888 }
889 MnfUsage::Relayed { monitor_id } => Some(MonitorId(monitor_id)),
890 MnfUsage::Disabled => None,
891 };
892
893 self.info = Some(OfferedInfo {
894 channel_id,
895 connection_id: connection_id.0,
896 monitor_id,
897 });
898 }
899
900 fn release_channel(
902 &mut self,
903 offer_id: OfferId,
904 assigned_channels: &mut AssignedChannels,
905 assigned_monitors: &mut AssignedMonitors,
906 ) {
907 if let Some(info) = self.info.take() {
908 assigned_channels.free(info.channel_id, offer_id);
909
910 if let Some(monitor_id) = info.monitor_id {
912 if self.offer.use_mnf.is_enabled() {
913 assigned_monitors.release_monitor(monitor_id);
914 }
915 }
916 }
917 }
918}
919
920#[derive(Debug)]
921struct AssignedChannels {
922 assignments: Vec<Option<OfferId>>,
923 vtl: Vtl,
924 reserved_offset: usize,
925 count_in_reserved_range: usize,
927}
928
929impl AssignedChannels {
930 fn new(vtl: Vtl, channel_id_offset: u16) -> Self {
931 Self {
932 assignments: vec![None; MAX_CHANNELS],
933 vtl,
934 reserved_offset: channel_id_offset as usize,
935 count_in_reserved_range: 0,
936 }
937 }
938
939 fn allowable_channel_count(&self) -> usize {
940 MAX_CHANNELS - self.reserved_offset + self.count_in_reserved_range
941 }
942
943 fn get(&self, channel_id: ChannelId) -> Option<OfferId> {
944 self.assignments
945 .get(Self::index(channel_id))
946 .copied()
947 .flatten()
948 }
949
950 fn set(&mut self, channel_id: ChannelId) -> Result<AssignmentEntry<'_>, OfferError> {
951 let index = Self::index(channel_id);
952 if self
953 .assignments
954 .get(index)
955 .ok_or(OfferError::InvalidChannelId(channel_id))?
956 .is_some()
957 {
958 return Err(OfferError::ChannelIdInUse(channel_id));
959 }
960 Ok(AssignmentEntry { list: self, index })
961 }
962
963 fn allocate(&mut self) -> Option<AssignmentEntry<'_>> {
964 let index = self.reserved_offset
965 + self.assignments[self.reserved_offset..]
966 .iter()
967 .position(|x| x.is_none())?;
968 Some(AssignmentEntry { list: self, index })
969 }
970
971 fn free(&mut self, channel_id: ChannelId, offer_id: OfferId) {
972 let index = Self::index(channel_id);
973 let slot = &mut self.assignments[index];
974 assert_eq!(slot.take(), Some(offer_id));
975 if index < self.reserved_offset {
976 self.count_in_reserved_range -= 1;
977 }
978 }
979
980 fn index(channel_id: ChannelId) -> usize {
981 channel_id.0.wrapping_sub(1) as usize
982 }
983}
984
985struct AssignmentEntry<'a> {
986 list: &'a mut AssignedChannels,
987 index: usize,
988}
989
990impl AssignmentEntry<'_> {
991 pub fn id(&self) -> ChannelId {
992 ChannelId(self.index as u32 + 1)
993 }
994
995 pub fn insert(self, offer_id: OfferId) {
996 assert!(
997 self.list.assignments[self.index]
998 .replace(offer_id)
999 .is_none()
1000 );
1001
1002 if self.index < self.list.reserved_offset {
1003 self.list.count_in_reserved_range += 1;
1004 }
1005 }
1006}
1007
1008struct ChannelList {
1009 channels: Slab<Channel>,
1010}
1011
1012fn channel_inspect_path(offer: &OfferParamsInternal, suffix: std::fmt::Arguments<'_>) -> String {
1013 if offer.subchannel_index == 0 {
1014 format!("{}{}", offer.instance_id, suffix)
1015 } else {
1016 format!(
1017 "{}/subchannels/{}{}",
1018 offer.instance_id, offer.subchannel_index, suffix
1019 )
1020 }
1021}
1022
1023impl ChannelList {
1024 fn inspect(
1025 &self,
1026 notifier: &impl Notifier,
1027 version: Option<VersionInfo>,
1028 resp: &mut inspect::Response<'_>,
1029 ) {
1030 for (offer_id, channel) in self.iter() {
1031 resp.child(
1032 &channel_inspect_path(&channel.offer, format_args!("")),
1033 |req| {
1034 let mut resp = req.respond();
1035 channel.inspect_state(&mut resp);
1036
1037 resp.merge(inspect::adhoc(|req| {
1041 if !matches!(channel.state, ChannelState::Revoked) {
1042 notifier.inspect(version, offer_id, req);
1043 }
1044 }));
1045 },
1046 );
1047 }
1048 }
1049}
1050
1051pub const MAX_CHANNELS: usize = 2047;
1054
1055impl ChannelList {
1056 fn new() -> Self {
1057 Self {
1058 channels: Slab::new(),
1059 }
1060 }
1061
1062 fn len(&self) -> usize {
1064 self.channels.len()
1065 }
1066
1067 fn offer(&mut self, new_channel: Channel) -> OfferId {
1069 OfferId(self.channels.insert(new_channel))
1070 }
1071
1072 fn remove(&mut self, offer_id: OfferId) {
1074 let channel = self.channels.remove(offer_id.0);
1075 assert!(channel.info.is_none());
1076 }
1077
1078 fn get_by_channel_id_mut(
1080 &mut self,
1081 assigned_channels: &AssignedChannels,
1082 channel_id: ChannelId,
1083 ) -> Result<(OfferId, &mut Channel), ChannelError> {
1084 let offer_id = assigned_channels
1085 .get(channel_id)
1086 .ok_or(ChannelError::UnknownChannelId)?;
1087 let channel = &mut self[offer_id];
1088 if channel.state.is_released() {
1089 return Err(ChannelError::ChannelReleased);
1090 }
1091 assert_eq!(
1092 channel.info.as_ref().map(|info| info.channel_id),
1093 Some(channel_id)
1094 );
1095 Ok((offer_id, channel))
1096 }
1097
1098 fn get_by_channel_id(
1100 &self,
1101 assigned_channels: &AssignedChannels,
1102 channel_id: ChannelId,
1103 ) -> Result<(OfferId, &Channel), ChannelError> {
1104 let offer_id = assigned_channels
1105 .get(channel_id)
1106 .ok_or(ChannelError::UnknownChannelId)?;
1107 let channel = &self[offer_id];
1108 if channel.state.is_released() {
1109 return Err(ChannelError::ChannelReleased);
1110 }
1111 assert_eq!(
1112 channel.info.as_ref().map(|info| info.channel_id),
1113 Some(channel_id)
1114 );
1115 Ok((offer_id, channel))
1116 }
1117
1118 fn get_by_key_mut(&mut self, key: &OfferKey) -> Option<(OfferId, &mut Channel)> {
1121 for (offer_id, channel) in self.iter_mut() {
1122 if channel.offer.instance_id == key.instance_id
1123 && channel.offer.interface_id == key.interface_id
1124 && channel.offer.subchannel_index == key.subchannel_index
1125 {
1126 return Some((offer_id, channel));
1127 }
1128 }
1129 None
1130 }
1131
1132 fn iter(&self) -> impl Iterator<Item = (OfferId, &Channel)> {
1134 self.channels
1135 .iter()
1136 .map(|(id, channel)| (OfferId(id), channel))
1137 }
1138
1139 fn iter_mut(&mut self) -> impl Iterator<Item = (OfferId, &mut Channel)> {
1141 self.channels
1142 .iter_mut()
1143 .map(|(id, channel)| (OfferId(id), channel))
1144 }
1145
1146 fn retain<F>(&mut self, mut f: F)
1148 where
1149 F: FnMut(OfferId, &mut Channel) -> bool,
1150 {
1151 self.channels.retain(|id, channel| {
1152 let retain = f(OfferId(id), channel);
1153 if !retain {
1154 assert!(channel.info.is_none());
1155 }
1156 retain
1157 })
1158 }
1159}
1160
1161impl Index<OfferId> for ChannelList {
1162 type Output = Channel;
1163
1164 fn index(&self, offer_id: OfferId) -> &Self::Output {
1165 &self.channels[offer_id.0]
1166 }
1167}
1168
1169impl IndexMut<OfferId> for ChannelList {
1170 fn index_mut(&mut self, offer_id: OfferId) -> &mut Self::Output {
1171 &mut self.channels[offer_id.0]
1172 }
1173}
1174
1175#[derive(Debug, Inspect)]
1177struct Gpadl {
1178 count: u16,
1179 #[inspect(skip)]
1180 buf: Vec<u64>,
1181 state: GpadlState,
1182}
1183
1184#[derive(Debug, Copy, Clone, PartialEq, Eq, Inspect)]
1185enum GpadlState {
1186 InProgress,
1188 Offered,
1190 OfferedTearingDown,
1193 Accepted,
1195 TearingDown,
1197}
1198
1199impl Gpadl {
1200 fn new(count: u16, len: usize) -> Self {
1203 Self {
1204 state: GpadlState::InProgress,
1205 count,
1206 buf: Vec::with_capacity(len),
1207 }
1208 }
1209
1210 fn append(&mut self, data: &[u8]) -> Result<bool, ChannelError> {
1212 if self.state == GpadlState::InProgress {
1213 let buf = &mut self.buf;
1214 let len = min(data.len() & !7, (buf.capacity() - buf.len()) * 8);
1219 let data = &data[..len];
1220 let start = buf.len();
1221 buf.resize(buf.len() + data.len() / 8, 0);
1222 buf[start..].as_mut_bytes().copy_from_slice(data);
1223 Ok(if buf.len() == buf.capacity() {
1224 gparange::validate_gpa_ranges(self.count as usize, buf)
1225 .map_err(ChannelError::InvalidGpaRange)?;
1226 self.state = GpadlState::Offered;
1227 true
1228 } else {
1229 false
1230 })
1231 } else {
1232 Err(ChannelError::GpadlAlreadyComplete)
1233 }
1234 }
1235}
1236
1237#[derive(Debug, Copy, Clone)]
1239pub struct OpenParams {
1240 pub open_data: OpenData,
1241 pub connection_id: u32,
1242 pub event_flag: u16,
1243 pub monitor_info: Option<MonitorInfo>,
1244 pub flags: protocol::OpenChannelFlags,
1245 pub reserved_target: Option<ConnectionTarget>,
1246 pub channel_id: ChannelId,
1247}
1248
1249impl OpenParams {
1250 fn from_request(
1251 info: &OfferedInfo,
1252 request: &OpenRequest,
1253 monitor_info: Option<MonitorInfo>,
1254 reserved_target: Option<ConnectionTarget>,
1255 ) -> Self {
1256 let (event_flag, connection_id) = if let Some(id) = request.guest_specified_interrupt_info {
1259 (id.event_flag, id.connection_id)
1260 } else {
1261 (info.channel_id.0 as u16, info.connection_id)
1262 };
1263
1264 Self {
1265 open_data: OpenData {
1266 target_vp: request.target_vp,
1267 ring_offset: request.downstream_ring_buffer_page_offset,
1268 ring_gpadl_id: request.ring_buffer_gpadl_id,
1269 user_data: request.user_data,
1270 event_flag,
1271 connection_id,
1272 },
1273 connection_id,
1274 event_flag,
1275 monitor_info,
1276 flags: request.flags.with_unused(0),
1277 reserved_target,
1278 channel_id: info.channel_id,
1279 }
1280 }
1281}
1282
1283#[derive(Debug)]
1285pub enum Action {
1286 Open(OpenParams, VersionInfo),
1287 Close,
1288 Gpadl(GpadlId, u16, Vec<u64>),
1289 TeardownGpadl {
1290 gpadl_id: GpadlId,
1291 post_restore: bool,
1292 },
1293 Modify {
1294 target_vp: u32,
1295 },
1296}
1297
1298static SUPPORTED_VERSIONS: &[Version] = &[
1300 Version::V1,
1301 Version::Win7,
1302 Version::Win8,
1303 Version::Win8_1,
1304 Version::Win10,
1305 Version::Win10Rs3_0,
1306 Version::Win10Rs3_1,
1307 Version::Win10Rs4,
1308 Version::Win10Rs5,
1309 Version::Iron,
1310 Version::Copper,
1311];
1312
1313const SUPPORTED_FEATURE_FLAGS: FeatureFlags = FeatureFlags::new()
1316 .with_guest_specified_signal_parameters(true)
1317 .with_channel_interrupt_redirection(true)
1318 .with_modify_connection(true)
1319 .with_client_id(true)
1320 .with_pause_resume(true)
1321 .with_server_specified_monitor_pages(true);
1322
1323pub trait Notifier: Send {
1325 fn notify(&mut self, offer_id: OfferId, action: Action);
1327
1328 fn forward_unhandled(&mut self, request: InitiateContactRequest);
1330
1331 fn modify_connection(&mut self, request: ModifyConnectionRequest) -> anyhow::Result<()>;
1337
1338 fn inspect(&self, version: Option<VersionInfo>, offer_id: OfferId, req: inspect::Request<'_>) {
1340 let _ = (version, offer_id, req);
1341 }
1342
1343 #[must_use]
1346 fn send_message(&mut self, message: &OutgoingMessage, target: MessageTarget) -> bool;
1347
1348 fn notify_hvsock(&mut self, request: &HvsockConnectRequest);
1350
1351 fn reset_complete(&mut self);
1353
1354 fn unload_complete(&mut self);
1356}
1357
1358impl Server {
1359 pub fn new(
1361 vtl: Vtl,
1362 child_connection_id: u32,
1363 channel_id_offset: u16,
1364 use_absolute_channel_order: bool,
1365 ) -> Self {
1366 Server {
1367 state: ConnectionState::Disconnected,
1368 channels: ChannelList::new(),
1369 assigned_channels: AssignedChannels::new(vtl, channel_id_offset),
1370 assigned_monitors: AssignedMonitors::new(),
1371 gpadls: Default::default(),
1372 incomplete_gpadls: Default::default(),
1373 child_connection_id,
1374 max_version: None,
1375 delayed_max_version: None,
1376 pending_messages: PendingMessages(VecDeque::new()),
1377 require_server_allocated_mnf: false,
1378 use_absolute_channel_order,
1379 }
1380 }
1381
1382 pub fn with_notifier<'a, T: Notifier>(
1384 &'a mut self,
1385 notifier: &'a mut T,
1386 ) -> ServerWithNotifier<'a, T> {
1387 self.validate();
1388 ServerWithNotifier {
1389 inner: self,
1390 notifier,
1391 }
1392 }
1393
1394 pub fn set_require_server_allocated_mnf(&mut self, require: bool) {
1397 self.require_server_allocated_mnf = require;
1398 }
1399
1400 fn validate(&self) {
1401 #[cfg(debug_assertions)]
1402 for (_, channel) in self.channels.iter() {
1403 let should_have_info = !channel.state.is_released();
1404 if channel.info.is_some() != should_have_info {
1405 panic!("channel invariant violation: {channel:?}");
1406 }
1407 }
1408 }
1409
1410 pub fn set_compatibility_version(&mut self, version: MaxVersionInfo, delay: bool) {
1412 if delay {
1413 self.delayed_max_version = Some(version)
1414 } else {
1415 tracing::info!(?version, "Limiting VmBus connections to version");
1416 self.max_version = Some(version);
1417 }
1418 }
1419
1420 pub fn channel_gpadls(&self, offer_id: OfferId) -> Vec<RestoredGpadl> {
1421 self.gpadls
1422 .iter()
1423 .filter_map(|(&(gpadl_id, gpadl_offer_id), gpadl)| {
1424 if offer_id != gpadl_offer_id {
1425 return None;
1426 }
1427 let accepted = match gpadl.state {
1428 GpadlState::Offered | GpadlState::OfferedTearingDown => false,
1429 GpadlState::Accepted => true,
1430 GpadlState::InProgress | GpadlState::TearingDown => return None,
1431 };
1432 Some(RestoredGpadl {
1433 request: GpadlRequest {
1434 id: gpadl_id,
1435 count: gpadl.count,
1436 buf: gpadl.buf.clone(),
1437 },
1438 accepted,
1439 })
1440 })
1441 .collect()
1442 }
1443
1444 pub fn get_version(&self) -> Option<VersionInfo> {
1445 self.state.get_version()
1446 }
1447
1448 pub fn get_restore_open_params(&self, offer_id: OfferId) -> Result<OpenParams, RestoreError> {
1449 let channel = &self.channels[offer_id];
1450
1451 match channel.restore_state {
1453 RestoreState::New => {
1454 return Err(RestoreError::MissingChannel(channel.offer.key()));
1458 }
1459 RestoreState::Restoring => {}
1460 RestoreState::Unmatched => unreachable!(),
1461 RestoreState::Restored => {
1462 return Err(RestoreError::AlreadyRestored(channel.offer.key()));
1463 }
1464 }
1465
1466 let info = channel
1467 .info
1468 .ok_or_else(|| RestoreError::MissingChannel(channel.offer.key()))?;
1469
1470 let (request, reserved_state) = match channel.state {
1471 ChannelState::Closed => {
1472 return Err(RestoreError::MismatchedOpenState(channel.offer.key()));
1473 }
1474 ChannelState::Closing { params, .. } | ChannelState::ClosingReopen { params, .. } => {
1475 (params, None)
1476 }
1477 ChannelState::Opening {
1478 request,
1479 reserved_state,
1480 } => (request, reserved_state),
1481 ChannelState::Open {
1482 params,
1483 reserved_state,
1484 ..
1485 } => (params, reserved_state),
1486 ChannelState::ClientReleased | ChannelState::Reoffered => {
1487 return Err(RestoreError::MissingChannel(channel.offer.key()));
1488 }
1489 ChannelState::Revoked
1490 | ChannelState::ClosingClientRelease
1491 | ChannelState::OpeningClientRelease => unreachable!(),
1492 };
1493
1494 Ok(OpenParams::from_request(
1495 &info,
1496 &request,
1497 channel.handled_monitor_info(),
1498 reserved_state.map(|state| state.target),
1499 ))
1500 }
1501
1502 pub fn has_pending_messages(&self) -> bool {
1504 !self.pending_messages.0.is_empty() && !self.state.is_paused()
1505 }
1506
1507 pub fn poll_flush_pending_messages(
1509 &mut self,
1510 mut send: impl FnMut(&OutgoingMessage) -> Poll<()>,
1511 ) -> Poll<()> {
1512 if !self.state.is_paused() {
1513 while let Some(message) = self.pending_messages.0.front() {
1514 ready!(send(message));
1515 self.pending_messages.0.pop_front();
1516 }
1517 }
1518
1519 Poll::Ready(())
1520 }
1521}
1522
1523impl<'a, N: 'a + Notifier> ServerWithNotifier<'a, N> {
1524 pub fn restore_channel(&mut self, offer_id: OfferId, open: bool) -> Result<(), RestoreError> {
1530 let channel = &mut self.inner.channels[offer_id];
1531
1532 match channel.restore_state {
1535 RestoreState::New => {
1536 if open {
1540 return Err(RestoreError::MissingChannel(channel.offer.key()));
1541 } else {
1542 return Ok(());
1543 }
1544 }
1545 RestoreState::Restoring => {}
1546 RestoreState::Unmatched => unreachable!(),
1547 RestoreState::Restored => {
1548 return Err(RestoreError::AlreadyRestored(channel.offer.key()));
1549 }
1550 }
1551
1552 let info = channel
1553 .info
1554 .ok_or_else(|| RestoreError::MissingChannel(channel.offer.key()))?;
1555
1556 if let Some(monitor_info) = channel.handled_monitor_info() {
1557 if !self
1558 .inner
1559 .assigned_monitors
1560 .claim_monitor(monitor_info.monitor_id)
1561 {
1562 return Err(RestoreError::DuplicateMonitorId(monitor_info.monitor_id.0));
1563 }
1564 }
1565
1566 if open {
1567 match channel.state {
1568 ChannelState::Closed => {
1569 return Err(RestoreError::MismatchedOpenState(channel.offer.key()));
1570 }
1571 ChannelState::Closing { .. } | ChannelState::ClosingReopen { .. } => {
1572 self.notifier.notify(offer_id, Action::Close);
1573 }
1574 ChannelState::Opening {
1575 request,
1576 reserved_state,
1577 } => {
1578 self.inner
1579 .pending_messages
1580 .sender(self.notifier, self.inner.state.is_paused())
1581 .send_open_result(
1582 info.channel_id,
1583 &request,
1584 protocol::STATUS_SUCCESS,
1585 MessageTarget::for_offer(offer_id, &reserved_state),
1586 );
1587 channel.state = ChannelState::Open {
1588 params: request,
1589 modify_state: ModifyState::NotModifying,
1590 reserved_state,
1591 };
1592 }
1593 ChannelState::Open { .. } => {}
1594 ChannelState::ClientReleased | ChannelState::Reoffered => {
1595 return Err(RestoreError::MissingChannel(channel.offer.key()));
1596 }
1597 ChannelState::Revoked
1598 | ChannelState::ClosingClientRelease
1599 | ChannelState::OpeningClientRelease => unreachable!(),
1600 };
1601 } else {
1602 match channel.state {
1603 ChannelState::Closed => {}
1604 ChannelState::Reoffered => {}
1609 ChannelState::Closing { .. } => {
1610 channel.state = ChannelState::Closed;
1611 }
1612 ChannelState::ClosingReopen { request, .. } => {
1613 self.notifier.notify(
1614 offer_id,
1615 Action::Open(
1616 OpenParams::from_request(
1617 &info,
1618 &request,
1619 channel.handled_monitor_info(),
1620 None,
1621 ),
1622 self.inner.state.get_version().expect("must be connected"),
1623 ),
1624 );
1625 channel.state = ChannelState::Opening {
1626 request,
1627 reserved_state: None,
1628 };
1629 }
1630 ChannelState::Opening {
1631 request,
1632 reserved_state,
1633 } => {
1634 self.notifier.notify(
1635 offer_id,
1636 Action::Open(
1637 OpenParams::from_request(
1638 &info,
1639 &request,
1640 channel.handled_monitor_info(),
1641 reserved_state.map(|state| state.target),
1642 ),
1643 self.inner.state.get_version().expect("must be connected"),
1644 ),
1645 );
1646 }
1647 ChannelState::Open { .. } => {
1648 return Err(RestoreError::MismatchedOpenState(channel.offer.key()));
1649 }
1650 ChannelState::ClientReleased => {
1651 return Err(RestoreError::MissingChannel(channel.offer.key()));
1652 }
1653 ChannelState::Revoked
1654 | ChannelState::ClosingClientRelease
1655 | ChannelState::OpeningClientRelease => unreachable!(),
1656 }
1657 }
1658
1659 channel.restore_state = RestoreState::Restored;
1660 Ok(())
1661 }
1662
1663 pub fn revoke_unclaimed_channels(&mut self) {
1666 for (offer_id, channel) in self.inner.channels.iter_mut() {
1667 match channel.restore_state {
1668 RestoreState::Restored => {
1669 }
1671 RestoreState::New => {
1672 if let ConnectionState::Connected(info) = &self.inner.state {
1677 if info.offers_sent && matches!(channel.state, ChannelState::ClientReleased)
1678 {
1679 channel.prepare_channel(
1680 offer_id,
1681 &mut self.inner.assigned_channels,
1682 &mut self.inner.assigned_monitors,
1683 );
1684 channel.state = ChannelState::Closed;
1685 self.inner
1686 .pending_messages
1687 .sender(self.notifier, self.inner.state.is_paused())
1688 .send_offer(channel, info);
1689 }
1690 }
1691 }
1692 RestoreState::Restoring => {
1693 let retain = revoke(
1697 self.inner
1698 .pending_messages
1699 .sender(self.notifier, self.inner.state.is_paused()),
1700 offer_id,
1701 channel,
1702 &mut self.inner.gpadls,
1703 );
1704 assert!(retain, "channel has not been released");
1705 channel.state = ChannelState::Reoffered;
1706 }
1707 RestoreState::Unmatched => {
1708 let retain = revoke(
1711 self.inner
1712 .pending_messages
1713 .sender(self.notifier, self.inner.state.is_paused()),
1714 offer_id,
1715 channel,
1716 &mut self.inner.gpadls,
1717 );
1718 assert!(retain, "channel has not been released");
1719 }
1720 }
1721 }
1722
1723 for (&(gpadl_id, offer_id), gpadl) in self.inner.gpadls.iter_mut() {
1725 match gpadl.state {
1726 GpadlState::InProgress | GpadlState::Accepted => {}
1727 GpadlState::Offered => {
1728 self.notifier.notify(
1729 offer_id,
1730 Action::Gpadl(gpadl_id, gpadl.count, gpadl.buf.clone()),
1731 );
1732 }
1733 GpadlState::TearingDown => {
1734 self.notifier.notify(
1735 offer_id,
1736 Action::TeardownGpadl {
1737 gpadl_id,
1738 post_restore: true,
1739 },
1740 );
1741 }
1742 GpadlState::OfferedTearingDown => unreachable!(),
1743 }
1744 }
1745
1746 self.check_disconnected();
1747 }
1748
1749 pub fn reset(&mut self) {
1754 assert!(!self.is_resetting());
1755 if self.request_disconnect(ConnectionAction::Reset) {
1756 self.complete_reset();
1757 }
1758 }
1759
1760 fn complete_reset(&mut self) {
1761 for (_, channel) in self.inner.channels.iter_mut() {
1763 channel.restore_state = RestoreState::New;
1764 }
1765 self.inner.pending_messages.0.clear();
1766 self.notifier.reset_complete();
1767 }
1768
1769 pub fn offer_channel(&mut self, offer: OfferParamsInternal) -> Result<OfferId, OfferError> {
1771 if let Some((offer_id, channel)) = self.inner.channels.get_by_key_mut(&offer.key()) {
1773 if channel.restore_state != RestoreState::Unmatched
1777 && !matches!(channel.state, ChannelState::Revoked)
1778 {
1779 return Err(OfferError::AlreadyExists(offer.key()));
1780 }
1781
1782 let info = channel.info.expect("assigned");
1783 if channel.restore_state == RestoreState::Unmatched {
1784 tracing::debug!(
1785 offer_id = offer_id.0,
1786 key = %channel.offer.key(),
1787 "matched channel"
1788 );
1789
1790 assert!(!matches!(channel.state, ChannelState::Revoked));
1791 channel.restore_state = RestoreState::Restoring;
1795
1796 if let MnfUsage::Relayed { monitor_id } = offer.use_mnf {
1799 if info.monitor_id != Some(MonitorId(monitor_id)) {
1800 return Err(OfferError::MismatchedMonitorId(
1801 info.monitor_id,
1802 MonitorId(monitor_id),
1803 ));
1804 }
1805 }
1806 } else {
1807 channel.state = ChannelState::Reoffered;
1811 tracing::info!(?offer_id, key = %channel.offer.key(), "channel marked for reoffer");
1812 }
1813
1814 channel.offer = offer;
1815 return Ok(offer_id);
1816 }
1817
1818 let mut connected_info = None;
1819 let state = match &self.inner.state {
1820 ConnectionState::Connected(info) => {
1821 if info.offers_sent {
1822 connected_info = Some(info);
1823 ChannelState::Closed
1824 } else {
1825 ChannelState::ClientReleased
1826 }
1827 }
1828 ConnectionState::Connecting { .. }
1829 | ConnectionState::Disconnecting { .. }
1830 | ConnectionState::Disconnected => ChannelState::ClientReleased,
1831 };
1832
1833 if self.inner.channels.len() >= self.inner.assigned_channels.allowable_channel_count() {
1835 return Err(OfferError::TooManyChannels);
1836 }
1837
1838 let key = offer.key();
1839 let confidential_ring_buffer = offer.flags.confidential_ring_buffer();
1840 let confidential_external_memory = offer.flags.confidential_external_memory();
1841 let channel = Channel {
1842 info: None,
1843 offer,
1844 state,
1845 restore_state: RestoreState::New,
1846 };
1847
1848 let offer_id = self.inner.channels.offer(channel);
1849 if let Some(info) = connected_info {
1850 let channel = &mut self.inner.channels[offer_id];
1851 channel.prepare_channel(
1852 offer_id,
1853 &mut self.inner.assigned_channels,
1854 &mut self.inner.assigned_monitors,
1855 );
1856
1857 self.inner
1858 .pending_messages
1859 .sender(self.notifier, self.inner.state.is_paused())
1860 .send_offer(channel, info);
1861 }
1862
1863 tracing::info!(?offer_id, %key, confidential_ring_buffer, confidential_external_memory, "new channel");
1864 Ok(offer_id)
1865 }
1866
1867 pub fn revoke_channel(&mut self, offer_id: OfferId) {
1869 let channel = &mut self.inner.channels[offer_id];
1870 let retain = revoke(
1871 self.inner
1872 .pending_messages
1873 .sender(self.notifier, self.inner.state.is_paused()),
1874 offer_id,
1875 channel,
1876 &mut self.inner.gpadls,
1877 );
1878 if !retain {
1879 self.inner.channels.remove(offer_id);
1880 }
1881
1882 self.check_disconnected();
1883 }
1884
1885 pub fn open_complete(&mut self, offer_id: OfferId, result: i32) {
1887 let channel = &mut self.inner.channels[offer_id];
1888 tracing::debug!(offer_id = offer_id.0, key = %channel.offer.key(), result, "open complete");
1889
1890 match channel.state {
1891 ChannelState::Opening {
1892 request,
1893 reserved_state,
1894 } => {
1895 let channel_id = channel.info.expect("assigned").channel_id;
1896 if result >= 0 {
1897 tracelimit::info_ratelimited!(
1898 offer_id = offer_id.0,
1899 channel_id = channel_id.0,
1900 key = %channel.offer.key(),
1901 result,
1902 "opened channel"
1903 );
1904 } else {
1905 tracelimit::error_ratelimited!(
1907 offer_id = offer_id.0,
1908 channel_id = channel_id.0,
1909 key = %channel.offer.key(),
1910 result,
1911 "failed to open channel"
1912 );
1913 }
1914
1915 self.inner
1916 .pending_messages
1917 .sender(self.notifier, self.inner.state.is_paused())
1918 .send_open_result(
1919 channel_id,
1920 &request,
1921 result,
1922 MessageTarget::for_offer(offer_id, &reserved_state),
1923 );
1924 channel.state = if result >= 0 {
1925 ChannelState::Open {
1926 params: request,
1927 modify_state: ModifyState::NotModifying,
1928 reserved_state,
1929 }
1930 } else {
1931 ChannelState::Closed
1932 };
1933 }
1934 ChannelState::OpeningClientRelease => {
1935 tracing::info!(
1936 offer_id = offer_id.0,
1937 key = %channel.offer.key(),
1938 result,
1939 "opened channel (client released)"
1940 );
1941
1942 if result >= 0 {
1943 channel.state = ChannelState::ClosingClientRelease;
1944 self.notifier.notify(offer_id, Action::Close);
1945 } else {
1946 channel.state = ChannelState::ClientReleased;
1947 self.check_disconnected();
1948 }
1949 }
1950
1951 ChannelState::ClientReleased
1952 | ChannelState::Closed
1953 | ChannelState::Open { .. }
1954 | ChannelState::Closing { .. }
1955 | ChannelState::ClosingReopen { .. }
1956 | ChannelState::Revoked
1957 | ChannelState::Reoffered
1958 | ChannelState::ClosingClientRelease => {
1959 tracing::error!(?offer_id, key = %channel.offer.key(), state = ?channel.state, "invalid open complete")
1960 }
1961 }
1962 }
1963
1964 fn are_channels_reset(&self, include_reserved: bool) -> bool {
1967 self.inner.gpadls.keys().all(|(_, offer_id)| {
1968 !include_reserved && self.inner.channels[*offer_id].state.is_reserved()
1969 }) && self.inner.channels.iter().all(|(_, channel)| {
1970 matches!(channel.state, ChannelState::ClientReleased)
1971 || (!include_reserved && channel.state.is_reserved())
1972 })
1973 }
1974
1975 fn check_disconnected(&mut self) {
1979 match self.inner.state {
1980 ConnectionState::Disconnecting {
1981 next_action,
1982 modify_sent: false,
1983 } => {
1984 if self.are_channels_reset(matches!(next_action, ConnectionAction::Reset)) {
1985 self.notify_disconnect(next_action);
1986 }
1987 }
1988 ConnectionState::Disconnecting {
1989 modify_sent: true, ..
1990 }
1991 | ConnectionState::Disconnected
1992 | ConnectionState::Connected { .. }
1993 | ConnectionState::Connecting { .. } => (),
1994 }
1995 }
1996
1997 fn notify_disconnect(&mut self, next_action: ConnectionAction) {
1999 debug_assert!(self.are_channels_reset(matches!(next_action, ConnectionAction::Reset)));
2001 self.inner.state = ConnectionState::Disconnecting {
2002 next_action,
2003 modify_sent: true,
2004 };
2005
2006 self.notifier
2008 .modify_connection(ModifyConnectionRequest {
2009 monitor_page: Update::Reset,
2010 interrupt_page: Update::Reset,
2011 ..Default::default()
2012 })
2013 .expect("resetting state should not fail");
2014 }
2015
2016 fn is_resetting(&self) -> bool {
2019 matches!(
2020 &self.inner.state,
2021 ConnectionState::Connecting {
2022 next_action: ConnectionAction::Reset,
2023 ..
2024 } | ConnectionState::Disconnecting {
2025 next_action: ConnectionAction::Reset,
2026 ..
2027 }
2028 )
2029 }
2030
2031 pub fn close_complete(&mut self, offer_id: OfferId) {
2033 let channel = &mut self.inner.channels[offer_id];
2034 tracing::info!(offer_id = offer_id.0, key = %channel.offer.key(), "closed channel");
2035 match channel.state {
2036 ChannelState::Closing {
2037 reserved_state: Some(reserved_state),
2038 ..
2039 } => {
2040 channel.state = ChannelState::Closed;
2041 if matches!(self.inner.state, ConnectionState::Connected { .. }) {
2042 let channel_id = channel.info.expect("assigned").channel_id;
2043 self.send_close_reserved_channel_response(
2044 channel_id,
2045 offer_id,
2046 reserved_state.target,
2047 );
2048 } else {
2049 if Self::client_release_channel(
2052 self.inner
2053 .pending_messages
2054 .sender(self.notifier, self.inner.state.is_paused()),
2055 offer_id,
2056 channel,
2057 &mut self.inner.gpadls,
2058 &mut self.inner.incomplete_gpadls,
2059 &mut self.inner.assigned_channels,
2060 &mut self.inner.assigned_monitors,
2061 None,
2062 ) {
2063 self.inner.channels.remove(offer_id);
2064 }
2065 }
2066 }
2067 ChannelState::Closing { .. } => {
2068 channel.state = ChannelState::Closed;
2069 }
2070 ChannelState::ClosingClientRelease => {
2071 channel.state = ChannelState::ClientReleased;
2072 self.check_disconnected();
2073 }
2074 ChannelState::ClosingReopen { request, .. } => {
2075 channel.state = ChannelState::Closed;
2076 self.open_channel(offer_id, &request, None);
2077 }
2078
2079 ChannelState::Closed
2080 | ChannelState::ClientReleased
2081 | ChannelState::Opening { .. }
2082 | ChannelState::Open { .. }
2083 | ChannelState::Revoked
2084 | ChannelState::Reoffered
2085 | ChannelState::OpeningClientRelease => {
2086 tracing::error!(?offer_id, key = %channel.offer.key(), state = ?channel.state, "invalid close complete")
2087 }
2088 }
2089 }
2090
2091 fn send_close_reserved_channel_response(
2092 &mut self,
2093 channel_id: ChannelId,
2094 offer_id: OfferId,
2095 target: ConnectionTarget,
2096 ) {
2097 self.sender().send_message_with_target(
2098 &protocol::CloseReservedChannelResponse { channel_id },
2099 MessageTarget::ReservedChannel(offer_id, target),
2100 );
2101 }
2102
2103 fn handle_initiate_contact(
2106 &mut self,
2107 input: &protocol::InitiateContact2,
2108 message: &SynicMessage,
2109 includes_client_id: bool,
2110 ) -> Result<(), ChannelError> {
2111 let target_info =
2112 protocol::TargetInfo::from(input.initiate_contact.interrupt_page_or_target_info);
2113
2114 let target_sint = if message.multiclient
2115 && input.initiate_contact.version_requested >= Version::Win10Rs3_1 as u32
2116 {
2117 target_info.sint()
2118 } else {
2119 VMBUS_SINT
2120 };
2121
2122 let target_vtl = if message.multiclient
2123 && input.initiate_contact.version_requested >= Version::Win10Rs4 as u32
2124 {
2125 target_info.vtl()
2126 } else {
2127 0
2128 };
2129
2130 let feature_flags = if input.initiate_contact.version_requested >= Version::Copper as u32 {
2131 target_info.feature_flags()
2132 } else {
2133 0
2134 };
2135
2136 let target_message_vp =
2141 if input.initiate_contact.version_requested >= Version::Win8_1 as u32 {
2142 input.initiate_contact.target_message_vp
2143 } else {
2144 0
2145 };
2146
2147 let interrupt_page = (input.initiate_contact.version_requested < Version::Win8 as u32
2154 && input.initiate_contact.interrupt_page_or_target_info != 0)
2155 .then_some(input.initiate_contact.interrupt_page_or_target_info);
2156
2157 let monitor_page = if (input.initiate_contact.parent_to_child_monitor_page_gpa == 0)
2160 != (input.initiate_contact.child_to_parent_monitor_page_gpa == 0)
2161 {
2162 MonitorPageRequest::Invalid
2163 } else if input.initiate_contact.parent_to_child_monitor_page_gpa != 0 {
2164 MonitorPageRequest::Some(MonitorPageGpas {
2165 parent_to_child: input.initiate_contact.parent_to_child_monitor_page_gpa,
2166 child_to_parent: input.initiate_contact.child_to_parent_monitor_page_gpa,
2167 })
2168 } else {
2169 MonitorPageRequest::None
2170 };
2171
2172 let client_id = if FeatureFlags::from(feature_flags).client_id() {
2175 if includes_client_id {
2176 input.client_id
2177 } else {
2178 return Err(ChannelError::ParseError(
2179 protocol::ParseError::MessageTooSmall(Some(
2180 protocol::MessageType::INITIATE_CONTACT,
2181 )),
2182 ));
2183 }
2184 } else {
2185 Guid::ZERO
2186 };
2187
2188 let request = InitiateContactRequest {
2189 version_requested: input.initiate_contact.version_requested,
2190 target_message_vp,
2191 monitor_page,
2192 target_sint,
2193 target_vtl,
2194 feature_flags,
2195 interrupt_page,
2196 client_id,
2197 trusted: message.trusted,
2198 };
2199 self.initiate_contact(request);
2200 Ok(())
2201 }
2202
2203 pub fn initiate_contact(&mut self, request: InitiateContactRequest) {
2204 let vtl = self.inner.assigned_channels.vtl as u8;
2207 if request.target_vtl != vtl {
2208 self.notifier.forward_unhandled(request);
2210 return;
2211 }
2212
2213 if request.target_sint != VMBUS_SINT {
2214 tracelimit::warn_ratelimited!(
2215 target_vtl = request.target_vtl,
2216 target_sint = request.target_sint,
2217 version = request.version_requested,
2218 "unsupported multiclient request",
2219 );
2220
2221 self.send_version_response_with_target(
2223 None,
2224 MessageTarget::Custom(ConnectionTarget {
2225 vp: request.target_message_vp,
2226 sint: request.target_sint,
2227 }),
2228 );
2229
2230 return;
2231 }
2232
2233 if !self.request_disconnect(ConnectionAction::Reconnect {
2234 initiate_contact: request,
2235 }) {
2236 return;
2237 }
2238
2239 let Some(version) = self.check_version_supported(&request) else {
2240 tracelimit::warn_ratelimited!(
2241 vtl,
2242 version = request.version_requested,
2243 client_id = ?request.client_id,
2244 "Guest requested unsupported version"
2245 );
2246
2247 self.send_version_response(None);
2249 return;
2250 };
2251
2252 tracelimit::info_ratelimited!(
2253 vtl,
2254 ?version,
2255 client_id = ?request.client_id,
2256 trusted = request.trusted,
2257 "Guest negotiated version"
2258 );
2259
2260 let monitor_page = match request.monitor_page {
2263 MonitorPageRequest::Some(mp) => {
2264 if self.inner.require_server_allocated_mnf {
2265 if !version.feature_flags.server_specified_monitor_pages() {
2266 tracelimit::warn_ratelimited!(
2267 "guest-supplied monitor pages not supported; MNF will be disabled"
2268 );
2269 }
2270
2271 None
2272 } else {
2273 Some(mp)
2274 }
2275 }
2276 MonitorPageRequest::None => None,
2277 MonitorPageRequest::Invalid => {
2278 self.send_version_response(Some(VersionResponseData::new(
2280 version,
2281 protocol::ConnectionState::FAILED_UNKNOWN_FAILURE,
2282 )));
2283
2284 return;
2285 }
2286 };
2287
2288 self.inner.state = ConnectionState::Connecting {
2289 info: ConnectionInfo {
2290 version,
2291 trusted: request.trusted,
2292 interrupt_page: request.interrupt_page,
2293 monitor_page: monitor_page.map(MonitorPageGpaInfo::from_guest_gpas),
2294 target_message_vp: request.target_message_vp,
2295 modifying: false,
2296 offers_sent: false,
2297 client_id: request.client_id,
2298 paused: false,
2299 },
2300 next_action: ConnectionAction::None,
2301 };
2302
2303 if let Err(err) = self.notifier.modify_connection(ModifyConnectionRequest {
2306 version: Some(version),
2307 monitor_page: monitor_page.into(),
2308 interrupt_page: request.interrupt_page.into(),
2309 target_message_vp: Some(request.target_message_vp),
2310 notify_relay: true,
2311 }) {
2312 tracelimit::error_ratelimited!(?err, "server failed to change state");
2313 self.inner.state = ConnectionState::Disconnected;
2314 self.send_version_response(Some(VersionResponseData::new(
2315 version,
2316 protocol::ConnectionState::FAILED_UNKNOWN_FAILURE,
2317 )));
2318 }
2319 }
2320
2321 pub(crate) fn complete_initiate_contact(&mut self, response: ModifyConnectionResponse) {
2322 let ConnectionState::Connecting {
2323 mut info,
2324 next_action,
2325 } = self.inner.state
2326 else {
2327 panic!("Invalid state for completing InitiateContact.");
2328 };
2329
2330 const LOCAL_FEATURE_FLAGS: FeatureFlags = FeatureFlags::new()
2334 .with_client_id(true)
2335 .with_confidential_channels(true);
2336
2337 let (relay_feature_flags, server_specified_monitor_page) = match response {
2338 ModifyConnectionResponse::Supported(
2340 protocol::ConnectionState::SUCCESSFUL,
2341 feature_flags,
2342 server_specified_monitor_page,
2343 ) => (feature_flags, server_specified_monitor_page),
2344 ModifyConnectionResponse::Supported(
2347 connection_state,
2348 feature_flags,
2349 server_specified_monitor_page,
2350 ) => {
2351 tracelimit::error_ratelimited!(
2352 ?connection_state,
2353 "initiate contact failed because relay request failed"
2354 );
2355
2356 info.version.feature_flags &= (feature_flags | LOCAL_FEATURE_FLAGS)
2359 .with_server_specified_monitor_pages(server_specified_monitor_page.is_some());
2360
2361 self.send_version_response(Some(VersionResponseData::new(
2362 info.version,
2363 connection_state,
2364 )));
2365 self.inner.state = ConnectionState::Disconnected;
2366 return;
2367 }
2368 ModifyConnectionResponse::Unsupported => {
2371 self.send_version_response(None);
2372 self.inner.state = ConnectionState::Disconnected;
2373 return;
2374 }
2375 ModifyConnectionResponse::Modified(_) => {
2376 panic!("Invalid response for completing InitiateContact.");
2377 }
2378 };
2379
2380 assert!(
2382 info.version.feature_flags.server_specified_monitor_pages()
2383 || server_specified_monitor_page.is_none()
2384 );
2385
2386 info.version.feature_flags &= relay_feature_flags | LOCAL_FEATURE_FLAGS;
2389
2390 if let Some(gpas) = server_specified_monitor_page {
2394 info.monitor_page = Some(MonitorPageGpaInfo::from_server_gpas(gpas));
2395 info.version
2396 .feature_flags
2397 .set_server_specified_monitor_pages(true);
2398 } else {
2399 info.version
2400 .feature_flags
2401 .set_server_specified_monitor_pages(false);
2402 }
2403
2404 let version = info.version;
2405 self.inner.state = ConnectionState::Connected(info);
2406
2407 self.send_version_response(Some(
2408 VersionResponseData::new(version, protocol::ConnectionState::SUCCESSFUL)
2409 .with_monitor_pages(server_specified_monitor_page),
2410 ));
2411 if !matches!(next_action, ConnectionAction::None) && self.request_disconnect(next_action) {
2412 self.do_next_action(next_action);
2413 }
2414 }
2415
2416 fn check_version_supported(&self, request: &InitiateContactRequest) -> Option<VersionInfo> {
2418 let version = SUPPORTED_VERSIONS
2419 .iter()
2420 .find(|v| request.version_requested == **v as u32)
2421 .copied()?;
2422
2423 if let Some(max_version) = self.inner.max_version {
2425 if version as u32 > max_version.version {
2426 return None;
2427 }
2428 }
2429
2430 let supported_flags = if version >= Version::Copper {
2431 let max_supported_flags =
2433 SUPPORTED_FEATURE_FLAGS.with_confidential_channels(request.trusted);
2434
2435 if let Some(max_version) = self.inner.max_version {
2437 max_supported_flags & max_version.feature_flags
2438 } else {
2439 max_supported_flags
2440 }
2441 } else {
2442 FeatureFlags::new()
2443 };
2444
2445 let feature_flags = supported_flags & request.feature_flags.into();
2446
2447 assert!(version >= Version::Copper || feature_flags == FeatureFlags::new());
2448 if feature_flags.into_bits() != request.feature_flags {
2449 tracelimit::info_ratelimited!(
2452 supported = feature_flags.into_bits(),
2453 requested = request.feature_flags,
2454 "guest requested unsupported feature flags."
2455 );
2456 }
2457
2458 Some(VersionInfo {
2459 version,
2460 feature_flags,
2461 })
2462 }
2463
2464 fn send_version_response(&mut self, data: Option<VersionResponseData>) {
2465 self.send_version_response_with_target(data, MessageTarget::Default);
2466 }
2467
2468 fn send_version_response_with_target(
2469 &mut self,
2470 data: Option<VersionResponseData>,
2471 target: MessageTarget,
2472 ) {
2473 enum VersionResponseType {
2474 PreCopper,
2475 Copper,
2476 CopperWithServerMnf,
2477 }
2478
2479 let mut response_copper_with_mnf = protocol::VersionResponse3::new_zeroed();
2480 let response_copper = &mut response_copper_with_mnf.version_response2;
2481 let response = &mut response_copper.version_response;
2482 let mut response_type = VersionResponseType::PreCopper;
2483 if let Some(data) = data {
2484 if data.state == protocol::ConnectionState::SUCCESSFUL
2487 || data.version.version >= Version::Win8
2488 {
2489 response.version_supported = 1;
2490 response.connection_state = data.state;
2491 response.selected_version_or_connection_id =
2492 if data.version.version >= Version::Win10Rs3_1 {
2493 self.inner.child_connection_id
2494 } else {
2495 data.version.version as u32
2496 };
2497
2498 if data.version.version >= Version::Copper {
2499 response_copper.supported_features = data.version.feature_flags.into();
2500 response_type = VersionResponseType::Copper;
2501 if let Some(monitor_page) = data.monitor_pages {
2502 assert!(data.version.feature_flags.server_specified_monitor_pages());
2503 response_copper_with_mnf.child_to_parent_monitor_page_gpa =
2504 monitor_page.child_to_parent;
2505 response_copper_with_mnf.parent_to_child_monitor_page_gpa =
2506 monitor_page.parent_to_child;
2507 response_type = VersionResponseType::CopperWithServerMnf;
2508 }
2509 }
2510 }
2511 }
2512
2513 match response_type {
2515 VersionResponseType::PreCopper => {
2516 self.sender().send_message_with_target(response, target)
2517 }
2518 VersionResponseType::Copper => self
2519 .sender()
2520 .send_message_with_target(response_copper, target),
2521 VersionResponseType::CopperWithServerMnf => self
2522 .sender()
2523 .send_message_with_target(&response_copper_with_mnf, target),
2524 }
2525 }
2526
2527 fn request_disconnect(&mut self, new_action: ConnectionAction) -> bool {
2530 assert!(!self.is_resetting());
2531
2532 let gpadls = &mut self.inner.gpadls;
2534 let vm_reset = matches!(new_action, ConnectionAction::Reset);
2535 self.inner.channels.retain(|offer_id, channel| {
2536 (!vm_reset && channel.state.is_reserved())
2538 || !Self::client_release_channel(
2539 self.inner
2540 .pending_messages
2541 .sender(self.notifier, self.inner.state.is_paused()),
2542 offer_id,
2543 channel,
2544 gpadls,
2545 &mut self.inner.incomplete_gpadls,
2546 &mut self.inner.assigned_channels,
2547 &mut self.inner.assigned_monitors,
2548 None,
2549 )
2550 });
2551
2552 match &mut self.inner.state {
2556 ConnectionState::Disconnected => {
2557 if vm_reset {
2559 if !self.are_channels_reset(true) {
2560 self.inner.state = ConnectionState::Disconnecting {
2561 next_action: ConnectionAction::Reset,
2562 modify_sent: false,
2563 };
2564 }
2565 } else {
2566 assert!(self.are_channels_reset(false));
2567 }
2568 }
2569
2570 ConnectionState::Connected { .. } => {
2571 if self.are_channels_reset(vm_reset) {
2572 self.notify_disconnect(new_action);
2573 } else {
2574 self.inner.state = ConnectionState::Disconnecting {
2575 next_action: new_action,
2576 modify_sent: false,
2577 };
2578 }
2579 }
2580
2581 ConnectionState::Connecting { next_action, .. }
2582 | ConnectionState::Disconnecting { next_action, .. } => {
2583 *next_action = new_action;
2584 }
2585 }
2586
2587 matches!(self.inner.state, ConnectionState::Disconnected)
2588 }
2589
2590 pub(crate) fn complete_disconnect(&mut self) {
2591 if let ConnectionState::Disconnecting {
2592 next_action,
2593 modify_sent,
2594 } = std::mem::replace(&mut self.inner.state, ConnectionState::Disconnected)
2595 {
2596 assert!(self.are_channels_reset(matches!(next_action, ConnectionAction::Reset)));
2597 if !modify_sent {
2598 tracelimit::warn_ratelimited!("unexpected modify response");
2599 }
2600
2601 self.inner.state = ConnectionState::Disconnected;
2602 self.do_next_action(next_action);
2603 } else {
2604 unreachable!("not ready for disconnect");
2605 }
2606 }
2607
2608 fn do_next_action(&mut self, action: ConnectionAction) {
2609 match action {
2610 ConnectionAction::None => {}
2611 ConnectionAction::Reset => {
2612 self.complete_reset();
2613 }
2614 ConnectionAction::SendUnloadComplete => {
2615 self.complete_unload();
2616 }
2617 ConnectionAction::Reconnect { initiate_contact } => {
2618 self.initiate_contact(initiate_contact);
2619 }
2620 ConnectionAction::SendFailedVersionResponse => {
2621 self.send_version_response(None);
2624 }
2625 }
2626 }
2627
2628 fn handle_unload(&mut self) {
2630 tracing::debug!(
2631 vtl = self.inner.assigned_channels.vtl as u8,
2632 state = ?self.inner.state,
2633 "VmBus received unload request from guest",
2634 );
2635
2636 if self.request_disconnect(ConnectionAction::SendUnloadComplete) {
2637 self.complete_unload();
2638 }
2639 }
2640
2641 fn complete_unload(&mut self) {
2642 self.notifier.unload_complete();
2643 if let Some(version) = self.inner.delayed_max_version.take() {
2644 self.inner.set_compatibility_version(version, false);
2645 }
2646
2647 self.sender().send_message(&protocol::UnloadComplete {});
2648 tracelimit::info_ratelimited!("Vmbus disconnected");
2649 }
2650
2651 fn handle_request_offers(&mut self) -> Result<(), ChannelError> {
2653 let ConnectionState::Connected(info) = &mut self.inner.state else {
2654 unreachable!(
2655 "in unexpected state {:?}, should be prevented by Message::parse()",
2656 self.inner.state
2657 );
2658 };
2659
2660 if info.offers_sent {
2661 return Err(ChannelError::OffersAlreadySent);
2662 }
2663
2664 info.offers_sent = true;
2665
2666 let mut sorted_channels: Vec<_> = self
2669 .inner
2670 .channels
2671 .iter_mut()
2672 .filter(|(_, channel)| !channel.state.is_reserved())
2673 .collect();
2674
2675 if self.inner.use_absolute_channel_order {
2676 sorted_channels.sort_unstable_by_key(|(_, channel)| {
2677 (
2678 channel.offer.offer_order.unwrap_or(u64::MAX),
2679 channel.offer.interface_id,
2680 channel.offer.instance_id,
2681 )
2682 });
2683 } else {
2684 sorted_channels.sort_unstable_by_key(|(_, channel)| {
2685 (
2686 channel.offer.interface_id,
2687 channel.offer.offer_order.unwrap_or(u64::MAX),
2688 channel.offer.instance_id,
2689 )
2690 });
2691 }
2692
2693 for (offer_id, channel) in sorted_channels {
2694 assert!(matches!(channel.state, ChannelState::ClientReleased));
2695
2696 channel.prepare_channel(
2697 offer_id,
2698 &mut self.inner.assigned_channels,
2699 &mut self.inner.assigned_monitors,
2700 );
2701
2702 channel.state = ChannelState::Closed;
2703 self.inner
2704 .pending_messages
2705 .sender(self.notifier, info.paused)
2706 .send_offer(channel, info);
2707 }
2708 self.sender().send_message(&protocol::AllOffersDelivered {});
2709
2710 Ok(())
2711 }
2712
2713 #[must_use]
2716 fn gpadl_updated(
2717 mut sender: MessageSender<'_, N>,
2718 offer_id: OfferId,
2719 channel: &Channel,
2720 gpadl_id: GpadlId,
2721 gpadl: &Gpadl,
2722 ) -> bool {
2723 if channel.state.is_revoked() {
2724 let channel_id = channel.info.as_ref().expect("assigned").channel_id;
2725 sender.send_gpadl_created(channel_id, gpadl_id, protocol::STATUS_UNSUCCESSFUL);
2726 false
2727 } else {
2728 sender.notifier.notify(
2730 offer_id,
2731 Action::Gpadl(gpadl_id, gpadl.count, gpadl.buf.clone()),
2732 );
2733 true
2734 }
2735 }
2736
2737 fn handle_gpadl_header_core(
2739 &mut self,
2740 input: &protocol::GpadlHeader,
2741 range: &[u8],
2742 ) -> Result<(), ChannelError> {
2743 let (offer_id, channel) = self
2745 .inner
2746 .channels
2747 .get_by_channel_id_mut(&self.inner.assigned_channels, input.channel_id)?;
2748
2749 if channel.state.is_reserved() {
2752 return Err(ChannelError::ChannelReserved);
2753 }
2754
2755 let mut gpadl = Gpadl::new(input.count, input.len as usize / 8);
2757 let done = gpadl.append(range)?;
2758
2759 let gpadl = match self.inner.gpadls.entry((input.gpadl_id, offer_id)) {
2761 Entry::Vacant(entry) => entry.insert(gpadl),
2762 Entry::Occupied(_) => return Err(ChannelError::DuplicateGpadlId),
2763 };
2764
2765 if !done {
2771 match self.inner.incomplete_gpadls.entry(input.gpadl_id) {
2772 Entry::Vacant(entry) => {
2773 entry.insert(offer_id);
2774 }
2775 Entry::Occupied(_) => {
2776 self.inner.gpadls.remove(&(input.gpadl_id, offer_id));
2777 tracelimit::error_ratelimited!(
2778 channel_id = ?input.channel_id,
2779 key = %channel.offer.key(),
2780 gpadl_id = ?input.gpadl_id,
2781 "duplicate in-progress gpadl ID",
2782 );
2783 return Err(ChannelError::DuplicateGpadlId);
2784 }
2785 }
2786 }
2787
2788 if done
2789 && !Self::gpadl_updated(
2790 self.inner
2791 .pending_messages
2792 .sender(self.notifier, self.inner.state.is_paused()),
2793 offer_id,
2794 channel,
2795 input.gpadl_id,
2796 gpadl,
2797 )
2798 {
2799 self.inner.gpadls.remove(&(input.gpadl_id, offer_id));
2800 }
2801 Ok(())
2802 }
2803
2804 fn handle_gpadl_header(&mut self, input: &protocol::GpadlHeader, range: &[u8]) {
2806 if let Err(err) = self.handle_gpadl_header_core(input, range) {
2807 tracelimit::warn_ratelimited!(
2808 err = &err as &dyn std::error::Error,
2809 channel_id = ?input.channel_id,
2810 key = %self.inner.channels.get_by_channel_id(&self.inner.assigned_channels, input.channel_id).map(|(_, c)| c.offer.key()).unwrap_or_default(),
2811 gpadl_id = ?input.gpadl_id,
2812 "error handling gpadl header"
2813 );
2814
2815 self.sender().send_gpadl_created(
2817 input.channel_id,
2818 input.gpadl_id,
2819 protocol::STATUS_UNSUCCESSFUL,
2820 );
2821 }
2822 }
2823
2824 fn handle_gpadl_body(
2830 &mut self,
2831 input: &protocol::GpadlBody,
2832 range: &[u8],
2833 ) -> Result<(), ChannelError> {
2834 let &offer_id = self
2838 .inner
2839 .incomplete_gpadls
2840 .get(&input.gpadl_id)
2841 .ok_or(ChannelError::UnknownGpadlId)?;
2842 let gpadl = self
2843 .inner
2844 .gpadls
2845 .get_mut(&(input.gpadl_id, offer_id))
2846 .ok_or(ChannelError::UnknownGpadlId)?;
2847 let channel = &mut self.inner.channels[offer_id];
2848
2849 match gpadl.append(range) {
2850 Ok(done) => {
2851 if done {
2852 self.inner.incomplete_gpadls.remove(&input.gpadl_id);
2853 if !Self::gpadl_updated(
2854 self.inner
2855 .pending_messages
2856 .sender(self.notifier, self.inner.state.is_paused()),
2857 offer_id,
2858 channel,
2859 input.gpadl_id,
2860 gpadl,
2861 ) {
2862 self.inner.gpadls.remove(&(input.gpadl_id, offer_id));
2863 }
2864 }
2865 }
2866 Err(err) => {
2867 self.inner.incomplete_gpadls.remove(&input.gpadl_id);
2868 self.inner.gpadls.remove(&(input.gpadl_id, offer_id));
2869 let channel_id = channel.info.as_ref().expect("assigned").channel_id;
2870 tracelimit::warn_ratelimited!(
2871 err = &err as &dyn std::error::Error,
2872 channel_id = channel_id.0,
2873 key = %channel.offer.key(),
2874 gpadl_id = input.gpadl_id.0,
2875 "error handling gpadl body"
2876 );
2877 self.sender().send_gpadl_created(
2878 channel_id,
2879 input.gpadl_id,
2880 protocol::STATUS_UNSUCCESSFUL,
2881 );
2882 }
2883 }
2884
2885 Ok(())
2886 }
2887
2888 fn handle_gpadl_teardown(
2890 &mut self,
2891 input: &protocol::GpadlTeardown,
2892 ) -> Result<(), ChannelError> {
2893 let (offer_id, channel) = self
2894 .inner
2895 .channels
2896 .get_by_channel_id_mut(&self.inner.assigned_channels, input.channel_id)?;
2897
2898 tracing::debug!(
2899 channel_id = input.channel_id.0,
2900 key = %channel.offer.key(),
2901 gpadl_id = input.gpadl_id.0,
2902 "Received GPADL teardown request"
2903 );
2904
2905 let gpadl = self
2906 .inner
2907 .gpadls
2908 .get_mut(&(input.gpadl_id, offer_id))
2909 .ok_or(ChannelError::UnknownGpadlId)?;
2910
2911 match gpadl.state {
2912 GpadlState::InProgress
2913 | GpadlState::Offered
2914 | GpadlState::OfferedTearingDown
2915 | GpadlState::TearingDown => {
2916 return Err(ChannelError::InvalidGpadlState);
2917 }
2918 GpadlState::Accepted => {
2919 if channel.info.as_ref().map(|info| info.channel_id) != Some(input.channel_id) {
2920 return Err(ChannelError::WrongGpadlChannelId);
2921 }
2922
2923 if channel.state.is_reserved() {
2927 return Err(ChannelError::ChannelReserved);
2928 }
2929
2930 if channel.state.is_revoked() {
2931 tracing::trace!(
2932 channel_id = input.channel_id.0,
2933 key = %channel.offer.key(),
2934 gpadl_id = input.gpadl_id.0,
2935 "Gpadl teardown for revoked channel"
2936 );
2937
2938 self.inner.gpadls.remove(&(input.gpadl_id, offer_id));
2939 self.sender().send_gpadl_torndown(input.gpadl_id);
2940 } else {
2941 gpadl.state = GpadlState::TearingDown;
2942 self.notifier.notify(
2943 offer_id,
2944 Action::TeardownGpadl {
2945 gpadl_id: input.gpadl_id,
2946 post_restore: false,
2947 },
2948 );
2949 }
2950 }
2951 }
2952 Ok(())
2953 }
2954
2955 fn open_channel(
2958 &mut self,
2959 offer_id: OfferId,
2960 input: &OpenRequest,
2961 reserved_state: Option<ReservedState>,
2962 ) {
2963 let channel = &mut self.inner.channels[offer_id];
2964 assert!(matches!(channel.state, ChannelState::Closed));
2965
2966 channel.state = ChannelState::Opening {
2967 request: *input,
2968 reserved_state,
2969 };
2970
2971 let info = channel.info.as_ref().expect("assigned");
2974 self.notifier.notify(
2975 offer_id,
2976 Action::Open(
2977 OpenParams::from_request(
2978 info,
2979 input,
2980 channel.handled_monitor_info(),
2981 reserved_state.map(|state| state.target),
2982 ),
2983 self.inner.state.get_version().expect("must be connected"),
2984 ),
2985 );
2986 }
2987
2988 fn handle_open_channel(&mut self, input: &protocol::OpenChannel2) -> Result<(), ChannelError> {
2990 let (offer_id, channel) = self
2991 .inner
2992 .channels
2993 .get_by_channel_id_mut(&self.inner.assigned_channels, input.open_channel.channel_id)?;
2994
2995 let guest_specified_interrupt_info = self
2996 .inner
2997 .state
2998 .check_feature_flags(|ff| ff.guest_specified_signal_parameters())
2999 .then_some(SignalInfo {
3000 event_flag: input.event_flag,
3001 connection_id: input.connection_id,
3002 });
3003
3004 let flags = if self
3005 .inner
3006 .state
3007 .check_feature_flags(|ff| ff.channel_interrupt_redirection())
3008 {
3009 input.flags
3010 } else {
3011 Default::default()
3012 };
3013
3014 let request = OpenRequest {
3015 open_id: input.open_channel.open_id,
3016 ring_buffer_gpadl_id: input.open_channel.ring_buffer_gpadl_id,
3017 target_vp: protocol::vp_index_if_enabled(input.open_channel.target_vp),
3018 downstream_ring_buffer_page_offset: input
3019 .open_channel
3020 .downstream_ring_buffer_page_offset,
3021 user_data: input.open_channel.user_data,
3022 guest_specified_interrupt_info,
3023 flags,
3024 };
3025
3026 match channel.state {
3027 ChannelState::Closed => self.open_channel(offer_id, &request, None),
3028 ChannelState::Closing { params, .. } => {
3029 channel.state = ChannelState::ClosingReopen { params, request }
3033 }
3034 ChannelState::Revoked | ChannelState::Reoffered => {}
3035
3036 ChannelState::Open { .. }
3037 | ChannelState::Opening { .. }
3038 | ChannelState::ClosingReopen { .. } => return Err(ChannelError::ChannelAlreadyOpen),
3039
3040 ChannelState::ClientReleased
3041 | ChannelState::ClosingClientRelease
3042 | ChannelState::OpeningClientRelease => unreachable!(),
3043 }
3044 Ok(())
3045 }
3046
3047 fn handle_close_channel(&mut self, input: &protocol::CloseChannel) -> Result<(), ChannelError> {
3049 let (offer_id, channel) = self
3050 .inner
3051 .channels
3052 .get_by_channel_id_mut(&self.inner.assigned_channels, input.channel_id)?;
3053
3054 match channel.state {
3055 ChannelState::Open {
3056 params,
3057 modify_state,
3058 reserved_state: None,
3059 } => {
3060 if modify_state.is_modifying() {
3061 tracelimit::warn_ratelimited!(
3062 key = %channel.offer.key(),
3063 ?modify_state,
3064 "Client is closing the channel with a modify in progress"
3065 )
3066 }
3067
3068 channel.state = ChannelState::Closing {
3069 params,
3070 reserved_state: None,
3071 };
3072 self.notifier.notify(offer_id, Action::Close);
3073 }
3074
3075 ChannelState::Open {
3076 reserved_state: Some(_),
3077 ..
3078 } => return Err(ChannelError::ChannelReserved),
3079
3080 ChannelState::Revoked | ChannelState::Reoffered => {}
3081
3082 ChannelState::Closed
3083 | ChannelState::Opening { .. }
3084 | ChannelState::Closing { .. }
3085 | ChannelState::ClosingReopen { .. } => return Err(ChannelError::ChannelNotOpen),
3086
3087 ChannelState::ClientReleased
3088 | ChannelState::ClosingClientRelease
3089 | ChannelState::OpeningClientRelease => unreachable!(),
3090 }
3091
3092 Ok(())
3093 }
3094
3095 fn handle_open_reserved_channel(
3098 &mut self,
3099 input: &protocol::OpenReservedChannel,
3100 version: VersionInfo,
3101 ) -> Result<(), ChannelError> {
3102 let (offer_id, channel) = self
3103 .inner
3104 .channels
3105 .get_by_channel_id_mut(&self.inner.assigned_channels, input.channel_id)?;
3106
3107 let target = ConnectionTarget {
3108 vp: input.target_vp,
3109 sint: input.target_sint as u8,
3110 };
3111
3112 let reserved_state = Some(ReservedState { version, target });
3113
3114 let request = OpenRequest {
3115 ring_buffer_gpadl_id: input.ring_buffer_gpadl,
3116 target_vp: None,
3118 downstream_ring_buffer_page_offset: input.downstream_page_offset,
3119 open_id: 0,
3120 user_data: UserDefinedData::new_zeroed(),
3121 guest_specified_interrupt_info: None,
3122 flags: Default::default(),
3123 };
3124
3125 match channel.state {
3126 ChannelState::Closed => self.open_channel(offer_id, &request, reserved_state),
3127 ChannelState::Revoked | ChannelState::Reoffered => {}
3128
3129 ChannelState::Open { .. } | ChannelState::Opening { .. } => {
3130 return Err(ChannelError::ChannelAlreadyOpen);
3131 }
3132
3133 ChannelState::Closing { .. } | ChannelState::ClosingReopen { .. } => {
3134 return Err(ChannelError::InvalidChannelState);
3135 }
3136
3137 ChannelState::ClientReleased
3138 | ChannelState::ClosingClientRelease
3139 | ChannelState::OpeningClientRelease => unreachable!(),
3140 }
3141 Ok(())
3142 }
3143
3144 fn handle_close_reserved_channel(
3147 &mut self,
3148 input: &protocol::CloseReservedChannel,
3149 ) -> Result<(), ChannelError> {
3150 let (offer_id, channel) = self
3151 .inner
3152 .channels
3153 .get_by_channel_id_mut(&self.inner.assigned_channels, input.channel_id)?;
3154
3155 match channel.state {
3156 ChannelState::Open {
3157 params,
3158 reserved_state: Some(mut resvd),
3159 ..
3160 } => {
3161 resvd.target.vp = input.target_vp;
3162 resvd.target.sint = input.target_sint as u8;
3163 channel.state = ChannelState::Closing {
3164 params,
3165 reserved_state: Some(resvd),
3166 };
3167 self.notifier.notify(offer_id, Action::Close);
3168 }
3169
3170 ChannelState::Open {
3171 reserved_state: None,
3172 ..
3173 } => return Err(ChannelError::ChannelNotReserved),
3174
3175 ChannelState::Revoked | ChannelState::Reoffered => {}
3176
3177 ChannelState::Closed
3178 | ChannelState::Opening { .. }
3179 | ChannelState::Closing { .. }
3180 | ChannelState::ClosingReopen { .. } => return Err(ChannelError::ChannelNotOpen),
3181
3182 ChannelState::ClientReleased
3183 | ChannelState::ClosingClientRelease
3184 | ChannelState::OpeningClientRelease => unreachable!(),
3185 }
3186
3187 Ok(())
3188 }
3189
3190 #[must_use]
3194 fn client_release_channel(
3195 mut sender: MessageSender<'_, N>,
3196 offer_id: OfferId,
3197 channel: &mut Channel,
3198 gpadls: &mut GpadlMap,
3199 incomplete_gpadls: &mut IncompleteGpadlMap,
3200 assigned_channels: &mut AssignedChannels,
3201 assigned_monitors: &mut AssignedMonitors,
3202 info: Option<&ConnectionInfo>,
3203 ) -> bool {
3204 tracelimit::info_ratelimited!(?offer_id, key = %channel.offer.key(), "client released channel");
3205 gpadls.retain(|&(gpadl_id, gpadl_offer_id), gpadl| {
3207 if gpadl_offer_id != offer_id {
3208 return true;
3209 }
3210 match gpadl.state {
3211 GpadlState::InProgress => {
3212 incomplete_gpadls.remove(&gpadl_id);
3213 false
3214 }
3215 GpadlState::Offered => {
3216 gpadl.state = GpadlState::OfferedTearingDown;
3217 true
3218 }
3219 GpadlState::Accepted => {
3220 if channel.state.is_revoked() {
3221 false
3223 } else {
3224 gpadl.state = GpadlState::TearingDown;
3225 sender.notifier.notify(
3226 offer_id,
3227 Action::TeardownGpadl {
3228 gpadl_id,
3229 post_restore: false,
3230 },
3231 );
3232 true
3233 }
3234 }
3235 GpadlState::OfferedTearingDown | GpadlState::TearingDown => true,
3236 }
3237 });
3238
3239 let remove = match &mut channel.state {
3240 ChannelState::Closed => {
3241 channel.state = ChannelState::ClientReleased;
3242 false
3243 }
3244 ChannelState::Reoffered => {
3245 if let Some(info) = info {
3246 channel.state = ChannelState::Closed;
3247 channel.restore_state = RestoreState::New;
3248 sender.send_offer(channel, info);
3249 return false;
3251 }
3252 channel.state = ChannelState::ClientReleased;
3253 false
3254 }
3255 ChannelState::Revoked => {
3256 channel.state = ChannelState::ClientReleased;
3257 true
3258 }
3259 ChannelState::Opening { .. } => {
3260 channel.state = ChannelState::OpeningClientRelease;
3261 false
3262 }
3263 ChannelState::Open { .. } => {
3264 channel.state = ChannelState::ClosingClientRelease;
3265 sender.notifier.notify(offer_id, Action::Close);
3266 false
3267 }
3268 ChannelState::Closing { .. } | ChannelState::ClosingReopen { .. } => {
3269 channel.state = ChannelState::ClosingClientRelease;
3270 false
3271 }
3272
3273 ChannelState::ClosingClientRelease
3274 | ChannelState::OpeningClientRelease
3275 | ChannelState::ClientReleased => false,
3276 };
3277
3278 assert!(channel.state.is_released());
3279
3280 channel.release_channel(offer_id, assigned_channels, assigned_monitors);
3281 remove
3282 }
3283
3284 fn handle_rel_id_released(
3286 &mut self,
3287 input: &protocol::RelIdReleased,
3288 ) -> Result<(), ChannelError> {
3289 let channel_id = input.channel_id;
3290 let (offer_id, channel) = self
3291 .inner
3292 .channels
3293 .get_by_channel_id_mut(&self.inner.assigned_channels, channel_id)?;
3294
3295 match channel.state {
3296 ChannelState::Closed
3297 | ChannelState::Revoked
3298 | ChannelState::Closing { .. }
3299 | ChannelState::Reoffered => {
3300 if Self::client_release_channel(
3301 self.inner
3302 .pending_messages
3303 .sender(self.notifier, self.inner.state.is_paused()),
3304 offer_id,
3305 channel,
3306 &mut self.inner.gpadls,
3307 &mut self.inner.incomplete_gpadls,
3308 &mut self.inner.assigned_channels,
3309 &mut self.inner.assigned_monitors,
3310 self.inner.state.get_connected_info(),
3311 ) {
3312 self.inner.channels.remove(offer_id);
3313 }
3314
3315 self.check_disconnected();
3316 }
3317
3318 ChannelState::Opening { .. }
3319 | ChannelState::Open { .. }
3320 | ChannelState::ClosingReopen { .. } => return Err(ChannelError::InvalidChannelState),
3321
3322 ChannelState::ClientReleased
3323 | ChannelState::OpeningClientRelease
3324 | ChannelState::ClosingClientRelease => unreachable!(),
3325 }
3326 Ok(())
3327 }
3328
3329 fn handle_tl_connect_request(&mut self, request: protocol::TlConnectRequest2) {
3332 let version = self
3333 .inner
3334 .state
3335 .get_version()
3336 .expect("must be connected")
3337 .version;
3338
3339 let hosted_silo_unaware = version < Version::Win10Rs5;
3340 self.notifier
3341 .notify_hvsock(&HvsockConnectRequest::from_message(
3342 request,
3343 hosted_silo_unaware,
3344 ));
3345 }
3346
3347 pub fn send_tl_connect_result(&mut self, result: HvsockConnectResult) {
3349 if !result.success && self.inner.state.check_version(Version::Win10Rs3_0) {
3353 self.sender().send_message(&protocol::TlConnectResult {
3357 service_id: result.service_id,
3358 endpoint_id: result.endpoint_id,
3359 status: protocol::STATUS_CONNECTION_REFUSED,
3360 })
3361 }
3362 }
3363
3364 fn handle_modify_channel(
3367 &mut self,
3368 request: &protocol::ModifyChannel,
3369 ) -> Result<(), ChannelError> {
3370 let result = self.modify_channel(request);
3371 if result.is_err() {
3372 self.send_modify_channel_response(request.channel_id, protocol::STATUS_UNSUCCESSFUL);
3373 }
3374
3375 result
3376 }
3377
3378 fn modify_channel(&mut self, request: &protocol::ModifyChannel) -> Result<(), ChannelError> {
3380 if request.target_vp == protocol::VP_INDEX_DISABLE_INTERRUPT {
3382 return Err(ChannelError::InvalidTargetVp);
3383 }
3384
3385 let (offer_id, channel) = self
3386 .inner
3387 .channels
3388 .get_by_channel_id_mut(&self.inner.assigned_channels, request.channel_id)?;
3389
3390 let (open_request, modify_state) = match &mut channel.state {
3391 ChannelState::Open {
3392 params,
3393 modify_state,
3394 reserved_state: None,
3395 } => (params, modify_state),
3396 _ => return Err(ChannelError::InvalidChannelState),
3397 };
3398
3399 if open_request.target_vp.is_none() {
3400 return Err(ChannelError::InterruptsDisabled);
3401 }
3402
3403 if let ModifyState::Modifying { pending_target_vp } = modify_state {
3404 if self.inner.state.check_version(Version::Iron) {
3405 tracelimit::warn_ratelimited!(
3408 key = %channel.offer.key(),
3409 "Client sent new ModifyChannel before receiving ModifyChannelResponse."
3410 );
3411 } else {
3412 *pending_target_vp = Some(request.target_vp);
3415 }
3416 } else {
3417 self.notifier.notify(
3418 offer_id,
3419 Action::Modify {
3420 target_vp: request.target_vp,
3421 },
3422 );
3423
3424 open_request.target_vp = Some(request.target_vp);
3426 *modify_state = ModifyState::Modifying {
3427 pending_target_vp: None,
3428 };
3429 }
3430
3431 Ok(())
3432 }
3433
3434 pub fn modify_channel_complete(&mut self, offer_id: OfferId, status: i32) {
3441 let channel = &mut self.inner.channels[offer_id];
3442
3443 if let ChannelState::Open {
3444 params,
3445 modify_state: ModifyState::Modifying { pending_target_vp },
3446 reserved_state: None,
3447 } = channel.state
3448 {
3449 channel.state = ChannelState::Open {
3450 params,
3451 modify_state: ModifyState::NotModifying,
3452 reserved_state: None,
3453 };
3454
3455 let channel_id = channel.info.as_ref().expect("assigned").channel_id;
3457 let key = channel.offer.key();
3458 self.send_modify_channel_response(channel_id, status);
3459
3460 if let Some(target_vp) = pending_target_vp {
3462 let request = protocol::ModifyChannel {
3463 channel_id,
3464 target_vp,
3465 };
3466
3467 if let Err(error) = self.handle_modify_channel(&request) {
3468 tracelimit::warn_ratelimited!(?error, %key, "Pending ModifyChannel request failed.")
3469 }
3470 }
3471 }
3472 }
3473
3474 fn send_modify_channel_response(&mut self, channel_id: ChannelId, status: i32) {
3475 if self.inner.state.check_version(Version::Iron) {
3476 self.sender()
3477 .send_message(&protocol::ModifyChannelResponse { channel_id, status });
3478 }
3479 }
3480
3481 fn handle_modify_connection(&mut self, request: protocol::ModifyConnection) {
3482 if let Err(err) = self.modify_connection(request) {
3483 tracelimit::error_ratelimited!(?err, "modifying connection failed");
3484 self.complete_modify_connection(ModifyConnectionResponse::Modified(
3485 protocol::ConnectionState::FAILED_UNKNOWN_FAILURE,
3486 ));
3487 }
3488 }
3489
3490 fn modify_connection(&mut self, request: protocol::ModifyConnection) -> anyhow::Result<()> {
3491 let ConnectionState::Connected(info) = &mut self.inner.state else {
3492 anyhow::bail!(
3493 "Invalid state for ModifyConnection request: {:?}",
3494 self.inner.state
3495 );
3496 };
3497
3498 if info.modifying {
3499 anyhow::bail!(
3500 "Duplicate ModifyConnection request, state: {:?}",
3501 self.inner.state
3502 );
3503 }
3504
3505 if matches!(
3506 info.monitor_page,
3507 Some(MonitorPageGpaInfo {
3508 server_allocated: true,
3509 ..
3510 })
3511 ) {
3512 anyhow::bail!("Cannot modify server-allocated monitor pages");
3513 }
3514
3515 if (request.child_to_parent_monitor_page_gpa == 0)
3516 != (request.parent_to_child_monitor_page_gpa == 0)
3517 {
3518 anyhow::bail!("Guest must specify either both or no monitor pages, {request:?}");
3519 }
3520
3521 let monitor_page = (request.child_to_parent_monitor_page_gpa != 0).then_some(
3522 MonitorPageGpaInfo::from_guest_gpas(MonitorPageGpas {
3523 child_to_parent: request.child_to_parent_monitor_page_gpa,
3524 parent_to_child: request.parent_to_child_monitor_page_gpa,
3525 }),
3526 );
3527
3528 info.modifying = true;
3529 info.monitor_page = monitor_page;
3530 tracing::debug!("modifying connection parameters.");
3531 self.notifier.modify_connection(request.into())?;
3532
3533 Ok(())
3534 }
3535
3536 pub fn complete_modify_connection(&mut self, response: ModifyConnectionResponse) {
3537 tracing::debug!(?response, "modifying connection parameters complete");
3538
3539 match &mut self.inner.state {
3543 ConnectionState::Connecting { .. } => self.complete_initiate_contact(response),
3544 ConnectionState::Disconnecting { .. } => self.complete_disconnect(),
3545 ConnectionState::Connected(info) => {
3546 let ModifyConnectionResponse::Modified(connection_state) = response else {
3547 panic!(
3548 "Relay should not return {:?} for a modify request with no version.",
3549 response
3550 );
3551 };
3552
3553 if !info.modifying {
3554 panic!(
3555 "ModifyConnection response while not modifying, state: {:?}",
3556 self.inner.state
3557 );
3558 }
3559
3560 info.modifying = false;
3561 self.sender()
3562 .send_message(&protocol::ModifyConnectionResponse { connection_state });
3563 }
3564 _ => panic!(
3565 "Invalid state for ModifyConnection response: {:?}",
3566 self.inner.state
3567 ),
3568 }
3569 }
3570
3571 fn handle_pause(&mut self) {
3572 tracelimit::info_ratelimited!("pausing sending messages");
3573 self.sender().send_message(&protocol::PauseResponse {});
3574 let ConnectionState::Connected(info) = &mut self.inner.state else {
3575 unreachable!(
3576 "in unexpected state {:?}, should be prevented by Message::parse()",
3577 self.inner.state
3578 );
3579 };
3580 info.paused = true;
3581 }
3582
3583 pub fn handle_synic_message(&mut self, message: SynicMessage) -> Result<(), ChannelError> {
3585 assert!(!self.is_resetting());
3586
3587 let version = self.inner.state.get_version();
3588 let msg = Message::parse(&message.data, version)?;
3589 tracing::trace!(?msg, message.trusted, "received vmbus message");
3590 if self.inner.state.is_trusted() && !message.trusted {
3595 tracelimit::warn_ratelimited!(?msg, "Received untrusted message");
3596 return Err(ChannelError::UntrustedMessage);
3597 }
3598
3599 match &mut self.inner.state {
3601 ConnectionState::Connected(info) if info.paused => {
3602 if !matches!(
3603 msg,
3604 Message::Resume(..)
3605 | Message::Unload(..)
3606 | Message::InitiateContact { .. }
3607 | Message::InitiateContact2 { .. }
3608 ) {
3609 tracelimit::warn_ratelimited!(?msg, "Received message while paused");
3610 return Err(ChannelError::Paused);
3611 }
3612 tracelimit::info_ratelimited!("resuming sending messages");
3613 info.paused = false;
3614 }
3615 _ => {}
3616 }
3617
3618 match msg {
3619 Message::InitiateContact2(input, ..) => {
3620 self.handle_initiate_contact(&input, &message, true)?
3621 }
3622 Message::InitiateContact(input, ..) => {
3623 self.handle_initiate_contact(&input.into(), &message, false)?
3624 }
3625 Message::Unload(..) => self.handle_unload(),
3626 Message::RequestOffers(..) => self.handle_request_offers()?,
3627 Message::GpadlHeader(input, range) => self.handle_gpadl_header(&input, range),
3628 Message::GpadlBody(input, range) => self.handle_gpadl_body(&input, range)?,
3629 Message::GpadlTeardown(input, ..) => self.handle_gpadl_teardown(&input)?,
3630 Message::OpenChannel(input, ..) => self.handle_open_channel(&input.into())?,
3631 Message::OpenChannel2(input, ..) => self.handle_open_channel(&input)?,
3632 Message::CloseChannel(input, ..) => self.handle_close_channel(&input)?,
3633 Message::RelIdReleased(input, ..) => self.handle_rel_id_released(&input)?,
3634 Message::TlConnectRequest(input, ..) => self.handle_tl_connect_request(input.into()),
3635 Message::TlConnectRequest2(input, ..) => self.handle_tl_connect_request(input),
3636 Message::ModifyChannel(input, ..) => self.handle_modify_channel(&input)?,
3637 Message::ModifyConnection(input, ..) => self.handle_modify_connection(input),
3638 Message::OpenReservedChannel(input, ..) => self.handle_open_reserved_channel(
3639 &input,
3640 version.expect("version validated by Message::parse"),
3641 )?,
3642 Message::CloseReservedChannel(input, ..) => {
3643 self.handle_close_reserved_channel(&input)?
3644 }
3645 Message::Pause(protocol::Pause, ..) => self.handle_pause(),
3646 Message::Resume(protocol::Resume, ..) => {}
3647 Message::OfferChannel(..)
3649 | Message::RescindChannelOffer(..)
3650 | Message::AllOffersDelivered(..)
3651 | Message::OpenResult(..)
3652 | Message::GpadlCreated(..)
3653 | Message::GpadlTorndown(..)
3654 | Message::VersionResponse(..)
3655 | Message::VersionResponse2(..)
3656 | Message::VersionResponse3(..)
3657 | Message::UnloadComplete(..)
3658 | Message::CloseReservedChannelResponse(..)
3659 | Message::TlConnectResult(..)
3660 | Message::ModifyChannelResponse(..)
3661 | Message::ModifyConnectionResponse(..)
3662 | Message::PauseResponse(..) => {
3663 unreachable!("Server received client message {:?}", msg);
3664 }
3665 }
3666 Ok(())
3667 }
3668
3669 pub fn gpadl_create_complete(&mut self, offer_id: OfferId, gpadl_id: GpadlId, status: i32) {
3671 let Some(gpadl) = self.inner.gpadls.get_mut(&(gpadl_id, offer_id)) else {
3672 tracelimit::error_ratelimited!(
3673 ?offer_id,
3674 key = %self.inner.channels[offer_id].offer.key(),
3675 ?gpadl_id,
3676 "invalid gpadl ID for channel"
3677 );
3678 return;
3679 };
3680 let retain = match gpadl.state {
3681 GpadlState::InProgress | GpadlState::TearingDown | GpadlState::Accepted => {
3682 tracelimit::error_ratelimited!(?offer_id, ?gpadl_id, ?gpadl, "invalid gpadl state");
3683 return;
3684 }
3685 GpadlState::Offered => {
3686 let channel_id = self.inner.channels[offer_id]
3687 .info
3688 .as_ref()
3689 .expect("assigned")
3690 .channel_id;
3691 self.inner
3692 .pending_messages
3693 .sender(self.notifier, self.inner.state.is_paused())
3694 .send_gpadl_created(channel_id, gpadl_id, status);
3695 if status >= 0 {
3696 gpadl.state = GpadlState::Accepted;
3697 true
3698 } else {
3699 false
3700 }
3701 }
3702 GpadlState::OfferedTearingDown => {
3703 if status >= 0 {
3704 self.notifier.notify(
3706 offer_id,
3707 Action::TeardownGpadl {
3708 gpadl_id,
3709 post_restore: false,
3710 },
3711 );
3712 gpadl.state = GpadlState::TearingDown;
3713 true
3714 } else {
3715 false
3716 }
3717 }
3718 };
3719 if !retain {
3720 self.inner
3721 .gpadls
3722 .remove(&(gpadl_id, offer_id))
3723 .expect("gpadl validated above");
3724
3725 self.check_disconnected();
3726 }
3727 }
3728
3729 pub fn gpadl_teardown_complete(&mut self, offer_id: OfferId, gpadl_id: GpadlId) {
3731 let channel = &mut self.inner.channels[offer_id];
3732 let Some(gpadl) = self.inner.gpadls.get_mut(&(gpadl_id, offer_id)) else {
3733 tracelimit::error_ratelimited!(
3734 ?offer_id,
3735 key = %channel.offer.key(),
3736 ?gpadl_id,
3737 "invalid gpadl ID for channel"
3738 );
3739 return;
3740 };
3741 tracing::debug!(
3742 offer_id = offer_id.0,
3743 key = %channel.offer.key(),
3744 gpadl_id = gpadl_id.0,
3745 "Gpadl teardown complete"
3746 );
3747 match gpadl.state {
3748 GpadlState::InProgress
3749 | GpadlState::Offered
3750 | GpadlState::OfferedTearingDown
3751 | GpadlState::Accepted => {
3752 tracelimit::error_ratelimited!(?offer_id, key = %channel.offer.key(), ?gpadl_id, ?gpadl, "invalid gpadl state");
3753 }
3754 GpadlState::TearingDown => {
3755 if !channel.state.is_released() {
3756 self.sender().send_gpadl_torndown(gpadl_id);
3757 }
3758 self.inner
3759 .gpadls
3760 .remove(&(gpadl_id, offer_id))
3761 .expect("gpadl validated above");
3762
3763 self.check_disconnected();
3764 }
3765 }
3766 }
3767
3768 fn sender(&mut self) -> MessageSender<'_, N> {
3773 self.inner
3774 .pending_messages
3775 .sender(self.notifier, self.inner.state.is_paused())
3776 }
3777}
3778
3779fn revoke<N: Notifier>(
3780 mut sender: MessageSender<'_, N>,
3781 offer_id: OfferId,
3782 channel: &mut Channel,
3783 gpadls: &mut GpadlMap,
3784) -> bool {
3785 let info = match channel.state {
3786 ChannelState::Closed
3787 | ChannelState::Open { .. }
3788 | ChannelState::Opening { .. }
3789 | ChannelState::Closing { .. }
3790 | ChannelState::ClosingReopen { .. } => {
3791 channel.state = ChannelState::Revoked;
3792 Some(channel.info.as_ref().expect("assigned"))
3793 }
3794 ChannelState::Reoffered => {
3795 channel.state = ChannelState::Revoked;
3796 None
3797 }
3798 ChannelState::ClientReleased
3799 | ChannelState::OpeningClientRelease
3800 | ChannelState::ClosingClientRelease => None,
3801 ChannelState::Revoked => return true,
3803 };
3804 let retain = !channel.state.is_released();
3805
3806 gpadls.retain(|&(gpadl_id, gpadl_offer_id), gpadl| {
3808 if gpadl_offer_id != offer_id {
3809 return true;
3810 }
3811
3812 match gpadl.state {
3813 GpadlState::InProgress => true,
3814 GpadlState::Offered => {
3815 if let Some(info) = info {
3816 sender.send_gpadl_created(
3817 info.channel_id,
3818 gpadl_id,
3819 protocol::STATUS_UNSUCCESSFUL,
3820 );
3821 }
3822 false
3823 }
3824 GpadlState::OfferedTearingDown => false,
3825 GpadlState::Accepted => true,
3826 GpadlState::TearingDown => {
3827 if info.is_some() {
3828 sender.send_gpadl_torndown(gpadl_id);
3829 }
3830 false
3831 }
3832 }
3833 });
3834 if let Some(info) = info {
3835 sender.send_rescind(info);
3836 }
3837 if channel.restore_state != RestoreState::New {
3839 channel.restore_state = RestoreState::Restored;
3840 }
3841 retain
3842}
3843
3844struct PendingMessages(VecDeque<OutgoingMessage>);
3845
3846impl PendingMessages {
3847 fn sender<'a, N: Notifier>(
3849 &'a mut self,
3850 notifier: &'a mut N,
3851 is_paused: bool,
3852 ) -> MessageSender<'a, N> {
3853 MessageSender {
3854 notifier,
3855 pending_messages: self,
3856 is_paused,
3857 }
3858 }
3859}
3860
3861struct MessageSender<'a, N> {
3864 notifier: &'a mut N,
3865 pending_messages: &'a mut PendingMessages,
3866 is_paused: bool,
3867}
3868
3869impl<N: Notifier> MessageSender<'_, N> {
3870 fn send_message<
3872 T: IntoBytes + protocol::VmbusMessage + std::fmt::Debug + Immutable + KnownLayout,
3873 >(
3874 &mut self,
3875 msg: &T,
3876 ) {
3877 let message = OutgoingMessage::new(msg);
3878
3879 tracing::trace!(typ = ?T::MESSAGE_TYPE, ?msg, "sending message");
3880 if !self.pending_messages.0.is_empty()
3882 || self.is_paused
3883 || !self.notifier.send_message(&message, MessageTarget::Default)
3884 {
3885 tracing::trace!("message queued");
3886 self.pending_messages.0.push_back(message);
3888 }
3889 }
3890
3891 fn send_message_with_target<
3893 T: IntoBytes + protocol::VmbusMessage + std::fmt::Debug + Immutable + KnownLayout,
3894 >(
3895 &mut self,
3896 msg: &T,
3897 target: MessageTarget,
3898 ) {
3899 if target == MessageTarget::Default {
3900 self.send_message(msg);
3901 } else {
3902 tracing::trace!(typ = ?T::MESSAGE_TYPE, ?msg, "sending message");
3903 let message = OutgoingMessage::new(msg);
3906 if !self.notifier.send_message(&message, target) {
3907 tracelimit::warn_ratelimited!(?target, "failed to send message");
3908 }
3909 }
3910 }
3911
3912 fn send_offer(&mut self, channel: &mut Channel, connection_info: &ConnectionInfo) {
3914 let info = channel.info.as_ref().expect("assigned");
3915 let mut flags = channel.offer.flags;
3916 if !connection_info
3917 .version
3918 .feature_flags
3919 .confidential_channels()
3920 {
3921 flags.set_confidential_ring_buffer(false);
3922 flags.set_confidential_external_memory(false);
3923 }
3924
3925 let monitor_id = connection_info.monitor_page.and(info.monitor_id);
3932 let msg = protocol::OfferChannel {
3933 interface_id: channel.offer.interface_id,
3934 instance_id: channel.offer.instance_id,
3935 rsvd: [0; 4],
3936 flags,
3937 mmio_megabytes: channel.offer.mmio_megabytes,
3938 user_defined: channel.offer.user_defined,
3939 subchannel_index: channel.offer.subchannel_index,
3940 mmio_megabytes_optional: channel.offer.mmio_megabytes_optional,
3941 channel_id: info.channel_id,
3942 monitor_id: monitor_id.unwrap_or(MonitorId::INVALID).0,
3943 monitor_allocated: monitor_id.is_some().into(),
3944 is_dedicated: 1,
3947 connection_id: info.connection_id,
3948 };
3949 tracing::info!(
3950 channel_id = msg.channel_id.0,
3951 connection_id = msg.connection_id,
3952 key = %channel.offer.key(),
3953 "sending offer to guest"
3954 );
3955
3956 self.send_message(&msg);
3957 }
3958
3959 fn send_open_result(
3960 &mut self,
3961 channel_id: ChannelId,
3962 open_request: &OpenRequest,
3963 result: i32,
3964 target: MessageTarget,
3965 ) {
3966 self.send_message_with_target(
3967 &protocol::OpenResult {
3968 channel_id,
3969 open_id: open_request.open_id,
3970 status: result as u32,
3971 },
3972 target,
3973 );
3974 }
3975
3976 fn send_gpadl_created(&mut self, channel_id: ChannelId, gpadl_id: GpadlId, status: i32) {
3977 self.send_message(&protocol::GpadlCreated {
3978 channel_id,
3979 gpadl_id,
3980 status,
3981 });
3982 }
3983
3984 fn send_gpadl_torndown(&mut self, gpadl_id: GpadlId) {
3985 self.send_message(&protocol::GpadlTorndown { gpadl_id });
3986 }
3987
3988 fn send_rescind(&mut self, info: &OfferedInfo) {
3989 tracing::info!(
3990 channel_id = info.channel_id.0,
3991 "rescinding channel from guest"
3992 );
3993
3994 self.send_message(&protocol::RescindChannelOffer {
3995 channel_id: info.channel_id,
3996 });
3997 }
3998}
3999
4000struct VersionResponseData {
4002 version: VersionInfo,
4003 state: protocol::ConnectionState,
4004 monitor_pages: Option<MonitorPageGpas>,
4005}
4006
4007impl VersionResponseData {
4008 fn new(version: VersionInfo, state: protocol::ConnectionState) -> Self {
4010 VersionResponseData {
4011 version,
4012 state,
4013 monitor_pages: None,
4014 }
4015 }
4016
4017 fn with_monitor_pages(mut self, monitor_pages: Option<MonitorPageGpas>) -> Self {
4019 self.monitor_pages = monitor_pages;
4020 self
4021 }
4022}