1pub mod saved_state;
5#[cfg(test)]
6mod tests;
7
8use crate::Guid;
9use crate::SynicMessage;
10use crate::monitor::AssignedMonitors;
11use crate::protocol::Version;
12use hvdef::Vtl;
13use inspect::Inspect;
14pub use saved_state::RestoreError;
15pub use saved_state::SavedState;
16pub use saved_state::SavedStateData;
17use slab::Slab;
18use std::cmp::min;
19use std::collections::VecDeque;
20use std::collections::hash_map::Entry;
21use std::collections::hash_map::HashMap;
22use std::fmt::Display;
23use std::ops::Index;
24use std::ops::IndexMut;
25use std::task::Poll;
26use std::task::ready;
27use std::time::Duration;
28use thiserror::Error;
29use vmbus_channel::bus::ChannelType;
30use vmbus_channel::bus::GpadlRequest;
31use vmbus_channel::bus::OfferKey;
32use vmbus_channel::bus::OfferParams;
33use vmbus_channel::bus::OpenData;
34use vmbus_channel::bus::RestoredGpadl;
35use vmbus_core::HvsockConnectRequest;
36use vmbus_core::HvsockConnectResult;
37use vmbus_core::MaxVersionInfo;
38use vmbus_core::OutgoingMessage;
39use vmbus_core::VMBUS_SINT;
40use vmbus_core::VersionInfo;
41use vmbus_core::protocol;
42use vmbus_core::protocol::ChannelId;
43use vmbus_core::protocol::ConnectionId;
44use vmbus_core::protocol::FeatureFlags;
45use vmbus_core::protocol::GpadlId;
46use vmbus_core::protocol::Message;
47use vmbus_core::protocol::OfferFlags;
48use vmbus_core::protocol::UserDefinedData;
49use vmbus_ring::gparange;
50use vmcore::monitor::MonitorId;
51use vmcore::synic::MonitorInfo;
52use vmcore::synic::MonitorPageGpas;
53use zerocopy::FromZeros;
54use zerocopy::Immutable;
55use zerocopy::IntoBytes;
56use zerocopy::KnownLayout;
57
58#[derive(Debug, Error)]
60pub enum ChannelError {
61 #[error("unknown channel ID")]
62 UnknownChannelId,
63 #[error("unknown GPADL ID")]
64 UnknownGpadlId,
65 #[error("parse error")]
66 ParseError(#[from] protocol::ParseError),
67 #[error("invalid gpa range")]
68 InvalidGpaRange(#[source] gparange::Error),
69 #[error("duplicate GPADL ID")]
70 DuplicateGpadlId,
71 #[error("GPADL is already complete")]
72 GpadlAlreadyComplete,
73 #[error("GPADL channel ID mismatch")]
74 WrongGpadlChannelId,
75 #[error("trying to open an open channel")]
76 ChannelAlreadyOpen,
77 #[error("trying to close a closed channel")]
78 ChannelNotOpen,
79 #[error("invalid GPADL state for operation")]
80 InvalidGpadlState,
81 #[error("invalid channel state for operation")]
82 InvalidChannelState,
83 #[error("channel ID has already been released")]
84 ChannelReleased,
85 #[error("channel offers have already been sent")]
86 OffersAlreadySent,
87 #[error("invalid operation on reserved channel")]
88 ChannelReserved,
89 #[error("invalid operation on non-reserved channel")]
90 ChannelNotReserved,
91 #[error("received untrusted message for trusted connection")]
92 UntrustedMessage,
93 #[error("received a non-resuming message while paused")]
94 Paused,
95 #[error("invalid target VP")]
96 InvalidTargetVp,
97 #[error("interrupts are disabled for this channel")]
98 InterruptsDisabled,
99}
100
101#[derive(Debug, Error)]
102pub enum OfferError {
103 #[error("the channel ID {} is not valid for this operation", (.0).0)]
104 InvalidChannelId(ChannelId),
105 #[error("the channel ID {} is already in use", (.0).0)]
106 ChannelIdInUse(ChannelId),
107 #[error("offer {0} already exists")]
108 AlreadyExists(OfferKey),
109 #[error("specified resources do not match those of the existing saved or revoked offer")]
110 IncompatibleResources,
111 #[error("too many channels have been offered")]
112 TooManyChannels,
113 #[error("mismatched monitor ID from saved state; expected {0:?}, actual {1:?}")]
114 MismatchedMonitorId(Option<MonitorId>, MonitorId),
115}
116
117#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
119pub struct OfferId(usize);
120
121type IncompleteGpadlMap = HashMap<GpadlId, OfferId>;
122
123type GpadlMap = HashMap<(GpadlId, OfferId), Gpadl>;
124
125pub struct Server {
127 state: ConnectionState,
128 channels: ChannelList,
129 assigned_channels: AssignedChannels,
130 assigned_monitors: AssignedMonitors,
131 gpadls: GpadlMap,
132 incomplete_gpadls: IncompleteGpadlMap,
133 child_connection_id: u32,
134 max_version: Option<MaxVersionInfo>,
135 delayed_max_version: Option<MaxVersionInfo>,
136 pending_messages: PendingMessages,
139 require_server_allocated_mnf: bool,
143 use_absolute_channel_order: bool,
144}
145
146pub struct ServerWithNotifier<'a, T> {
147 inner: &'a mut Server,
148 notifier: &'a mut T,
149}
150
151impl<T> Drop for ServerWithNotifier<'_, T> {
152 fn drop(&mut self) {
153 self.inner.validate();
154 }
155}
156
157impl<T: Notifier> Inspect for ServerWithNotifier<'_, T> {
158 fn inspect(&self, req: inspect::Request<'_>) {
159 let mut resp = req.respond();
160 let (state, info, next_action) = match &self.inner.state {
161 ConnectionState::Disconnected => ("disconnected", None, None),
162 ConnectionState::Connecting { info, .. } => ("connecting", Some(info), None),
163 ConnectionState::Connected(info) => (
164 if info.offers_sent {
165 "connected"
166 } else {
167 "negotiated"
168 },
169 Some(info),
170 None,
171 ),
172 ConnectionState::Disconnecting { next_action, .. } => {
173 ("disconnecting", None, Some(next_action))
174 }
175 };
176
177 resp.field("connection_info", info);
178 let next_action = next_action.map(|a| match a {
179 ConnectionAction::None => "disconnect",
180 ConnectionAction::Reset => "reset",
181 ConnectionAction::SendUnloadComplete => "unload",
182 ConnectionAction::Reconnect { .. } => "reconnect",
183 ConnectionAction::SendFailedVersionResponse => "send_version_response",
184 });
185 resp.field("state", state)
186 .field("next_action", next_action)
187 .field(
188 "assigned_monitors_bitmap",
189 format_args!("{:x}", self.inner.assigned_monitors.bitmap()),
190 )
191 .child("channels", |req| {
192 let mut resp = req.respond();
193 self.inner
194 .channels
195 .inspect(self.notifier, self.inner.get_version(), &mut resp);
196 for ((gpadl_id, offer_id), gpadl) in &self.inner.gpadls {
197 let channel = &self.inner.channels[*offer_id];
198 resp.field(
199 &channel_inspect_path(
200 &channel.offer,
201 format_args!("/gpadls/{}", gpadl_id.0),
202 ),
203 gpadl,
204 );
205 }
206 });
207 }
208}
209
210#[derive(Debug, Copy, Clone, Inspect)]
212struct MonitorPageGpaInfo {
213 gpas: MonitorPageGpas,
214 server_allocated: bool,
215}
216
217impl MonitorPageGpaInfo {
218 fn from_guest_gpas(gpas: MonitorPageGpas) -> Self {
220 Self {
221 gpas,
222 server_allocated: false,
223 }
224 }
225
226 fn from_server_gpas(gpas: MonitorPageGpas) -> Self {
228 Self {
229 gpas,
230 server_allocated: true,
231 }
232 }
233}
234
235#[derive(Debug, Copy, Clone, Inspect)]
236struct ConnectionInfo {
237 version: VersionInfo,
238 trusted: bool,
241 offers_sent: bool,
242 interrupt_page: Option<u64>,
243 monitor_page: Option<MonitorPageGpaInfo>,
244 target_message_vp: u32,
245 modifying: bool,
246 client_id: Guid,
247 paused: bool,
248}
249
250#[derive(Debug)]
252enum ConnectionState {
253 Disconnected,
254 Disconnecting {
255 next_action: ConnectionAction,
256 modify_sent: bool,
257 },
258 Connecting {
259 info: ConnectionInfo,
260 next_action: ConnectionAction,
261 },
262 Connected(ConnectionInfo),
263}
264
265impl ConnectionState {
266 fn check_version(&self, min_version: Version) -> bool {
268 matches!(self, ConnectionState::Connected(info) if info.version.version >= min_version)
269 }
270
271 fn check_feature_flags(&self, flags: impl Fn(FeatureFlags) -> bool) -> bool {
274 matches!(self, ConnectionState::Connected(info) if flags(info.version.feature_flags))
275 }
276
277 fn get_version(&self) -> Option<VersionInfo> {
278 if let ConnectionState::Connected(info) = self {
279 Some(info.version)
280 } else {
281 None
282 }
283 }
284
285 fn get_connected_info(&self) -> Option<&ConnectionInfo> {
287 if let ConnectionState::Connected(info) = self {
288 Some(info)
289 } else {
290 None
291 }
292 }
293
294 fn is_trusted(&self) -> bool {
295 match self {
296 ConnectionState::Connected(info) => info.trusted,
297 ConnectionState::Connecting { info, .. } => info.trusted,
298 _ => false,
299 }
300 }
301
302 fn is_paused(&self) -> bool {
303 if let ConnectionState::Connected(info) = self {
304 info.paused
305 } else {
306 false
307 }
308 }
309}
310
311#[derive(Debug, Copy, Clone)]
312enum ConnectionAction {
313 None,
314 Reset,
315 SendUnloadComplete,
316 Reconnect {
317 initiate_contact: InitiateContactRequest,
318 },
319 SendFailedVersionResponse,
320}
321
322#[derive(PartialEq, Eq, Debug, Copy, Clone)]
323pub enum MonitorPageRequest {
324 None,
325 Some(MonitorPageGpas),
326 Invalid,
327}
328
329#[derive(PartialEq, Eq, Debug, Copy, Clone)]
330pub struct InitiateContactRequest {
331 pub version_requested: u32,
332 pub target_message_vp: u32,
333 pub monitor_page: MonitorPageRequest,
334 pub target_sint: u8,
335 pub target_vtl: u8,
336 pub feature_flags: u32,
337 pub interrupt_page: Option<u64>,
338 pub client_id: Guid,
339 pub trusted: bool,
340}
341
342#[derive(Debug, Copy, Clone)]
343pub struct OpenRequest {
344 pub open_id: u32,
345 pub ring_buffer_gpadl_id: GpadlId,
346 pub target_vp: Option<u32>,
347 pub downstream_ring_buffer_page_offset: u32,
348 pub user_data: UserDefinedData,
349 pub guest_specified_interrupt_info: Option<SignalInfo>,
350 pub flags: protocol::OpenChannelFlags,
351}
352
353#[derive(Debug, Copy, Clone, Eq, PartialEq)]
354pub enum Update<T: std::fmt::Debug + Copy + Clone> {
355 Unchanged,
356 Reset,
357 Set(T),
358}
359
360impl<T: std::fmt::Debug + Copy + Clone> From<Option<T>> for Update<T> {
361 fn from(value: Option<T>) -> Self {
362 match value {
363 None => Self::Reset,
364 Some(value) => Self::Set(value),
365 }
366 }
367}
368
369#[derive(Debug, Copy, Clone, Eq, PartialEq)]
370pub struct ModifyConnectionRequest {
371 pub version: Option<VersionInfo>,
372 pub monitor_page: Update<MonitorPageGpas>,
373 pub interrupt_page: Update<u64>,
374 pub target_message_vp: Option<u32>,
375 pub notify_relay: bool,
376}
377
378impl Default for ModifyConnectionRequest {
380 fn default() -> Self {
381 Self {
382 version: None,
383 monitor_page: Update::Unchanged,
384 interrupt_page: Update::Unchanged,
385 target_message_vp: None,
386 notify_relay: true,
387 }
388 }
389}
390
391impl From<protocol::ModifyConnection> for ModifyConnectionRequest {
392 fn from(value: protocol::ModifyConnection) -> Self {
393 let monitor_page = if value.parent_to_child_monitor_page_gpa != 0 {
394 Update::Set(MonitorPageGpas {
395 parent_to_child: value.parent_to_child_monitor_page_gpa,
396 child_to_parent: value.child_to_parent_monitor_page_gpa,
397 })
398 } else {
399 Update::Reset
400 };
401
402 Self {
403 monitor_page,
404 ..Default::default()
405 }
406 }
407}
408
409#[derive(Debug, Copy, Clone)]
411pub enum ModifyConnectionResponse {
412 Supported(
418 protocol::ConnectionState,
419 FeatureFlags,
420 Option<MonitorPageGpas>,
421 ),
422 Unsupported,
424 Modified(protocol::ConnectionState),
427}
428
429#[derive(Debug, Copy, Clone)]
430pub enum ModifyState {
431 NotModifying,
432 Modifying { pending_target_vp: Option<u32> },
433}
434
435impl ModifyState {
436 pub fn is_modifying(&self) -> bool {
437 matches!(self, ModifyState::Modifying { .. })
438 }
439}
440
441#[derive(Debug, Copy, Clone)]
442pub struct SignalInfo {
443 pub event_flag: u16,
444 pub connection_id: u32,
445}
446
447#[derive(Debug, Copy, Clone, PartialEq, Eq)]
448enum RestoreState {
449 New,
451 Restoring,
455 Unmatched,
458 Restored,
460}
461
462#[derive(Debug, Clone)]
464enum ChannelState {
465 ClientReleased,
469
470 Closed,
472
473 Opening {
476 request: OpenRequest,
477 reserved_state: Option<ReservedState>,
478 },
479
480 Open {
482 params: OpenRequest,
483 modify_state: ModifyState,
484 reserved_state: Option<ReservedState>,
485 },
486
487 Closing {
489 params: OpenRequest,
490 reserved_state: Option<ReservedState>,
491 },
492
493 ClosingReopen {
496 params: OpenRequest,
497 request: OpenRequest,
498 },
499
500 Revoked,
502
503 Reoffered,
506
507 ClosingClientRelease,
510
511 OpeningClientRelease,
514}
515
516impl ChannelState {
517 fn is_released(&self) -> bool {
520 match self {
521 ChannelState::Closed
522 | ChannelState::Opening { .. }
523 | ChannelState::Open { .. }
524 | ChannelState::Closing { .. }
525 | ChannelState::ClosingReopen { .. }
526 | ChannelState::Revoked
527 | ChannelState::Reoffered => false,
528
529 ChannelState::ClientReleased
530 | ChannelState::ClosingClientRelease
531 | ChannelState::OpeningClientRelease => true,
532 }
533 }
534
535 fn is_revoked(&self) -> bool {
537 match self {
538 ChannelState::Revoked | ChannelState::Reoffered => true,
539
540 ChannelState::ClientReleased
541 | ChannelState::Closed
542 | ChannelState::Opening { .. }
543 | ChannelState::Open { .. }
544 | ChannelState::Closing { .. }
545 | ChannelState::ClosingReopen { .. }
546 | ChannelState::ClosingClientRelease
547 | ChannelState::OpeningClientRelease => false,
548 }
549 }
550
551 fn is_reserved(&self) -> bool {
552 match self {
553 ChannelState::Open {
555 reserved_state: Some(_),
556 ..
557 }
558 | ChannelState::Opening {
559 reserved_state: Some(_),
560 ..
561 }
562 | ChannelState::Closing {
563 reserved_state: Some(_),
564 ..
565 } => true,
566
567 ChannelState::Opening { .. }
568 | ChannelState::Open { .. }
569 | ChannelState::Closing { .. }
570 | ChannelState::ClientReleased
571 | ChannelState::Closed
572 | ChannelState::ClosingReopen { .. }
573 | ChannelState::Revoked
574 | ChannelState::Reoffered
575 | ChannelState::ClosingClientRelease
576 | ChannelState::OpeningClientRelease => false,
577 }
578 }
579}
580
581impl Display for ChannelState {
582 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
583 let state = match self {
584 Self::ClientReleased => "ClientReleased",
585 Self::Closed => "Closed",
586 Self::Opening { .. } => "Opening",
587 Self::Open { .. } => "Open",
588 Self::Closing { .. } => "Closing",
589 Self::ClosingReopen { .. } => "ClosingReopen",
590 Self::Revoked => "Revoked",
591 Self::Reoffered => "Reoffered",
592 Self::ClosingClientRelease => "ClosingClientRelease",
593 Self::OpeningClientRelease => "OpeningClientRelease",
594 };
595 write!(f, "{}", state)
596 }
597}
598
599#[derive(Debug, Clone, Default, mesh::MeshPayload)]
601pub enum MnfUsage {
602 #[default]
604 Disabled,
605 Enabled { latency: Duration },
607 Relayed { monitor_id: u8 },
610}
611
612impl MnfUsage {
613 pub fn is_enabled(&self) -> bool {
614 matches!(self, Self::Enabled { .. })
615 }
616
617 pub fn is_relayed(&self) -> bool {
618 matches!(self, Self::Relayed { .. })
619 }
620
621 pub fn enabled_and_then<T>(&self, f: impl FnOnce(Duration) -> Option<T>) -> Option<T> {
622 if let Self::Enabled { latency } = self {
623 f(*latency)
624 } else {
625 None
626 }
627 }
628}
629
630impl From<Option<Duration>> for MnfUsage {
631 fn from(value: Option<Duration>) -> Self {
632 match value {
633 None => Self::Disabled,
634 Some(latency) => Self::Enabled { latency },
635 }
636 }
637}
638
639#[derive(Debug, Clone, Default, mesh::MeshPayload)]
640pub struct OfferParamsInternal {
641 pub interface_name: String,
643 pub instance_id: Guid,
644 pub interface_id: Guid,
645 pub mmio_megabytes: u16,
646 pub mmio_megabytes_optional: u16,
647 pub subchannel_index: u16,
648 pub use_mnf: MnfUsage,
649 pub offer_order: Option<u64>,
650 pub flags: OfferFlags,
651 pub user_defined: UserDefinedData,
652}
653
654impl OfferParamsInternal {
655 pub fn key(&self) -> OfferKey {
657 OfferKey {
658 interface_id: self.interface_id,
659 instance_id: self.instance_id,
660 subchannel_index: self.subchannel_index,
661 }
662 }
663}
664
665impl From<OfferParams> for OfferParamsInternal {
666 fn from(value: OfferParams) -> Self {
667 let mut user_defined = UserDefinedData::new_zeroed();
668
669 let mut flags = OfferFlags::new()
672 .with_confidential_ring_buffer(true)
673 .with_confidential_external_memory(value.allow_confidential_external_memory);
674
675 match value.channel_type {
676 ChannelType::Device { pipe_packets } => {
677 if pipe_packets {
678 flags.set_named_pipe_mode(true);
679 user_defined.as_pipe_params_mut().pipe_type = protocol::PipeType::MESSAGE;
680 }
681 }
682 ChannelType::Interface {
683 user_defined: interface_user_defined,
684 } => {
685 flags.set_enumerate_device_interface(true);
686 user_defined = interface_user_defined;
687 }
688 ChannelType::Pipe { message_mode } => {
689 flags.set_enumerate_device_interface(true);
690 flags.set_named_pipe_mode(true);
691 user_defined.as_pipe_params_mut().pipe_type = if message_mode {
692 protocol::PipeType::MESSAGE
693 } else {
694 protocol::PipeType::BYTE
695 };
696 }
697 ChannelType::HvSocket {
698 is_connect,
699 is_for_container,
700 silo_id,
701 } => {
702 flags.set_enumerate_device_interface(true);
703 flags.set_tlnpi_provider(true);
704 flags.set_named_pipe_mode(true);
705 *user_defined.as_hvsock_params_mut() = protocol::HvsockUserDefinedParameters::new(
706 is_connect,
707 is_for_container,
708 silo_id,
709 );
710 }
711 };
712
713 Self {
714 interface_name: value.interface_name,
715 instance_id: value.instance_id,
716 interface_id: value.interface_id,
717 mmio_megabytes: value.mmio_megabytes,
718 mmio_megabytes_optional: value.mmio_megabytes_optional,
719 subchannel_index: value.subchannel_index,
720 use_mnf: value.mnf_interrupt_latency.into(),
721 offer_order: value.offer_order,
722 user_defined,
723 flags,
724 }
725 }
726}
727
728#[derive(Debug, Copy, Clone, Inspect, PartialEq, Eq)]
729pub struct ConnectionTarget {
730 pub vp: u32,
731 pub sint: u8,
732}
733
734#[derive(Debug, Copy, Clone, PartialEq, Eq)]
735pub enum MessageTarget {
736 Default,
737 ReservedChannel(OfferId, ConnectionTarget),
738 Custom(ConnectionTarget),
739}
740
741impl MessageTarget {
742 pub fn for_offer(offer_id: OfferId, reserved_state: &Option<ReservedState>) -> Self {
743 if let Some(state) = reserved_state {
744 Self::ReservedChannel(offer_id, state.target)
745 } else {
746 Self::Default
747 }
748 }
749}
750
751#[derive(Debug, Copy, Clone)]
752pub struct ReservedState {
753 version: VersionInfo,
754 target: ConnectionTarget,
755}
756
757#[derive(Debug)]
759struct Channel {
760 info: Option<OfferedInfo>,
761 offer: OfferParamsInternal,
762 state: ChannelState,
763 restore_state: RestoreState,
764}
765
766#[derive(Debug, Copy, Clone)]
767struct OfferedInfo {
768 channel_id: ChannelId,
769 connection_id: u32,
770 monitor_id: Option<MonitorId>,
771}
772
773impl Channel {
774 fn inspect_state(&self, resp: &mut inspect::Response<'_>) {
775 let mut target_vp = None;
776 let mut event_flag = None;
777 let mut connection_id = None;
778 let mut reserved_target = None;
779 let state = match &self.state {
780 ChannelState::ClientReleased => "client_released",
781 ChannelState::Closed => "closed",
782 ChannelState::Opening { reserved_state, .. } => {
783 reserved_target = reserved_state.map(|state| state.target);
784 "opening"
785 }
786 ChannelState::Open {
787 params,
788 reserved_state,
789 ..
790 } => {
791 target_vp = Some(params.target_vp);
792 if let Some(id) = params.guest_specified_interrupt_info {
793 event_flag = Some(id.event_flag);
794 connection_id = Some(id.connection_id);
795 }
796 reserved_target = reserved_state.map(|state| state.target);
797 "open"
798 }
799 ChannelState::Closing { reserved_state, .. } => {
800 reserved_target = reserved_state.map(|state| state.target);
801 "closing"
802 }
803 ChannelState::ClosingReopen { .. } => "closing_reopen",
804 ChannelState::Revoked => "revoked",
805 ChannelState::Reoffered => "reoffered",
806 ChannelState::ClosingClientRelease => "closing_client_release",
807 ChannelState::OpeningClientRelease => "opening_client_release",
808 };
809 let restore_state = match self.restore_state {
810 RestoreState::New => "new",
811 RestoreState::Restoring => "restoring",
812 RestoreState::Restored => "restored",
813 RestoreState::Unmatched => "unmatched",
814 };
815 if let Some(info) = &self.info {
816 resp.field("channel_id", info.channel_id.0)
817 .field("offered_connection_id", info.connection_id)
818 .field("monitor_id", info.monitor_id.map(|id| id.0));
819 }
820 resp.field("state", state)
821 .field("restore_state", restore_state)
822 .field("interface_name", self.offer.interface_name.clone())
823 .display("instance_id", &self.offer.instance_id)
824 .display("interface_id", &self.offer.interface_id)
825 .field("mmio_megabytes", self.offer.mmio_megabytes)
826 .field("target_vp", target_vp)
827 .field("guest_specified_event_flag", event_flag)
828 .field("guest_specified_connection_id", connection_id)
829 .field("reserved_connection_target", reserved_target)
830 .binary("offer_flags", self.offer.flags.into_bits());
831 }
832
833 fn handled_monitor_info(&self) -> Option<MonitorInfo> {
843 self.offer.use_mnf.enabled_and_then(|latency| {
844 if self.state.is_reserved() {
845 None
846 } else {
847 self.info.and_then(|info| {
848 info.monitor_id.map(|monitor_id| MonitorInfo {
849 monitor_id,
850 latency,
851 })
852 })
853 }
854 })
855 }
856
857 fn prepare_channel(
860 &mut self,
861 offer_id: OfferId,
862 assigned_channels: &mut AssignedChannels,
863 assigned_monitors: &mut AssignedMonitors,
864 ) {
865 assert!(self.info.is_none());
866
867 let entry = assigned_channels
869 .allocate()
870 .expect("there are enough channel IDs for everything in ChannelList");
871
872 let channel_id = entry.id();
873 entry.insert(offer_id);
874 let connection_id = ConnectionId::new(channel_id.0, assigned_channels.vtl, VMBUS_SINT);
875
876 let monitor_id = match self.offer.use_mnf {
881 MnfUsage::Enabled { .. } => {
882 let monitor_id = assigned_monitors.assign_monitor();
883 if monitor_id.is_none() {
884 tracelimit::warn_ratelimited!("Out of monitor IDs.");
885 }
886
887 monitor_id
888 }
889 MnfUsage::Relayed { monitor_id } => Some(MonitorId(monitor_id)),
890 MnfUsage::Disabled => None,
891 };
892
893 self.info = Some(OfferedInfo {
894 channel_id,
895 connection_id: connection_id.0,
896 monitor_id,
897 });
898 }
899
900 fn release_channel(
902 &mut self,
903 offer_id: OfferId,
904 assigned_channels: &mut AssignedChannels,
905 assigned_monitors: &mut AssignedMonitors,
906 ) {
907 if let Some(info) = self.info.take() {
908 assigned_channels.free(info.channel_id, offer_id);
909
910 if let Some(monitor_id) = info.monitor_id {
912 if self.offer.use_mnf.is_enabled() {
913 assigned_monitors.release_monitor(monitor_id);
914 }
915 }
916 }
917 }
918}
919
920#[derive(Debug)]
921struct AssignedChannels {
922 assignments: Vec<Option<OfferId>>,
923 vtl: Vtl,
924 reserved_offset: usize,
925 count_in_reserved_range: usize,
927}
928
929impl AssignedChannels {
930 fn new(vtl: Vtl, channel_id_offset: u16) -> Self {
931 Self {
932 assignments: vec![None; MAX_CHANNELS],
933 vtl,
934 reserved_offset: channel_id_offset as usize,
935 count_in_reserved_range: 0,
936 }
937 }
938
939 fn allowable_channel_count(&self) -> usize {
940 MAX_CHANNELS - self.reserved_offset + self.count_in_reserved_range
941 }
942
943 fn get(&self, channel_id: ChannelId) -> Option<OfferId> {
944 self.assignments
945 .get(Self::index(channel_id))
946 .copied()
947 .flatten()
948 }
949
950 fn set(&mut self, channel_id: ChannelId) -> Result<AssignmentEntry<'_>, OfferError> {
951 let index = Self::index(channel_id);
952 if self
953 .assignments
954 .get(index)
955 .ok_or(OfferError::InvalidChannelId(channel_id))?
956 .is_some()
957 {
958 return Err(OfferError::ChannelIdInUse(channel_id));
959 }
960 Ok(AssignmentEntry { list: self, index })
961 }
962
963 fn allocate(&mut self) -> Option<AssignmentEntry<'_>> {
964 let index = self.reserved_offset
965 + self.assignments[self.reserved_offset..]
966 .iter()
967 .position(|x| x.is_none())?;
968 Some(AssignmentEntry { list: self, index })
969 }
970
971 fn free(&mut self, channel_id: ChannelId, offer_id: OfferId) {
972 let index = Self::index(channel_id);
973 let slot = &mut self.assignments[index];
974 assert_eq!(slot.take(), Some(offer_id));
975 if index < self.reserved_offset {
976 self.count_in_reserved_range -= 1;
977 }
978 }
979
980 fn index(channel_id: ChannelId) -> usize {
981 channel_id.0.wrapping_sub(1) as usize
982 }
983}
984
985struct AssignmentEntry<'a> {
986 list: &'a mut AssignedChannels,
987 index: usize,
988}
989
990impl AssignmentEntry<'_> {
991 pub fn id(&self) -> ChannelId {
992 ChannelId(self.index as u32 + 1)
993 }
994
995 pub fn insert(self, offer_id: OfferId) {
996 assert!(
997 self.list.assignments[self.index]
998 .replace(offer_id)
999 .is_none()
1000 );
1001
1002 if self.index < self.list.reserved_offset {
1003 self.list.count_in_reserved_range += 1;
1004 }
1005 }
1006}
1007
1008struct ChannelList {
1009 channels: Slab<Channel>,
1010}
1011
1012fn channel_inspect_path(offer: &OfferParamsInternal, suffix: std::fmt::Arguments<'_>) -> String {
1013 if offer.subchannel_index == 0 {
1014 format!("{}{}", offer.instance_id, suffix)
1015 } else {
1016 format!(
1017 "{}/subchannels/{}{}",
1018 offer.instance_id, offer.subchannel_index, suffix
1019 )
1020 }
1021}
1022
1023impl ChannelList {
1024 fn inspect(
1025 &self,
1026 notifier: &impl Notifier,
1027 version: Option<VersionInfo>,
1028 resp: &mut inspect::Response<'_>,
1029 ) {
1030 for (offer_id, channel) in self.iter() {
1031 resp.child(
1032 &channel_inspect_path(&channel.offer, format_args!("")),
1033 |req| {
1034 let mut resp = req.respond();
1035 channel.inspect_state(&mut resp);
1036
1037 resp.merge(inspect::adhoc(|req| {
1041 if !matches!(channel.state, ChannelState::Revoked) {
1042 notifier.inspect(version, offer_id, req);
1043 }
1044 }));
1045 },
1046 );
1047 }
1048 }
1049}
1050
1051pub const MAX_CHANNELS: usize = 2047;
1054
1055impl ChannelList {
1056 fn new() -> Self {
1057 Self {
1058 channels: Slab::new(),
1059 }
1060 }
1061
1062 fn len(&self) -> usize {
1064 self.channels.len()
1065 }
1066
1067 fn offer(&mut self, new_channel: Channel) -> OfferId {
1069 OfferId(self.channels.insert(new_channel))
1070 }
1071
1072 fn remove(&mut self, offer_id: OfferId) {
1074 let channel = self.channels.remove(offer_id.0);
1075 assert!(channel.info.is_none());
1076 }
1077
1078 fn get_by_channel_id_mut(
1080 &mut self,
1081 assigned_channels: &AssignedChannels,
1082 channel_id: ChannelId,
1083 ) -> Result<(OfferId, &mut Channel), ChannelError> {
1084 let offer_id = assigned_channels
1085 .get(channel_id)
1086 .ok_or(ChannelError::UnknownChannelId)?;
1087 let channel = &mut self[offer_id];
1088 if channel.state.is_released() {
1089 return Err(ChannelError::ChannelReleased);
1090 }
1091 assert_eq!(
1092 channel.info.as_ref().map(|info| info.channel_id),
1093 Some(channel_id)
1094 );
1095 Ok((offer_id, channel))
1096 }
1097
1098 fn get_by_channel_id(
1100 &self,
1101 assigned_channels: &AssignedChannels,
1102 channel_id: ChannelId,
1103 ) -> Result<(OfferId, &Channel), ChannelError> {
1104 let offer_id = assigned_channels
1105 .get(channel_id)
1106 .ok_or(ChannelError::UnknownChannelId)?;
1107 let channel = &self[offer_id];
1108 if channel.state.is_released() {
1109 return Err(ChannelError::ChannelReleased);
1110 }
1111 assert_eq!(
1112 channel.info.as_ref().map(|info| info.channel_id),
1113 Some(channel_id)
1114 );
1115 Ok((offer_id, channel))
1116 }
1117
1118 fn get_by_key_mut(&mut self, key: &OfferKey) -> Option<(OfferId, &mut Channel)> {
1121 for (offer_id, channel) in self.iter_mut() {
1122 if channel.offer.instance_id == key.instance_id
1123 && channel.offer.interface_id == key.interface_id
1124 && channel.offer.subchannel_index == key.subchannel_index
1125 {
1126 return Some((offer_id, channel));
1127 }
1128 }
1129 None
1130 }
1131
1132 fn iter(&self) -> impl Iterator<Item = (OfferId, &Channel)> {
1134 self.channels
1135 .iter()
1136 .map(|(id, channel)| (OfferId(id), channel))
1137 }
1138
1139 fn iter_mut(&mut self) -> impl Iterator<Item = (OfferId, &mut Channel)> {
1141 self.channels
1142 .iter_mut()
1143 .map(|(id, channel)| (OfferId(id), channel))
1144 }
1145
1146 fn retain<F>(&mut self, mut f: F)
1148 where
1149 F: FnMut(OfferId, &mut Channel) -> bool,
1150 {
1151 self.channels.retain(|id, channel| {
1152 let retain = f(OfferId(id), channel);
1153 if !retain {
1154 assert!(channel.info.is_none());
1155 }
1156 retain
1157 })
1158 }
1159}
1160
1161impl Index<OfferId> for ChannelList {
1162 type Output = Channel;
1163
1164 fn index(&self, offer_id: OfferId) -> &Self::Output {
1165 &self.channels[offer_id.0]
1166 }
1167}
1168
1169impl IndexMut<OfferId> for ChannelList {
1170 fn index_mut(&mut self, offer_id: OfferId) -> &mut Self::Output {
1171 &mut self.channels[offer_id.0]
1172 }
1173}
1174
1175#[derive(Debug, Inspect)]
1177struct Gpadl {
1178 count: u16,
1179 #[inspect(skip)]
1180 buf: Vec<u64>,
1181 state: GpadlState,
1182}
1183
1184#[derive(Debug, Copy, Clone, PartialEq, Eq, Inspect)]
1185enum GpadlState {
1186 InProgress,
1188 Offered,
1190 OfferedTearingDown,
1193 Accepted,
1195 TearingDown,
1197}
1198
1199impl Gpadl {
1200 fn new(count: u16, len: usize) -> Self {
1203 Self {
1204 state: GpadlState::InProgress,
1205 count,
1206 buf: Vec::with_capacity(len),
1207 }
1208 }
1209
1210 fn append(&mut self, data: &[u8]) -> Result<bool, ChannelError> {
1212 if self.state == GpadlState::InProgress {
1213 let buf = &mut self.buf;
1214 let len = min(data.len() & !7, (buf.capacity() - buf.len()) * 8);
1219 let data = &data[..len];
1220 let start = buf.len();
1221 buf.resize(buf.len() + data.len() / 8, 0);
1222 buf[start..].as_mut_bytes().copy_from_slice(data);
1223 Ok(if buf.len() == buf.capacity() {
1224 gparange::validate_gpa_ranges(self.count as usize, buf)
1225 .map_err(ChannelError::InvalidGpaRange)?;
1226 self.state = GpadlState::Offered;
1227 true
1228 } else {
1229 false
1230 })
1231 } else {
1232 Err(ChannelError::GpadlAlreadyComplete)
1233 }
1234 }
1235}
1236
1237#[derive(Debug, Copy, Clone)]
1239pub struct OpenParams {
1240 pub open_data: OpenData,
1241 pub connection_id: u32,
1242 pub event_flag: u16,
1243 pub monitor_info: Option<MonitorInfo>,
1244 pub flags: protocol::OpenChannelFlags,
1245 pub reserved_target: Option<ConnectionTarget>,
1246 pub channel_id: ChannelId,
1247}
1248
1249impl OpenParams {
1250 fn from_request(
1251 info: &OfferedInfo,
1252 request: &OpenRequest,
1253 monitor_info: Option<MonitorInfo>,
1254 reserved_target: Option<ConnectionTarget>,
1255 ) -> Self {
1256 let (event_flag, connection_id) = if let Some(id) = request.guest_specified_interrupt_info {
1259 (id.event_flag, id.connection_id)
1260 } else {
1261 (info.channel_id.0 as u16, info.connection_id)
1262 };
1263
1264 Self {
1265 open_data: OpenData {
1266 target_vp: request.target_vp,
1267 ring_offset: request.downstream_ring_buffer_page_offset,
1268 ring_gpadl_id: request.ring_buffer_gpadl_id,
1269 user_data: request.user_data,
1270 event_flag,
1271 connection_id,
1272 },
1273 connection_id,
1274 event_flag,
1275 monitor_info: request.target_vp.and(monitor_info),
1277 flags: request.flags.with_unused(0),
1278 reserved_target,
1279 channel_id: info.channel_id,
1280 }
1281 }
1282}
1283
1284#[derive(Debug)]
1286pub enum Action {
1287 Open(OpenParams, VersionInfo),
1288 Close,
1289 Gpadl(GpadlId, u16, Vec<u64>),
1290 TeardownGpadl {
1291 gpadl_id: GpadlId,
1292 post_restore: bool,
1293 },
1294 Modify {
1295 target_vp: u32,
1296 },
1297}
1298
1299static SUPPORTED_VERSIONS: &[Version] = &[
1301 Version::V1,
1302 Version::Win7,
1303 Version::Win8,
1304 Version::Win8_1,
1305 Version::Win10,
1306 Version::Win10Rs3_0,
1307 Version::Win10Rs3_1,
1308 Version::Win10Rs4,
1309 Version::Win10Rs5,
1310 Version::Iron,
1311 Version::Copper,
1312];
1313
1314const SUPPORTED_FEATURE_FLAGS: FeatureFlags = FeatureFlags::new()
1317 .with_guest_specified_signal_parameters(true)
1318 .with_channel_interrupt_redirection(true)
1319 .with_modify_connection(true)
1320 .with_client_id(true)
1321 .with_pause_resume(true)
1322 .with_server_specified_monitor_pages(true);
1323
1324pub trait Notifier: Send {
1326 fn notify(&mut self, offer_id: OfferId, action: Action);
1328
1329 fn forward_unhandled(&mut self, request: InitiateContactRequest);
1331
1332 fn modify_connection(&mut self, request: ModifyConnectionRequest) -> anyhow::Result<()>;
1338
1339 fn inspect(&self, version: Option<VersionInfo>, offer_id: OfferId, req: inspect::Request<'_>) {
1341 let _ = (version, offer_id, req);
1342 }
1343
1344 #[must_use]
1347 fn send_message(&mut self, message: &OutgoingMessage, target: MessageTarget) -> bool;
1348
1349 fn notify_hvsock(&mut self, request: &HvsockConnectRequest);
1351
1352 fn reset_complete(&mut self);
1354
1355 fn unload_complete(&mut self);
1357}
1358
1359impl Server {
1360 pub fn new(
1362 vtl: Vtl,
1363 child_connection_id: u32,
1364 channel_id_offset: u16,
1365 use_absolute_channel_order: bool,
1366 ) -> Self {
1367 Server {
1368 state: ConnectionState::Disconnected,
1369 channels: ChannelList::new(),
1370 assigned_channels: AssignedChannels::new(vtl, channel_id_offset),
1371 assigned_monitors: AssignedMonitors::new(),
1372 gpadls: Default::default(),
1373 incomplete_gpadls: Default::default(),
1374 child_connection_id,
1375 max_version: None,
1376 delayed_max_version: None,
1377 pending_messages: PendingMessages(VecDeque::new()),
1378 require_server_allocated_mnf: false,
1379 use_absolute_channel_order,
1380 }
1381 }
1382
1383 pub fn with_notifier<'a, T: Notifier>(
1385 &'a mut self,
1386 notifier: &'a mut T,
1387 ) -> ServerWithNotifier<'a, T> {
1388 self.validate();
1389 ServerWithNotifier {
1390 inner: self,
1391 notifier,
1392 }
1393 }
1394
1395 pub fn set_require_server_allocated_mnf(&mut self, require: bool) {
1398 self.require_server_allocated_mnf = require;
1399 }
1400
1401 fn validate(&self) {
1402 #[cfg(debug_assertions)]
1403 for (_, channel) in self.channels.iter() {
1404 let should_have_info = !channel.state.is_released();
1405 if channel.info.is_some() != should_have_info {
1406 panic!("channel invariant violation: {channel:?}");
1407 }
1408 }
1409 }
1410
1411 pub fn set_compatibility_version(&mut self, version: MaxVersionInfo, delay: bool) {
1413 if delay {
1414 self.delayed_max_version = Some(version)
1415 } else {
1416 tracing::info!(?version, "Limiting VmBus connections to version");
1417 self.max_version = Some(version);
1418 }
1419 }
1420
1421 pub fn channel_gpadls(&self, offer_id: OfferId) -> Vec<RestoredGpadl> {
1422 self.gpadls
1423 .iter()
1424 .filter_map(|(&(gpadl_id, gpadl_offer_id), gpadl)| {
1425 if offer_id != gpadl_offer_id {
1426 return None;
1427 }
1428 let accepted = match gpadl.state {
1429 GpadlState::Offered | GpadlState::OfferedTearingDown => false,
1430 GpadlState::Accepted => true,
1431 GpadlState::InProgress | GpadlState::TearingDown => return None,
1432 };
1433 Some(RestoredGpadl {
1434 request: GpadlRequest {
1435 id: gpadl_id,
1436 count: gpadl.count,
1437 buf: gpadl.buf.clone(),
1438 },
1439 accepted,
1440 })
1441 })
1442 .collect()
1443 }
1444
1445 pub fn get_version(&self) -> Option<VersionInfo> {
1446 self.state.get_version()
1447 }
1448
1449 pub fn get_restore_open_params(&self, offer_id: OfferId) -> Result<OpenParams, RestoreError> {
1450 let channel = &self.channels[offer_id];
1451
1452 match channel.restore_state {
1454 RestoreState::New => {
1455 return Err(RestoreError::MissingChannel(channel.offer.key()));
1459 }
1460 RestoreState::Restoring => {}
1461 RestoreState::Unmatched => unreachable!(),
1462 RestoreState::Restored => {
1463 return Err(RestoreError::AlreadyRestored(channel.offer.key()));
1464 }
1465 }
1466
1467 let info = channel
1468 .info
1469 .ok_or_else(|| RestoreError::MissingChannel(channel.offer.key()))?;
1470
1471 let (request, reserved_state) = match channel.state {
1472 ChannelState::Closed => {
1473 return Err(RestoreError::MismatchedOpenState(channel.offer.key()));
1474 }
1475 ChannelState::Closing { params, .. } | ChannelState::ClosingReopen { params, .. } => {
1476 (params, None)
1477 }
1478 ChannelState::Opening {
1479 request,
1480 reserved_state,
1481 } => (request, reserved_state),
1482 ChannelState::Open {
1483 params,
1484 reserved_state,
1485 ..
1486 } => (params, reserved_state),
1487 ChannelState::ClientReleased | ChannelState::Reoffered => {
1488 return Err(RestoreError::MissingChannel(channel.offer.key()));
1489 }
1490 ChannelState::Revoked
1491 | ChannelState::ClosingClientRelease
1492 | ChannelState::OpeningClientRelease => unreachable!(),
1493 };
1494
1495 Ok(OpenParams::from_request(
1496 &info,
1497 &request,
1498 channel.handled_monitor_info(),
1499 reserved_state.map(|state| state.target),
1500 ))
1501 }
1502
1503 pub fn has_pending_messages(&self) -> bool {
1505 !self.pending_messages.0.is_empty() && !self.state.is_paused()
1506 }
1507
1508 pub fn poll_flush_pending_messages(
1510 &mut self,
1511 mut send: impl FnMut(&OutgoingMessage) -> Poll<()>,
1512 ) -> Poll<()> {
1513 if !self.state.is_paused() {
1514 while let Some(message) = self.pending_messages.0.front() {
1515 ready!(send(message));
1516 self.pending_messages.0.pop_front();
1517 }
1518 }
1519
1520 Poll::Ready(())
1521 }
1522}
1523
1524impl<'a, N: 'a + Notifier> ServerWithNotifier<'a, N> {
1525 pub fn restore_channel(&mut self, offer_id: OfferId, open: bool) -> Result<(), RestoreError> {
1531 let channel = &mut self.inner.channels[offer_id];
1532
1533 match channel.restore_state {
1536 RestoreState::New => {
1537 if open {
1541 return Err(RestoreError::MissingChannel(channel.offer.key()));
1542 } else {
1543 return Ok(());
1544 }
1545 }
1546 RestoreState::Restoring => {}
1547 RestoreState::Unmatched => unreachable!(),
1548 RestoreState::Restored => {
1549 return Err(RestoreError::AlreadyRestored(channel.offer.key()));
1550 }
1551 }
1552
1553 let info = channel
1554 .info
1555 .ok_or_else(|| RestoreError::MissingChannel(channel.offer.key()))?;
1556
1557 if let Some(monitor_info) = channel.handled_monitor_info() {
1558 if !self
1559 .inner
1560 .assigned_monitors
1561 .claim_monitor(monitor_info.monitor_id)
1562 {
1563 return Err(RestoreError::DuplicateMonitorId(monitor_info.monitor_id.0));
1564 }
1565 }
1566
1567 if open {
1568 match channel.state {
1569 ChannelState::Closed => {
1570 return Err(RestoreError::MismatchedOpenState(channel.offer.key()));
1571 }
1572 ChannelState::Closing { .. } | ChannelState::ClosingReopen { .. } => {
1573 self.notifier.notify(offer_id, Action::Close);
1574 }
1575 ChannelState::Opening {
1576 request,
1577 reserved_state,
1578 } => {
1579 self.inner
1580 .pending_messages
1581 .sender(self.notifier, self.inner.state.is_paused())
1582 .send_open_result(
1583 info.channel_id,
1584 &request,
1585 protocol::STATUS_SUCCESS,
1586 MessageTarget::for_offer(offer_id, &reserved_state),
1587 );
1588 channel.state = ChannelState::Open {
1589 params: request,
1590 modify_state: ModifyState::NotModifying,
1591 reserved_state,
1592 };
1593 }
1594 ChannelState::Open { .. } => {}
1595 ChannelState::ClientReleased | ChannelState::Reoffered => {
1596 return Err(RestoreError::MissingChannel(channel.offer.key()));
1597 }
1598 ChannelState::Revoked
1599 | ChannelState::ClosingClientRelease
1600 | ChannelState::OpeningClientRelease => unreachable!(),
1601 };
1602 } else {
1603 match channel.state {
1604 ChannelState::Closed => {}
1605 ChannelState::Reoffered => {}
1610 ChannelState::Closing { .. } => {
1611 channel.state = ChannelState::Closed;
1612 }
1613 ChannelState::ClosingReopen { request, .. } => {
1614 self.notifier.notify(
1615 offer_id,
1616 Action::Open(
1617 OpenParams::from_request(
1618 &info,
1619 &request,
1620 channel.handled_monitor_info(),
1621 None,
1622 ),
1623 self.inner.state.get_version().expect("must be connected"),
1624 ),
1625 );
1626 channel.state = ChannelState::Opening {
1627 request,
1628 reserved_state: None,
1629 };
1630 }
1631 ChannelState::Opening {
1632 request,
1633 reserved_state,
1634 } => {
1635 self.notifier.notify(
1636 offer_id,
1637 Action::Open(
1638 OpenParams::from_request(
1639 &info,
1640 &request,
1641 channel.handled_monitor_info(),
1642 reserved_state.map(|state| state.target),
1643 ),
1644 self.inner.state.get_version().expect("must be connected"),
1645 ),
1646 );
1647 }
1648 ChannelState::Open { .. } => {
1649 return Err(RestoreError::MismatchedOpenState(channel.offer.key()));
1650 }
1651 ChannelState::ClientReleased => {
1652 return Err(RestoreError::MissingChannel(channel.offer.key()));
1653 }
1654 ChannelState::Revoked
1655 | ChannelState::ClosingClientRelease
1656 | ChannelState::OpeningClientRelease => unreachable!(),
1657 }
1658 }
1659
1660 channel.restore_state = RestoreState::Restored;
1661 Ok(())
1662 }
1663
1664 pub fn revoke_unclaimed_channels(&mut self) {
1667 for (offer_id, channel) in self.inner.channels.iter_mut() {
1668 match channel.restore_state {
1669 RestoreState::Restored => {
1670 }
1672 RestoreState::New => {
1673 if let ConnectionState::Connected(info) = &self.inner.state {
1678 if info.offers_sent && matches!(channel.state, ChannelState::ClientReleased)
1679 {
1680 channel.prepare_channel(
1681 offer_id,
1682 &mut self.inner.assigned_channels,
1683 &mut self.inner.assigned_monitors,
1684 );
1685 channel.state = ChannelState::Closed;
1686 self.inner
1687 .pending_messages
1688 .sender(self.notifier, self.inner.state.is_paused())
1689 .send_offer(channel, info);
1690 }
1691 }
1692 }
1693 RestoreState::Restoring => {
1694 let retain = revoke(
1698 self.inner
1699 .pending_messages
1700 .sender(self.notifier, self.inner.state.is_paused()),
1701 offer_id,
1702 channel,
1703 &mut self.inner.gpadls,
1704 );
1705 assert!(retain, "channel has not been released");
1706 channel.state = ChannelState::Reoffered;
1707 }
1708 RestoreState::Unmatched => {
1709 let retain = revoke(
1712 self.inner
1713 .pending_messages
1714 .sender(self.notifier, self.inner.state.is_paused()),
1715 offer_id,
1716 channel,
1717 &mut self.inner.gpadls,
1718 );
1719 assert!(retain, "channel has not been released");
1720 }
1721 }
1722 }
1723
1724 for (&(gpadl_id, offer_id), gpadl) in self.inner.gpadls.iter_mut() {
1726 match gpadl.state {
1727 GpadlState::InProgress | GpadlState::Accepted => {}
1728 GpadlState::Offered => {
1729 self.notifier.notify(
1730 offer_id,
1731 Action::Gpadl(gpadl_id, gpadl.count, gpadl.buf.clone()),
1732 );
1733 }
1734 GpadlState::TearingDown => {
1735 self.notifier.notify(
1736 offer_id,
1737 Action::TeardownGpadl {
1738 gpadl_id,
1739 post_restore: true,
1740 },
1741 );
1742 }
1743 GpadlState::OfferedTearingDown => unreachable!(),
1744 }
1745 }
1746
1747 self.check_disconnected();
1748 }
1749
1750 pub fn reset(&mut self) {
1755 assert!(!self.is_resetting());
1756 if self.request_disconnect(ConnectionAction::Reset) {
1757 self.complete_reset();
1758 }
1759 }
1760
1761 fn complete_reset(&mut self) {
1762 for (_, channel) in self.inner.channels.iter_mut() {
1764 channel.restore_state = RestoreState::New;
1765 }
1766 self.inner.pending_messages.0.clear();
1767 self.notifier.reset_complete();
1768 }
1769
1770 pub fn offer_channel(&mut self, offer: OfferParamsInternal) -> Result<OfferId, OfferError> {
1772 if let Some((offer_id, channel)) = self.inner.channels.get_by_key_mut(&offer.key()) {
1774 if channel.restore_state != RestoreState::Unmatched
1778 && !matches!(channel.state, ChannelState::Revoked)
1779 {
1780 return Err(OfferError::AlreadyExists(offer.key()));
1781 }
1782
1783 let info = channel.info.expect("assigned");
1784 if channel.restore_state == RestoreState::Unmatched {
1785 tracing::debug!(
1786 offer_id = offer_id.0,
1787 key = %channel.offer.key(),
1788 "matched channel"
1789 );
1790
1791 assert!(!matches!(channel.state, ChannelState::Revoked));
1792 channel.restore_state = RestoreState::Restoring;
1796
1797 if let MnfUsage::Relayed { monitor_id } = offer.use_mnf {
1800 if info.monitor_id != Some(MonitorId(monitor_id)) {
1801 return Err(OfferError::MismatchedMonitorId(
1802 info.monitor_id,
1803 MonitorId(monitor_id),
1804 ));
1805 }
1806 }
1807 } else {
1808 channel.state = ChannelState::Reoffered;
1812 tracing::info!(?offer_id, key = %channel.offer.key(), "channel marked for reoffer");
1813 }
1814
1815 channel.offer = offer;
1816 return Ok(offer_id);
1817 }
1818
1819 let mut connected_info = None;
1820 let state = match &self.inner.state {
1821 ConnectionState::Connected(info) => {
1822 if info.offers_sent {
1823 connected_info = Some(info);
1824 ChannelState::Closed
1825 } else {
1826 ChannelState::ClientReleased
1827 }
1828 }
1829 ConnectionState::Connecting { .. }
1830 | ConnectionState::Disconnecting { .. }
1831 | ConnectionState::Disconnected => ChannelState::ClientReleased,
1832 };
1833
1834 if self.inner.channels.len() >= self.inner.assigned_channels.allowable_channel_count() {
1836 return Err(OfferError::TooManyChannels);
1837 }
1838
1839 let key = offer.key();
1840 let confidential_ring_buffer = offer.flags.confidential_ring_buffer();
1841 let confidential_external_memory = offer.flags.confidential_external_memory();
1842 let channel = Channel {
1843 info: None,
1844 offer,
1845 state,
1846 restore_state: RestoreState::New,
1847 };
1848
1849 let offer_id = self.inner.channels.offer(channel);
1850 if let Some(info) = connected_info {
1851 let channel = &mut self.inner.channels[offer_id];
1852 channel.prepare_channel(
1853 offer_id,
1854 &mut self.inner.assigned_channels,
1855 &mut self.inner.assigned_monitors,
1856 );
1857
1858 self.inner
1859 .pending_messages
1860 .sender(self.notifier, self.inner.state.is_paused())
1861 .send_offer(channel, info);
1862 }
1863
1864 tracing::info!(?offer_id, %key, confidential_ring_buffer, confidential_external_memory, "new channel");
1865 Ok(offer_id)
1866 }
1867
1868 pub fn revoke_channel(&mut self, offer_id: OfferId) {
1870 let channel = &mut self.inner.channels[offer_id];
1871 let retain = revoke(
1872 self.inner
1873 .pending_messages
1874 .sender(self.notifier, self.inner.state.is_paused()),
1875 offer_id,
1876 channel,
1877 &mut self.inner.gpadls,
1878 );
1879 if !retain {
1880 self.inner.channels.remove(offer_id);
1881 }
1882
1883 self.check_disconnected();
1884 }
1885
1886 pub fn open_complete(&mut self, offer_id: OfferId, result: i32) {
1888 let channel = &mut self.inner.channels[offer_id];
1889 tracing::debug!(offer_id = offer_id.0, key = %channel.offer.key(), result, "open complete");
1890
1891 match channel.state {
1892 ChannelState::Opening {
1893 request,
1894 reserved_state,
1895 } => {
1896 let channel_id = channel.info.expect("assigned").channel_id;
1897 if result >= 0 {
1898 tracelimit::info_ratelimited!(
1899 offer_id = offer_id.0,
1900 channel_id = channel_id.0,
1901 key = %channel.offer.key(),
1902 result,
1903 "opened channel"
1904 );
1905 } else {
1906 tracelimit::error_ratelimited!(
1908 offer_id = offer_id.0,
1909 channel_id = channel_id.0,
1910 key = %channel.offer.key(),
1911 result,
1912 "failed to open channel"
1913 );
1914 }
1915
1916 self.inner
1917 .pending_messages
1918 .sender(self.notifier, self.inner.state.is_paused())
1919 .send_open_result(
1920 channel_id,
1921 &request,
1922 result,
1923 MessageTarget::for_offer(offer_id, &reserved_state),
1924 );
1925 channel.state = if result >= 0 {
1926 ChannelState::Open {
1927 params: request,
1928 modify_state: ModifyState::NotModifying,
1929 reserved_state,
1930 }
1931 } else {
1932 ChannelState::Closed
1933 };
1934 }
1935 ChannelState::OpeningClientRelease => {
1936 tracing::info!(
1937 offer_id = offer_id.0,
1938 key = %channel.offer.key(),
1939 result,
1940 "opened channel (client released)"
1941 );
1942
1943 if result >= 0 {
1944 channel.state = ChannelState::ClosingClientRelease;
1945 self.notifier.notify(offer_id, Action::Close);
1946 } else {
1947 channel.state = ChannelState::ClientReleased;
1948 self.check_disconnected();
1949 }
1950 }
1951
1952 ChannelState::ClientReleased
1953 | ChannelState::Closed
1954 | ChannelState::Open { .. }
1955 | ChannelState::Closing { .. }
1956 | ChannelState::ClosingReopen { .. }
1957 | ChannelState::Revoked
1958 | ChannelState::Reoffered
1959 | ChannelState::ClosingClientRelease => {
1960 tracing::error!(?offer_id, key = %channel.offer.key(), state = ?channel.state, "invalid open complete")
1961 }
1962 }
1963 }
1964
1965 fn are_channels_reset(&self, include_reserved: bool) -> bool {
1968 self.inner.gpadls.keys().all(|(_, offer_id)| {
1969 !include_reserved && self.inner.channels[*offer_id].state.is_reserved()
1970 }) && self.inner.channels.iter().all(|(_, channel)| {
1971 matches!(channel.state, ChannelState::ClientReleased)
1972 || (!include_reserved && channel.state.is_reserved())
1973 })
1974 }
1975
1976 fn check_disconnected(&mut self) {
1980 match self.inner.state {
1981 ConnectionState::Disconnecting {
1982 next_action,
1983 modify_sent: false,
1984 } => {
1985 if self.are_channels_reset(matches!(next_action, ConnectionAction::Reset)) {
1986 self.notify_disconnect(next_action);
1987 }
1988 }
1989 ConnectionState::Disconnecting {
1990 modify_sent: true, ..
1991 }
1992 | ConnectionState::Disconnected
1993 | ConnectionState::Connected { .. }
1994 | ConnectionState::Connecting { .. } => (),
1995 }
1996 }
1997
1998 fn notify_disconnect(&mut self, next_action: ConnectionAction) {
2000 debug_assert!(self.are_channels_reset(matches!(next_action, ConnectionAction::Reset)));
2002 self.inner.state = ConnectionState::Disconnecting {
2003 next_action,
2004 modify_sent: true,
2005 };
2006
2007 self.notifier
2009 .modify_connection(ModifyConnectionRequest {
2010 monitor_page: Update::Reset,
2011 interrupt_page: Update::Reset,
2012 ..Default::default()
2013 })
2014 .expect("resetting state should not fail");
2015 }
2016
2017 fn is_resetting(&self) -> bool {
2020 matches!(
2021 &self.inner.state,
2022 ConnectionState::Connecting {
2023 next_action: ConnectionAction::Reset,
2024 ..
2025 } | ConnectionState::Disconnecting {
2026 next_action: ConnectionAction::Reset,
2027 ..
2028 }
2029 )
2030 }
2031
2032 pub fn close_complete(&mut self, offer_id: OfferId) {
2034 let channel = &mut self.inner.channels[offer_id];
2035 tracing::info!(offer_id = offer_id.0, key = %channel.offer.key(), "closed channel");
2036 match channel.state {
2037 ChannelState::Closing {
2038 reserved_state: Some(reserved_state),
2039 ..
2040 } => {
2041 channel.state = ChannelState::Closed;
2042 if matches!(self.inner.state, ConnectionState::Connected { .. }) {
2043 let channel_id = channel.info.expect("assigned").channel_id;
2044 self.send_close_reserved_channel_response(
2045 channel_id,
2046 offer_id,
2047 reserved_state.target,
2048 );
2049 } else {
2050 if Self::client_release_channel(
2053 self.inner
2054 .pending_messages
2055 .sender(self.notifier, self.inner.state.is_paused()),
2056 offer_id,
2057 channel,
2058 &mut self.inner.gpadls,
2059 &mut self.inner.incomplete_gpadls,
2060 &mut self.inner.assigned_channels,
2061 &mut self.inner.assigned_monitors,
2062 None,
2063 false,
2064 ) {
2065 self.inner.channels.remove(offer_id);
2066 }
2067 }
2068 }
2069 ChannelState::Closing { .. } => {
2070 channel.state = ChannelState::Closed;
2071 }
2072 ChannelState::ClosingClientRelease => {
2073 channel.state = ChannelState::ClientReleased;
2074 self.check_disconnected();
2075 }
2076 ChannelState::ClosingReopen { request, .. } => {
2077 channel.state = ChannelState::Closed;
2078 self.open_channel(offer_id, &request, None);
2079 }
2080
2081 ChannelState::Closed
2082 | ChannelState::ClientReleased
2083 | ChannelState::Opening { .. }
2084 | ChannelState::Open { .. }
2085 | ChannelState::Revoked
2086 | ChannelState::Reoffered
2087 | ChannelState::OpeningClientRelease => {
2088 tracing::error!(?offer_id, key = %channel.offer.key(), state = ?channel.state, "invalid close complete")
2089 }
2090 }
2091 }
2092
2093 fn send_close_reserved_channel_response(
2094 &mut self,
2095 channel_id: ChannelId,
2096 offer_id: OfferId,
2097 target: ConnectionTarget,
2098 ) {
2099 self.sender().send_message_with_target(
2100 &protocol::CloseReservedChannelResponse { channel_id },
2101 MessageTarget::ReservedChannel(offer_id, target),
2102 );
2103 }
2104
2105 fn handle_initiate_contact(
2108 &mut self,
2109 input: &protocol::InitiateContact2,
2110 message: &SynicMessage,
2111 includes_client_id: bool,
2112 ) -> Result<(), ChannelError> {
2113 let target_info =
2114 protocol::TargetInfo::from(input.initiate_contact.interrupt_page_or_target_info);
2115
2116 let target_sint = if message.multiclient
2117 && input.initiate_contact.version_requested >= Version::Win10Rs3_1 as u32
2118 {
2119 target_info.sint()
2120 } else {
2121 VMBUS_SINT
2122 };
2123
2124 let target_vtl = if message.multiclient
2125 && input.initiate_contact.version_requested >= Version::Win10Rs4 as u32
2126 {
2127 target_info.vtl()
2128 } else {
2129 0
2130 };
2131
2132 let feature_flags = if input.initiate_contact.version_requested >= Version::Copper as u32 {
2133 target_info.feature_flags()
2134 } else {
2135 0
2136 };
2137
2138 let target_message_vp =
2143 if input.initiate_contact.version_requested >= Version::Win8_1 as u32 {
2144 input.initiate_contact.target_message_vp
2145 } else {
2146 0
2147 };
2148
2149 let interrupt_page = (input.initiate_contact.version_requested < Version::Win8 as u32
2156 && input.initiate_contact.interrupt_page_or_target_info != 0)
2157 .then_some(input.initiate_contact.interrupt_page_or_target_info);
2158
2159 let monitor_page = if (input.initiate_contact.parent_to_child_monitor_page_gpa == 0)
2162 != (input.initiate_contact.child_to_parent_monitor_page_gpa == 0)
2163 {
2164 MonitorPageRequest::Invalid
2165 } else if input.initiate_contact.parent_to_child_monitor_page_gpa != 0 {
2166 MonitorPageRequest::Some(MonitorPageGpas {
2167 parent_to_child: input.initiate_contact.parent_to_child_monitor_page_gpa,
2168 child_to_parent: input.initiate_contact.child_to_parent_monitor_page_gpa,
2169 })
2170 } else {
2171 MonitorPageRequest::None
2172 };
2173
2174 let client_id = if FeatureFlags::from(feature_flags).client_id() {
2177 if includes_client_id {
2178 input.client_id
2179 } else {
2180 return Err(ChannelError::ParseError(
2181 protocol::ParseError::MessageTooSmall(Some(
2182 protocol::MessageType::INITIATE_CONTACT,
2183 )),
2184 ));
2185 }
2186 } else {
2187 Guid::ZERO
2188 };
2189
2190 let request = InitiateContactRequest {
2191 version_requested: input.initiate_contact.version_requested,
2192 target_message_vp,
2193 monitor_page,
2194 target_sint,
2195 target_vtl,
2196 feature_flags,
2197 interrupt_page,
2198 client_id,
2199 trusted: message.trusted,
2200 };
2201 self.initiate_contact(request);
2202 Ok(())
2203 }
2204
2205 pub fn initiate_contact(&mut self, request: InitiateContactRequest) {
2206 let vtl = self.inner.assigned_channels.vtl as u8;
2209 if request.target_vtl != vtl {
2210 self.notifier.forward_unhandled(request);
2212 return;
2213 }
2214
2215 if request.target_sint != VMBUS_SINT {
2216 tracelimit::warn_ratelimited!(
2217 target_vtl = request.target_vtl,
2218 target_sint = request.target_sint,
2219 version = request.version_requested,
2220 "unsupported multiclient request",
2221 );
2222
2223 self.send_version_response_with_target(
2225 None,
2226 MessageTarget::Custom(ConnectionTarget {
2227 vp: request.target_message_vp,
2228 sint: request.target_sint,
2229 }),
2230 );
2231
2232 return;
2233 }
2234
2235 if !self.request_disconnect(ConnectionAction::Reconnect {
2236 initiate_contact: request,
2237 }) {
2238 return;
2239 }
2240
2241 let Some(version) = self.check_version_supported(&request) else {
2242 tracelimit::warn_ratelimited!(
2243 vtl,
2244 version = request.version_requested,
2245 client_id = ?request.client_id,
2246 "Guest requested unsupported version"
2247 );
2248
2249 self.send_version_response(None);
2251 return;
2252 };
2253
2254 tracelimit::info_ratelimited!(
2255 vtl,
2256 ?version,
2257 client_id = ?request.client_id,
2258 trusted = request.trusted,
2259 "Guest negotiated version"
2260 );
2261
2262 let monitor_page = match request.monitor_page {
2265 MonitorPageRequest::Some(mp) => {
2266 if self.inner.require_server_allocated_mnf {
2267 if !version.feature_flags.server_specified_monitor_pages() {
2268 tracelimit::warn_ratelimited!(
2269 "guest-supplied monitor pages not supported; MNF will be disabled"
2270 );
2271 }
2272
2273 None
2274 } else {
2275 Some(mp)
2276 }
2277 }
2278 MonitorPageRequest::None => None,
2279 MonitorPageRequest::Invalid => {
2280 self.send_version_response(Some(VersionResponseData::new(
2282 version,
2283 protocol::ConnectionState::FAILED_UNKNOWN_FAILURE,
2284 )));
2285
2286 return;
2287 }
2288 };
2289
2290 self.inner.state = ConnectionState::Connecting {
2291 info: ConnectionInfo {
2292 version,
2293 trusted: request.trusted,
2294 interrupt_page: request.interrupt_page,
2295 monitor_page: monitor_page.map(MonitorPageGpaInfo::from_guest_gpas),
2296 target_message_vp: request.target_message_vp,
2297 modifying: false,
2298 offers_sent: false,
2299 client_id: request.client_id,
2300 paused: false,
2301 },
2302 next_action: ConnectionAction::None,
2303 };
2304
2305 if let Err(err) = self.notifier.modify_connection(ModifyConnectionRequest {
2308 version: Some(version),
2309 monitor_page: monitor_page.into(),
2310 interrupt_page: request.interrupt_page.into(),
2311 target_message_vp: Some(request.target_message_vp),
2312 notify_relay: true,
2313 }) {
2314 tracelimit::error_ratelimited!(?err, "server failed to change state");
2315 self.inner.state = ConnectionState::Disconnected;
2316 self.send_version_response(Some(VersionResponseData::new(
2317 version,
2318 protocol::ConnectionState::FAILED_UNKNOWN_FAILURE,
2319 )));
2320 }
2321 }
2322
2323 pub(crate) fn complete_initiate_contact(&mut self, response: ModifyConnectionResponse) {
2324 let ConnectionState::Connecting {
2325 mut info,
2326 next_action,
2327 } = self.inner.state
2328 else {
2329 panic!("Invalid state for completing InitiateContact.");
2330 };
2331
2332 const LOCAL_FEATURE_FLAGS: FeatureFlags = FeatureFlags::new()
2336 .with_client_id(true)
2337 .with_confidential_channels(true);
2338
2339 let (relay_feature_flags, server_specified_monitor_page) = match response {
2340 ModifyConnectionResponse::Supported(
2342 protocol::ConnectionState::SUCCESSFUL,
2343 feature_flags,
2344 server_specified_monitor_page,
2345 ) => (feature_flags, server_specified_monitor_page),
2346 ModifyConnectionResponse::Supported(
2349 connection_state,
2350 feature_flags,
2351 server_specified_monitor_page,
2352 ) => {
2353 tracelimit::error_ratelimited!(
2354 ?connection_state,
2355 "initiate contact failed because relay request failed"
2356 );
2357
2358 info.version.feature_flags &= (feature_flags | LOCAL_FEATURE_FLAGS)
2361 .with_server_specified_monitor_pages(server_specified_monitor_page.is_some());
2362
2363 self.send_version_response(Some(VersionResponseData::new(
2364 info.version,
2365 connection_state,
2366 )));
2367 self.inner.state = ConnectionState::Disconnected;
2368 return;
2369 }
2370 ModifyConnectionResponse::Unsupported => {
2373 self.send_version_response(None);
2374 self.inner.state = ConnectionState::Disconnected;
2375 return;
2376 }
2377 ModifyConnectionResponse::Modified(_) => {
2378 panic!("Invalid response for completing InitiateContact.");
2379 }
2380 };
2381
2382 assert!(
2384 info.version.feature_flags.server_specified_monitor_pages()
2385 || server_specified_monitor_page.is_none()
2386 );
2387
2388 info.version.feature_flags &= relay_feature_flags | LOCAL_FEATURE_FLAGS;
2391
2392 if let Some(gpas) = server_specified_monitor_page {
2396 info.monitor_page = Some(MonitorPageGpaInfo::from_server_gpas(gpas));
2397 info.version
2398 .feature_flags
2399 .set_server_specified_monitor_pages(true);
2400 } else {
2401 info.version
2402 .feature_flags
2403 .set_server_specified_monitor_pages(false);
2404 }
2405
2406 let version = info.version;
2407 self.inner.state = ConnectionState::Connected(info);
2408
2409 self.send_version_response(Some(
2410 VersionResponseData::new(version, protocol::ConnectionState::SUCCESSFUL)
2411 .with_monitor_pages(server_specified_monitor_page),
2412 ));
2413 if !matches!(next_action, ConnectionAction::None) && self.request_disconnect(next_action) {
2414 self.do_next_action(next_action);
2415 }
2416 }
2417
2418 fn check_version_supported(&self, request: &InitiateContactRequest) -> Option<VersionInfo> {
2420 let version = SUPPORTED_VERSIONS
2421 .iter()
2422 .find(|v| request.version_requested == **v as u32)
2423 .copied()?;
2424
2425 if let Some(max_version) = self.inner.max_version {
2427 if version as u32 > max_version.version {
2428 return None;
2429 }
2430 }
2431
2432 let supported_flags = if version >= Version::Copper {
2433 let max_supported_flags =
2435 SUPPORTED_FEATURE_FLAGS.with_confidential_channels(request.trusted);
2436
2437 if let Some(max_version) = self.inner.max_version {
2439 max_supported_flags & max_version.feature_flags
2440 } else {
2441 max_supported_flags
2442 }
2443 } else {
2444 FeatureFlags::new()
2445 };
2446
2447 let feature_flags = supported_flags & request.feature_flags.into();
2448
2449 assert!(version >= Version::Copper || feature_flags == FeatureFlags::new());
2450 if feature_flags.into_bits() != request.feature_flags {
2451 tracelimit::info_ratelimited!(
2454 supported = feature_flags.into_bits(),
2455 requested = request.feature_flags,
2456 "guest requested unsupported feature flags."
2457 );
2458 }
2459
2460 Some(VersionInfo {
2461 version,
2462 feature_flags,
2463 })
2464 }
2465
2466 fn send_version_response(&mut self, data: Option<VersionResponseData>) {
2467 self.send_version_response_with_target(data, MessageTarget::Default);
2468 }
2469
2470 fn send_version_response_with_target(
2471 &mut self,
2472 data: Option<VersionResponseData>,
2473 target: MessageTarget,
2474 ) {
2475 enum VersionResponseType {
2476 PreCopper,
2477 Copper,
2478 CopperWithServerMnf,
2479 }
2480
2481 let mut response_copper_with_mnf = protocol::VersionResponse3::new_zeroed();
2482 let response_copper = &mut response_copper_with_mnf.version_response2;
2483 let response = &mut response_copper.version_response;
2484 let mut response_type = VersionResponseType::PreCopper;
2485 if let Some(data) = data {
2486 if data.state == protocol::ConnectionState::SUCCESSFUL
2489 || data.version.version >= Version::Win8
2490 {
2491 response.version_supported = 1;
2492 response.connection_state = data.state;
2493 response.selected_version_or_connection_id =
2494 if data.version.version >= Version::Win10Rs3_1 {
2495 self.inner.child_connection_id
2496 } else {
2497 data.version.version as u32
2498 };
2499
2500 if data.version.version >= Version::Copper {
2501 response_copper.supported_features = data.version.feature_flags.into();
2502 response_type = VersionResponseType::Copper;
2503 if let Some(monitor_page) = data.monitor_pages {
2504 assert!(data.version.feature_flags.server_specified_monitor_pages());
2505 response_copper_with_mnf.child_to_parent_monitor_page_gpa =
2506 monitor_page.child_to_parent;
2507 response_copper_with_mnf.parent_to_child_monitor_page_gpa =
2508 monitor_page.parent_to_child;
2509 response_type = VersionResponseType::CopperWithServerMnf;
2510 }
2511 }
2512 }
2513 }
2514
2515 match response_type {
2517 VersionResponseType::PreCopper => {
2518 self.sender().send_message_with_target(response, target)
2519 }
2520 VersionResponseType::Copper => self
2521 .sender()
2522 .send_message_with_target(response_copper, target),
2523 VersionResponseType::CopperWithServerMnf => self
2524 .sender()
2525 .send_message_with_target(&response_copper_with_mnf, target),
2526 }
2527 }
2528
2529 fn request_disconnect(&mut self, new_action: ConnectionAction) -> bool {
2532 assert!(!self.is_resetting());
2533
2534 let gpadls = &mut self.inner.gpadls;
2536 let vm_reset = matches!(new_action, ConnectionAction::Reset);
2537 self.inner.channels.retain(|offer_id, channel| {
2538 (!vm_reset && channel.state.is_reserved())
2540 || !Self::client_release_channel(
2541 self.inner
2542 .pending_messages
2543 .sender(self.notifier, self.inner.state.is_paused()),
2544 offer_id,
2545 channel,
2546 gpadls,
2547 &mut self.inner.incomplete_gpadls,
2548 &mut self.inner.assigned_channels,
2549 &mut self.inner.assigned_monitors,
2550 None,
2551 vm_reset,
2552 )
2553 });
2554
2555 match &mut self.inner.state {
2559 ConnectionState::Disconnected => {
2560 if vm_reset {
2562 if !self.are_channels_reset(true) {
2563 self.inner.state = ConnectionState::Disconnecting {
2564 next_action: ConnectionAction::Reset,
2565 modify_sent: false,
2566 };
2567 }
2568 } else {
2569 assert!(self.are_channels_reset(false));
2570 }
2571 }
2572
2573 ConnectionState::Connected { .. } => {
2574 if self.are_channels_reset(vm_reset) {
2575 self.notify_disconnect(new_action);
2576 } else {
2577 self.inner.state = ConnectionState::Disconnecting {
2578 next_action: new_action,
2579 modify_sent: false,
2580 };
2581 }
2582 }
2583
2584 ConnectionState::Connecting { next_action, .. }
2585 | ConnectionState::Disconnecting { next_action, .. } => {
2586 *next_action = new_action;
2587 }
2588 }
2589
2590 matches!(self.inner.state, ConnectionState::Disconnected)
2591 }
2592
2593 pub(crate) fn complete_disconnect(&mut self) {
2594 if let ConnectionState::Disconnecting {
2595 next_action,
2596 modify_sent,
2597 } = std::mem::replace(&mut self.inner.state, ConnectionState::Disconnected)
2598 {
2599 assert!(self.are_channels_reset(matches!(next_action, ConnectionAction::Reset)));
2600 if !modify_sent {
2601 tracelimit::warn_ratelimited!("unexpected modify response");
2602 }
2603
2604 self.inner.state = ConnectionState::Disconnected;
2605 self.do_next_action(next_action);
2606 } else {
2607 unreachable!("not ready for disconnect");
2608 }
2609 }
2610
2611 fn do_next_action(&mut self, action: ConnectionAction) {
2612 match action {
2613 ConnectionAction::None => {}
2614 ConnectionAction::Reset => {
2615 self.complete_reset();
2616 }
2617 ConnectionAction::SendUnloadComplete => {
2618 self.complete_unload();
2619 }
2620 ConnectionAction::Reconnect { initiate_contact } => {
2621 self.initiate_contact(initiate_contact);
2622 }
2623 ConnectionAction::SendFailedVersionResponse => {
2624 self.send_version_response(None);
2627 }
2628 }
2629 }
2630
2631 fn handle_unload(&mut self) {
2633 tracing::debug!(
2634 vtl = self.inner.assigned_channels.vtl as u8,
2635 state = ?self.inner.state,
2636 "VmBus received unload request from guest",
2637 );
2638
2639 if self.request_disconnect(ConnectionAction::SendUnloadComplete) {
2640 self.complete_unload();
2641 }
2642 }
2643
2644 fn complete_unload(&mut self) {
2645 self.notifier.unload_complete();
2646 if let Some(version) = self.inner.delayed_max_version.take() {
2647 self.inner.set_compatibility_version(version, false);
2648 }
2649
2650 self.sender().send_message(&protocol::UnloadComplete {});
2651 tracelimit::info_ratelimited!("Vmbus disconnected");
2652 }
2653
2654 fn handle_request_offers(&mut self) -> Result<(), ChannelError> {
2656 let ConnectionState::Connected(info) = &mut self.inner.state else {
2657 unreachable!(
2658 "in unexpected state {:?}, should be prevented by Message::parse()",
2659 self.inner.state
2660 );
2661 };
2662
2663 if info.offers_sent {
2664 return Err(ChannelError::OffersAlreadySent);
2665 }
2666
2667 info.offers_sent = true;
2668
2669 let mut sorted_channels: Vec<_> = self
2672 .inner
2673 .channels
2674 .iter_mut()
2675 .filter(|(_, channel)| !channel.state.is_reserved())
2676 .collect();
2677
2678 if self.inner.use_absolute_channel_order {
2679 sorted_channels.sort_unstable_by_key(|(_, channel)| {
2680 (
2681 channel.offer.offer_order.unwrap_or(u64::MAX),
2682 channel.offer.interface_id,
2683 channel.offer.instance_id,
2684 )
2685 });
2686 } else {
2687 sorted_channels.sort_unstable_by_key(|(_, channel)| {
2688 (
2689 channel.offer.interface_id,
2690 channel.offer.offer_order.unwrap_or(u64::MAX),
2691 channel.offer.instance_id,
2692 )
2693 });
2694 }
2695
2696 for (offer_id, channel) in sorted_channels {
2697 assert!(matches!(channel.state, ChannelState::ClientReleased));
2698
2699 channel.prepare_channel(
2700 offer_id,
2701 &mut self.inner.assigned_channels,
2702 &mut self.inner.assigned_monitors,
2703 );
2704
2705 channel.state = ChannelState::Closed;
2706 self.inner
2707 .pending_messages
2708 .sender(self.notifier, info.paused)
2709 .send_offer(channel, info);
2710 }
2711 self.sender().send_message(&protocol::AllOffersDelivered {});
2712
2713 Ok(())
2714 }
2715
2716 #[must_use]
2719 fn gpadl_updated(
2720 mut sender: MessageSender<'_, N>,
2721 offer_id: OfferId,
2722 channel: &Channel,
2723 gpadl_id: GpadlId,
2724 gpadl: &Gpadl,
2725 ) -> bool {
2726 if channel.state.is_revoked() {
2727 let channel_id = channel.info.as_ref().expect("assigned").channel_id;
2728 sender.send_gpadl_created(channel_id, gpadl_id, protocol::STATUS_UNSUCCESSFUL);
2729 false
2730 } else {
2731 sender.notifier.notify(
2733 offer_id,
2734 Action::Gpadl(gpadl_id, gpadl.count, gpadl.buf.clone()),
2735 );
2736 true
2737 }
2738 }
2739
2740 fn handle_gpadl_header_core(
2742 &mut self,
2743 input: &protocol::GpadlHeader,
2744 range: &[u8],
2745 ) -> Result<(), ChannelError> {
2746 let (offer_id, channel) = self
2748 .inner
2749 .channels
2750 .get_by_channel_id_mut(&self.inner.assigned_channels, input.channel_id)?;
2751
2752 if channel.state.is_reserved() {
2755 return Err(ChannelError::ChannelReserved);
2756 }
2757
2758 let mut gpadl = Gpadl::new(input.count, input.len as usize / 8);
2760 let done = gpadl.append(range)?;
2761
2762 let gpadl = match self.inner.gpadls.entry((input.gpadl_id, offer_id)) {
2764 Entry::Vacant(entry) => entry.insert(gpadl),
2765 Entry::Occupied(_) => return Err(ChannelError::DuplicateGpadlId),
2766 };
2767
2768 if !done {
2774 match self.inner.incomplete_gpadls.entry(input.gpadl_id) {
2775 Entry::Vacant(entry) => {
2776 entry.insert(offer_id);
2777 }
2778 Entry::Occupied(_) => {
2779 self.inner.gpadls.remove(&(input.gpadl_id, offer_id));
2780 tracelimit::error_ratelimited!(
2781 channel_id = ?input.channel_id,
2782 key = %channel.offer.key(),
2783 gpadl_id = ?input.gpadl_id,
2784 "duplicate in-progress gpadl ID",
2785 );
2786 return Err(ChannelError::DuplicateGpadlId);
2787 }
2788 }
2789 }
2790
2791 if done
2792 && !Self::gpadl_updated(
2793 self.inner
2794 .pending_messages
2795 .sender(self.notifier, self.inner.state.is_paused()),
2796 offer_id,
2797 channel,
2798 input.gpadl_id,
2799 gpadl,
2800 )
2801 {
2802 self.inner.gpadls.remove(&(input.gpadl_id, offer_id));
2803 }
2804 Ok(())
2805 }
2806
2807 fn handle_gpadl_header(&mut self, input: &protocol::GpadlHeader, range: &[u8]) {
2809 if let Err(err) = self.handle_gpadl_header_core(input, range) {
2810 tracelimit::warn_ratelimited!(
2811 err = &err as &dyn std::error::Error,
2812 channel_id = ?input.channel_id,
2813 key = %self.inner.channels.get_by_channel_id(&self.inner.assigned_channels, input.channel_id).map(|(_, c)| c.offer.key()).unwrap_or_default(),
2814 gpadl_id = ?input.gpadl_id,
2815 "error handling gpadl header"
2816 );
2817
2818 self.sender().send_gpadl_created(
2820 input.channel_id,
2821 input.gpadl_id,
2822 protocol::STATUS_UNSUCCESSFUL,
2823 );
2824 }
2825 }
2826
2827 fn handle_gpadl_body(
2833 &mut self,
2834 input: &protocol::GpadlBody,
2835 range: &[u8],
2836 ) -> Result<(), ChannelError> {
2837 let &offer_id = self
2841 .inner
2842 .incomplete_gpadls
2843 .get(&input.gpadl_id)
2844 .ok_or(ChannelError::UnknownGpadlId)?;
2845 let gpadl = self
2846 .inner
2847 .gpadls
2848 .get_mut(&(input.gpadl_id, offer_id))
2849 .ok_or(ChannelError::UnknownGpadlId)?;
2850 let channel = &mut self.inner.channels[offer_id];
2851
2852 match gpadl.append(range) {
2853 Ok(done) => {
2854 if done {
2855 self.inner.incomplete_gpadls.remove(&input.gpadl_id);
2856 if !Self::gpadl_updated(
2857 self.inner
2858 .pending_messages
2859 .sender(self.notifier, self.inner.state.is_paused()),
2860 offer_id,
2861 channel,
2862 input.gpadl_id,
2863 gpadl,
2864 ) {
2865 self.inner.gpadls.remove(&(input.gpadl_id, offer_id));
2866 }
2867 }
2868 }
2869 Err(err) => {
2870 self.inner.incomplete_gpadls.remove(&input.gpadl_id);
2871 self.inner.gpadls.remove(&(input.gpadl_id, offer_id));
2872 let channel_id = channel.info.as_ref().expect("assigned").channel_id;
2873 tracelimit::warn_ratelimited!(
2874 err = &err as &dyn std::error::Error,
2875 channel_id = channel_id.0,
2876 key = %channel.offer.key(),
2877 gpadl_id = input.gpadl_id.0,
2878 "error handling gpadl body"
2879 );
2880 self.sender().send_gpadl_created(
2881 channel_id,
2882 input.gpadl_id,
2883 protocol::STATUS_UNSUCCESSFUL,
2884 );
2885 }
2886 }
2887
2888 Ok(())
2889 }
2890
2891 fn handle_gpadl_teardown(
2893 &mut self,
2894 input: &protocol::GpadlTeardown,
2895 ) -> Result<(), ChannelError> {
2896 let (offer_id, channel) = self
2897 .inner
2898 .channels
2899 .get_by_channel_id_mut(&self.inner.assigned_channels, input.channel_id)?;
2900
2901 tracing::debug!(
2902 channel_id = input.channel_id.0,
2903 key = %channel.offer.key(),
2904 gpadl_id = input.gpadl_id.0,
2905 "Received GPADL teardown request"
2906 );
2907
2908 let gpadl = self
2909 .inner
2910 .gpadls
2911 .get_mut(&(input.gpadl_id, offer_id))
2912 .ok_or(ChannelError::UnknownGpadlId)?;
2913
2914 match gpadl.state {
2915 GpadlState::InProgress
2916 | GpadlState::Offered
2917 | GpadlState::OfferedTearingDown
2918 | GpadlState::TearingDown => {
2919 return Err(ChannelError::InvalidGpadlState);
2920 }
2921 GpadlState::Accepted => {
2922 if channel.info.as_ref().map(|info| info.channel_id) != Some(input.channel_id) {
2923 return Err(ChannelError::WrongGpadlChannelId);
2924 }
2925
2926 if channel.state.is_reserved() {
2930 return Err(ChannelError::ChannelReserved);
2931 }
2932
2933 if channel.state.is_revoked() {
2934 tracing::trace!(
2935 channel_id = input.channel_id.0,
2936 key = %channel.offer.key(),
2937 gpadl_id = input.gpadl_id.0,
2938 "Gpadl teardown for revoked channel"
2939 );
2940
2941 self.inner.gpadls.remove(&(input.gpadl_id, offer_id));
2942 self.sender().send_gpadl_torndown(input.gpadl_id);
2943 } else {
2944 gpadl.state = GpadlState::TearingDown;
2945 self.notifier.notify(
2946 offer_id,
2947 Action::TeardownGpadl {
2948 gpadl_id: input.gpadl_id,
2949 post_restore: false,
2950 },
2951 );
2952 }
2953 }
2954 }
2955 Ok(())
2956 }
2957
2958 fn open_channel(
2961 &mut self,
2962 offer_id: OfferId,
2963 input: &OpenRequest,
2964 reserved_state: Option<ReservedState>,
2965 ) {
2966 let channel = &mut self.inner.channels[offer_id];
2967 assert!(matches!(channel.state, ChannelState::Closed));
2968
2969 channel.state = ChannelState::Opening {
2970 request: *input,
2971 reserved_state,
2972 };
2973
2974 let info = channel.info.as_ref().expect("assigned");
2977 self.notifier.notify(
2978 offer_id,
2979 Action::Open(
2980 OpenParams::from_request(
2981 info,
2982 input,
2983 channel.handled_monitor_info(),
2984 reserved_state.map(|state| state.target),
2985 ),
2986 self.inner.state.get_version().expect("must be connected"),
2987 ),
2988 );
2989 }
2990
2991 fn handle_open_channel(&mut self, input: &protocol::OpenChannel2) -> Result<(), ChannelError> {
2993 let (offer_id, channel) = self
2994 .inner
2995 .channels
2996 .get_by_channel_id_mut(&self.inner.assigned_channels, input.open_channel.channel_id)?;
2997
2998 let guest_specified_interrupt_info = self
2999 .inner
3000 .state
3001 .check_feature_flags(|ff| ff.guest_specified_signal_parameters())
3002 .then_some(SignalInfo {
3003 event_flag: input.event_flag,
3004 connection_id: input.connection_id,
3005 });
3006
3007 let flags = if self
3008 .inner
3009 .state
3010 .check_feature_flags(|ff| ff.channel_interrupt_redirection())
3011 {
3012 input.flags
3013 } else {
3014 Default::default()
3015 };
3016
3017 let request = OpenRequest {
3018 open_id: input.open_channel.open_id,
3019 ring_buffer_gpadl_id: input.open_channel.ring_buffer_gpadl_id,
3020 target_vp: protocol::vp_index_if_enabled(input.open_channel.target_vp),
3021 downstream_ring_buffer_page_offset: input
3022 .open_channel
3023 .downstream_ring_buffer_page_offset,
3024 user_data: input.open_channel.user_data,
3025 guest_specified_interrupt_info,
3026 flags,
3027 };
3028
3029 match channel.state {
3030 ChannelState::Closed => self.open_channel(offer_id, &request, None),
3031 ChannelState::Closing { params, .. } => {
3032 channel.state = ChannelState::ClosingReopen { params, request }
3036 }
3037 ChannelState::Revoked | ChannelState::Reoffered => {}
3038
3039 ChannelState::Open { .. }
3040 | ChannelState::Opening { .. }
3041 | ChannelState::ClosingReopen { .. } => return Err(ChannelError::ChannelAlreadyOpen),
3042
3043 ChannelState::ClientReleased
3044 | ChannelState::ClosingClientRelease
3045 | ChannelState::OpeningClientRelease => unreachable!(),
3046 }
3047 Ok(())
3048 }
3049
3050 fn handle_close_channel(&mut self, input: &protocol::CloseChannel) -> Result<(), ChannelError> {
3052 let (offer_id, channel) = self
3053 .inner
3054 .channels
3055 .get_by_channel_id_mut(&self.inner.assigned_channels, input.channel_id)?;
3056
3057 match channel.state {
3058 ChannelState::Open {
3059 params,
3060 modify_state,
3061 reserved_state: None,
3062 } => {
3063 if modify_state.is_modifying() {
3064 tracelimit::warn_ratelimited!(
3065 key = %channel.offer.key(),
3066 ?modify_state,
3067 "Client is closing the channel with a modify in progress"
3068 )
3069 }
3070
3071 channel.state = ChannelState::Closing {
3072 params,
3073 reserved_state: None,
3074 };
3075 self.notifier.notify(offer_id, Action::Close);
3076 }
3077
3078 ChannelState::Open {
3079 reserved_state: Some(_),
3080 ..
3081 } => return Err(ChannelError::ChannelReserved),
3082
3083 ChannelState::Revoked | ChannelState::Reoffered => {}
3084
3085 ChannelState::Closed
3086 | ChannelState::Opening { .. }
3087 | ChannelState::Closing { .. }
3088 | ChannelState::ClosingReopen { .. } => return Err(ChannelError::ChannelNotOpen),
3089
3090 ChannelState::ClientReleased
3091 | ChannelState::ClosingClientRelease
3092 | ChannelState::OpeningClientRelease => unreachable!(),
3093 }
3094
3095 Ok(())
3096 }
3097
3098 fn handle_open_reserved_channel(
3101 &mut self,
3102 input: &protocol::OpenReservedChannel,
3103 version: VersionInfo,
3104 ) -> Result<(), ChannelError> {
3105 let (offer_id, channel) = self
3106 .inner
3107 .channels
3108 .get_by_channel_id_mut(&self.inner.assigned_channels, input.channel_id)?;
3109
3110 let target = ConnectionTarget {
3111 vp: input.target_vp,
3112 sint: input.target_sint as u8,
3113 };
3114
3115 let reserved_state = Some(ReservedState { version, target });
3116
3117 let request = OpenRequest {
3118 ring_buffer_gpadl_id: input.ring_buffer_gpadl,
3119 target_vp: None,
3121 downstream_ring_buffer_page_offset: input.downstream_page_offset,
3122 open_id: 0,
3123 user_data: UserDefinedData::new_zeroed(),
3124 guest_specified_interrupt_info: None,
3125 flags: Default::default(),
3126 };
3127
3128 match channel.state {
3129 ChannelState::Closed => self.open_channel(offer_id, &request, reserved_state),
3130 ChannelState::Revoked | ChannelState::Reoffered => {}
3131
3132 ChannelState::Open { .. } | ChannelState::Opening { .. } => {
3133 return Err(ChannelError::ChannelAlreadyOpen);
3134 }
3135
3136 ChannelState::Closing { .. } | ChannelState::ClosingReopen { .. } => {
3137 return Err(ChannelError::InvalidChannelState);
3138 }
3139
3140 ChannelState::ClientReleased
3141 | ChannelState::ClosingClientRelease
3142 | ChannelState::OpeningClientRelease => unreachable!(),
3143 }
3144 Ok(())
3145 }
3146
3147 fn handle_close_reserved_channel(
3150 &mut self,
3151 input: &protocol::CloseReservedChannel,
3152 ) -> Result<(), ChannelError> {
3153 let (offer_id, channel) = self
3154 .inner
3155 .channels
3156 .get_by_channel_id_mut(&self.inner.assigned_channels, input.channel_id)?;
3157
3158 match channel.state {
3159 ChannelState::Open {
3160 params,
3161 reserved_state: Some(mut resvd),
3162 ..
3163 } => {
3164 resvd.target.vp = input.target_vp;
3165 resvd.target.sint = input.target_sint as u8;
3166 channel.state = ChannelState::Closing {
3167 params,
3168 reserved_state: Some(resvd),
3169 };
3170 self.notifier.notify(offer_id, Action::Close);
3171 }
3172
3173 ChannelState::Open {
3174 reserved_state: None,
3175 ..
3176 } => return Err(ChannelError::ChannelNotReserved),
3177
3178 ChannelState::Revoked | ChannelState::Reoffered => {}
3179
3180 ChannelState::Closed
3181 | ChannelState::Opening { .. }
3182 | ChannelState::Closing { .. }
3183 | ChannelState::ClosingReopen { .. } => return Err(ChannelError::ChannelNotOpen),
3184
3185 ChannelState::ClientReleased
3186 | ChannelState::ClosingClientRelease
3187 | ChannelState::OpeningClientRelease => unreachable!(),
3188 }
3189
3190 Ok(())
3191 }
3192
3193 #[must_use]
3197 fn client_release_channel(
3198 mut sender: MessageSender<'_, N>,
3199 offer_id: OfferId,
3200 channel: &mut Channel,
3201 gpadls: &mut GpadlMap,
3202 incomplete_gpadls: &mut IncompleteGpadlMap,
3203 assigned_channels: &mut AssignedChannels,
3204 assigned_monitors: &mut AssignedMonitors,
3205 info: Option<&ConnectionInfo>,
3206 vm_reset: bool,
3207 ) -> bool {
3208 tracelimit::info_ratelimited!(?offer_id, key = %channel.offer.key(), "client released channel");
3209 gpadls.retain(|&(gpadl_id, gpadl_offer_id), gpadl| {
3211 if gpadl_offer_id != offer_id {
3212 return true;
3213 }
3214 match gpadl.state {
3215 GpadlState::InProgress => {
3216 incomplete_gpadls.remove(&gpadl_id);
3217 false
3218 }
3219 GpadlState::Offered => {
3220 gpadl.state = GpadlState::OfferedTearingDown;
3221 true
3222 }
3223 GpadlState::Accepted => {
3224 if channel.state.is_revoked() {
3225 false
3227 } else {
3228 gpadl.state = GpadlState::TearingDown;
3229 sender.notifier.notify(
3230 offer_id,
3231 Action::TeardownGpadl {
3232 gpadl_id,
3233 post_restore: false,
3234 },
3235 );
3236 true
3237 }
3238 }
3239 GpadlState::OfferedTearingDown | GpadlState::TearingDown => true,
3240 }
3241 });
3242
3243 let remove = match &mut channel.state {
3244 ChannelState::Closed => {
3245 channel.state = ChannelState::ClientReleased;
3246 false
3247 }
3248 ChannelState::Reoffered => {
3249 if let Some(info) = info {
3250 channel.state = ChannelState::Closed;
3251 channel.restore_state = RestoreState::New;
3252 sender.send_offer(channel, info);
3253 return false;
3255 }
3256 channel.state = ChannelState::ClientReleased;
3257 false
3258 }
3259 ChannelState::Revoked => {
3260 channel.state = ChannelState::ClientReleased;
3261 true
3262 }
3263 ChannelState::Opening { .. } => {
3264 if vm_reset {
3282 channel.state = ChannelState::ClientReleased;
3283 } else {
3284 channel.state = ChannelState::OpeningClientRelease;
3285 }
3286 false
3287 }
3288 ChannelState::Open { .. } => {
3289 channel.state = ChannelState::ClosingClientRelease;
3290 sender.notifier.notify(offer_id, Action::Close);
3291 false
3292 }
3293 ChannelState::Closing { .. } | ChannelState::ClosingReopen { .. } => {
3294 channel.state = ChannelState::ClosingClientRelease;
3295 false
3296 }
3297
3298 ChannelState::ClosingClientRelease
3299 | ChannelState::OpeningClientRelease
3300 | ChannelState::ClientReleased => false,
3301 };
3302
3303 assert!(channel.state.is_released());
3304
3305 channel.release_channel(offer_id, assigned_channels, assigned_monitors);
3306 remove
3307 }
3308
3309 fn handle_rel_id_released(
3311 &mut self,
3312 input: &protocol::RelIdReleased,
3313 ) -> Result<(), ChannelError> {
3314 let channel_id = input.channel_id;
3315 let (offer_id, channel) = self
3316 .inner
3317 .channels
3318 .get_by_channel_id_mut(&self.inner.assigned_channels, channel_id)?;
3319
3320 match channel.state {
3321 ChannelState::Closed
3322 | ChannelState::Revoked
3323 | ChannelState::Closing { .. }
3324 | ChannelState::Reoffered => {
3325 if Self::client_release_channel(
3326 self.inner
3327 .pending_messages
3328 .sender(self.notifier, self.inner.state.is_paused()),
3329 offer_id,
3330 channel,
3331 &mut self.inner.gpadls,
3332 &mut self.inner.incomplete_gpadls,
3333 &mut self.inner.assigned_channels,
3334 &mut self.inner.assigned_monitors,
3335 self.inner.state.get_connected_info(),
3336 false,
3337 ) {
3338 self.inner.channels.remove(offer_id);
3339 }
3340
3341 self.check_disconnected();
3342 }
3343
3344 ChannelState::Opening { .. }
3345 | ChannelState::Open { .. }
3346 | ChannelState::ClosingReopen { .. } => return Err(ChannelError::InvalidChannelState),
3347
3348 ChannelState::ClientReleased
3349 | ChannelState::OpeningClientRelease
3350 | ChannelState::ClosingClientRelease => unreachable!(),
3351 }
3352 Ok(())
3353 }
3354
3355 fn handle_tl_connect_request(&mut self, request: protocol::TlConnectRequest2) {
3358 let version = self
3359 .inner
3360 .state
3361 .get_version()
3362 .expect("must be connected")
3363 .version;
3364
3365 let hosted_silo_unaware = version < Version::Win10Rs5;
3366 self.notifier
3367 .notify_hvsock(&HvsockConnectRequest::from_message(
3368 request,
3369 hosted_silo_unaware,
3370 ));
3371 }
3372
3373 pub fn send_tl_connect_result(&mut self, result: HvsockConnectResult) {
3375 if !result.success && self.inner.state.check_version(Version::Win10Rs3_0) {
3379 self.sender().send_message(&protocol::TlConnectResult {
3383 service_id: result.service_id,
3384 endpoint_id: result.endpoint_id,
3385 status: protocol::STATUS_CONNECTION_REFUSED,
3386 })
3387 }
3388 }
3389
3390 fn handle_modify_channel(
3393 &mut self,
3394 request: &protocol::ModifyChannel,
3395 ) -> Result<(), ChannelError> {
3396 let result = self.modify_channel(request);
3397 if result.is_err() {
3398 self.send_modify_channel_response(request.channel_id, protocol::STATUS_UNSUCCESSFUL);
3399 }
3400
3401 result
3402 }
3403
3404 fn modify_channel(&mut self, request: &protocol::ModifyChannel) -> Result<(), ChannelError> {
3406 if request.target_vp == protocol::VP_INDEX_DISABLE_INTERRUPT {
3408 return Err(ChannelError::InvalidTargetVp);
3409 }
3410
3411 let (offer_id, channel) = self
3412 .inner
3413 .channels
3414 .get_by_channel_id_mut(&self.inner.assigned_channels, request.channel_id)?;
3415
3416 let (open_request, modify_state) = match &mut channel.state {
3417 ChannelState::Open {
3418 params,
3419 modify_state,
3420 reserved_state: None,
3421 } => (params, modify_state),
3422 _ => return Err(ChannelError::InvalidChannelState),
3423 };
3424
3425 if open_request.target_vp.is_none() {
3426 return Err(ChannelError::InterruptsDisabled);
3427 }
3428
3429 if let ModifyState::Modifying { pending_target_vp } = modify_state {
3430 if self.inner.state.check_version(Version::Iron) {
3431 tracelimit::warn_ratelimited!(
3434 key = %channel.offer.key(),
3435 "Client sent new ModifyChannel before receiving ModifyChannelResponse."
3436 );
3437 } else {
3438 *pending_target_vp = Some(request.target_vp);
3441 }
3442 } else {
3443 self.notifier.notify(
3444 offer_id,
3445 Action::Modify {
3446 target_vp: request.target_vp,
3447 },
3448 );
3449
3450 open_request.target_vp = Some(request.target_vp);
3452 *modify_state = ModifyState::Modifying {
3453 pending_target_vp: None,
3454 };
3455 }
3456
3457 Ok(())
3458 }
3459
3460 pub fn modify_channel_complete(&mut self, offer_id: OfferId, status: i32) {
3467 let channel = &mut self.inner.channels[offer_id];
3468
3469 if let ChannelState::Open {
3470 params,
3471 modify_state: ModifyState::Modifying { pending_target_vp },
3472 reserved_state: None,
3473 } = channel.state
3474 {
3475 channel.state = ChannelState::Open {
3476 params,
3477 modify_state: ModifyState::NotModifying,
3478 reserved_state: None,
3479 };
3480
3481 let channel_id = channel.info.as_ref().expect("assigned").channel_id;
3483 let key = channel.offer.key();
3484 self.send_modify_channel_response(channel_id, status);
3485
3486 if let Some(target_vp) = pending_target_vp {
3488 let request = protocol::ModifyChannel {
3489 channel_id,
3490 target_vp,
3491 };
3492
3493 if let Err(error) = self.handle_modify_channel(&request) {
3494 tracelimit::warn_ratelimited!(?error, %key, "Pending ModifyChannel request failed.")
3495 }
3496 }
3497 }
3498 }
3499
3500 fn send_modify_channel_response(&mut self, channel_id: ChannelId, status: i32) {
3501 if self.inner.state.check_version(Version::Iron) {
3502 self.sender()
3503 .send_message(&protocol::ModifyChannelResponse { channel_id, status });
3504 }
3505 }
3506
3507 fn handle_modify_connection(&mut self, request: protocol::ModifyConnection) {
3508 if let Err(err) = self.modify_connection(request) {
3509 tracelimit::error_ratelimited!(?err, "modifying connection failed");
3510 self.complete_modify_connection(ModifyConnectionResponse::Modified(
3511 protocol::ConnectionState::FAILED_UNKNOWN_FAILURE,
3512 ));
3513 }
3514 }
3515
3516 fn modify_connection(&mut self, request: protocol::ModifyConnection) -> anyhow::Result<()> {
3517 let ConnectionState::Connected(info) = &mut self.inner.state else {
3518 anyhow::bail!(
3519 "Invalid state for ModifyConnection request: {:?}",
3520 self.inner.state
3521 );
3522 };
3523
3524 if info.modifying {
3525 anyhow::bail!(
3526 "Duplicate ModifyConnection request, state: {:?}",
3527 self.inner.state
3528 );
3529 }
3530
3531 if matches!(
3532 info.monitor_page,
3533 Some(MonitorPageGpaInfo {
3534 server_allocated: true,
3535 ..
3536 })
3537 ) {
3538 anyhow::bail!("Cannot modify server-allocated monitor pages");
3539 }
3540
3541 if (request.child_to_parent_monitor_page_gpa == 0)
3542 != (request.parent_to_child_monitor_page_gpa == 0)
3543 {
3544 anyhow::bail!("Guest must specify either both or no monitor pages, {request:?}");
3545 }
3546
3547 let monitor_page = (request.child_to_parent_monitor_page_gpa != 0).then_some(
3548 MonitorPageGpaInfo::from_guest_gpas(MonitorPageGpas {
3549 child_to_parent: request.child_to_parent_monitor_page_gpa,
3550 parent_to_child: request.parent_to_child_monitor_page_gpa,
3551 }),
3552 );
3553
3554 info.modifying = true;
3555 info.monitor_page = monitor_page;
3556 tracing::debug!("modifying connection parameters.");
3557 self.notifier.modify_connection(request.into())?;
3558
3559 Ok(())
3560 }
3561
3562 pub fn complete_modify_connection(&mut self, response: ModifyConnectionResponse) {
3563 tracing::debug!(?response, "modifying connection parameters complete");
3564
3565 match &mut self.inner.state {
3569 ConnectionState::Connecting { .. } => self.complete_initiate_contact(response),
3570 ConnectionState::Disconnecting { .. } => self.complete_disconnect(),
3571 ConnectionState::Connected(info) => {
3572 let ModifyConnectionResponse::Modified(connection_state) = response else {
3573 panic!(
3574 "Relay should not return {:?} for a modify request with no version.",
3575 response
3576 );
3577 };
3578
3579 if !info.modifying {
3580 panic!(
3581 "ModifyConnection response while not modifying, state: {:?}",
3582 self.inner.state
3583 );
3584 }
3585
3586 info.modifying = false;
3587 self.sender()
3588 .send_message(&protocol::ModifyConnectionResponse { connection_state });
3589 }
3590 _ => panic!(
3591 "Invalid state for ModifyConnection response: {:?}",
3592 self.inner.state
3593 ),
3594 }
3595 }
3596
3597 fn handle_pause(&mut self) {
3598 tracelimit::info_ratelimited!("pausing sending messages");
3599 self.sender().send_message(&protocol::PauseResponse {});
3600 let ConnectionState::Connected(info) = &mut self.inner.state else {
3601 unreachable!(
3602 "in unexpected state {:?}, should be prevented by Message::parse()",
3603 self.inner.state
3604 );
3605 };
3606 info.paused = true;
3607 }
3608
3609 pub fn handle_synic_message(&mut self, message: SynicMessage) -> Result<(), ChannelError> {
3611 assert!(!self.is_resetting());
3612
3613 let version = self.inner.state.get_version();
3614 let msg = Message::parse(&message.data, version)?;
3615 tracing::trace!(?msg, message.trusted, "received vmbus message");
3616 if self.inner.state.is_trusted() && !message.trusted {
3621 tracelimit::warn_ratelimited!(?msg, "Received untrusted message");
3622 return Err(ChannelError::UntrustedMessage);
3623 }
3624
3625 match &mut self.inner.state {
3627 ConnectionState::Connected(info) if info.paused => {
3628 if !matches!(
3629 msg,
3630 Message::Resume(..)
3631 | Message::Unload(..)
3632 | Message::InitiateContact { .. }
3633 | Message::InitiateContact2 { .. }
3634 ) {
3635 tracelimit::warn_ratelimited!(?msg, "Received message while paused");
3636 return Err(ChannelError::Paused);
3637 }
3638 tracelimit::info_ratelimited!("resuming sending messages");
3639 info.paused = false;
3640 }
3641 _ => {}
3642 }
3643
3644 match msg {
3645 Message::InitiateContact2(input, ..) => {
3646 self.handle_initiate_contact(&input, &message, true)?
3647 }
3648 Message::InitiateContact(input, ..) => {
3649 self.handle_initiate_contact(&input.into(), &message, false)?
3650 }
3651 Message::Unload(..) => self.handle_unload(),
3652 Message::RequestOffers(..) => self.handle_request_offers()?,
3653 Message::GpadlHeader(input, range) => self.handle_gpadl_header(&input, range),
3654 Message::GpadlBody(input, range) => self.handle_gpadl_body(&input, range)?,
3655 Message::GpadlTeardown(input, ..) => self.handle_gpadl_teardown(&input)?,
3656 Message::OpenChannel(input, ..) => self.handle_open_channel(&input.into())?,
3657 Message::OpenChannel2(input, ..) => self.handle_open_channel(&input)?,
3658 Message::CloseChannel(input, ..) => self.handle_close_channel(&input)?,
3659 Message::RelIdReleased(input, ..) => self.handle_rel_id_released(&input)?,
3660 Message::TlConnectRequest(input, ..) => self.handle_tl_connect_request(input.into()),
3661 Message::TlConnectRequest2(input, ..) => self.handle_tl_connect_request(input),
3662 Message::ModifyChannel(input, ..) => self.handle_modify_channel(&input)?,
3663 Message::ModifyConnection(input, ..) => self.handle_modify_connection(input),
3664 Message::OpenReservedChannel(input, ..) => self.handle_open_reserved_channel(
3665 &input,
3666 version.expect("version validated by Message::parse"),
3667 )?,
3668 Message::CloseReservedChannel(input, ..) => {
3669 self.handle_close_reserved_channel(&input)?
3670 }
3671 Message::Pause(protocol::Pause, ..) => self.handle_pause(),
3672 Message::Resume(protocol::Resume, ..) => {}
3673 Message::OfferChannel(..)
3675 | Message::RescindChannelOffer(..)
3676 | Message::AllOffersDelivered(..)
3677 | Message::OpenResult(..)
3678 | Message::GpadlCreated(..)
3679 | Message::GpadlTorndown(..)
3680 | Message::VersionResponse(..)
3681 | Message::VersionResponse2(..)
3682 | Message::VersionResponse3(..)
3683 | Message::UnloadComplete(..)
3684 | Message::CloseReservedChannelResponse(..)
3685 | Message::TlConnectResult(..)
3686 | Message::ModifyChannelResponse(..)
3687 | Message::ModifyConnectionResponse(..)
3688 | Message::PauseResponse(..) => {
3689 unreachable!("Server received client message {:?}", msg);
3690 }
3691 }
3692 Ok(())
3693 }
3694
3695 pub fn gpadl_create_complete(&mut self, offer_id: OfferId, gpadl_id: GpadlId, status: i32) {
3697 let Some(gpadl) = self.inner.gpadls.get_mut(&(gpadl_id, offer_id)) else {
3698 tracelimit::error_ratelimited!(
3699 ?offer_id,
3700 key = %self.inner.channels[offer_id].offer.key(),
3701 ?gpadl_id,
3702 "invalid gpadl ID for channel"
3703 );
3704 return;
3705 };
3706 let retain = match gpadl.state {
3707 GpadlState::InProgress | GpadlState::TearingDown | GpadlState::Accepted => {
3708 tracelimit::error_ratelimited!(?offer_id, ?gpadl_id, ?gpadl, "invalid gpadl state");
3709 return;
3710 }
3711 GpadlState::Offered => {
3712 let channel_id = self.inner.channels[offer_id]
3713 .info
3714 .as_ref()
3715 .expect("assigned")
3716 .channel_id;
3717 self.inner
3718 .pending_messages
3719 .sender(self.notifier, self.inner.state.is_paused())
3720 .send_gpadl_created(channel_id, gpadl_id, status);
3721 if status >= 0 {
3722 gpadl.state = GpadlState::Accepted;
3723 true
3724 } else {
3725 false
3726 }
3727 }
3728 GpadlState::OfferedTearingDown => {
3729 if status >= 0 {
3730 self.notifier.notify(
3732 offer_id,
3733 Action::TeardownGpadl {
3734 gpadl_id,
3735 post_restore: false,
3736 },
3737 );
3738 gpadl.state = GpadlState::TearingDown;
3739 true
3740 } else {
3741 false
3742 }
3743 }
3744 };
3745 if !retain {
3746 self.inner
3747 .gpadls
3748 .remove(&(gpadl_id, offer_id))
3749 .expect("gpadl validated above");
3750
3751 self.check_disconnected();
3752 }
3753 }
3754
3755 pub fn gpadl_teardown_complete(&mut self, offer_id: OfferId, gpadl_id: GpadlId) {
3757 let channel = &mut self.inner.channels[offer_id];
3758 let Some(gpadl) = self.inner.gpadls.get_mut(&(gpadl_id, offer_id)) else {
3759 tracelimit::error_ratelimited!(
3760 ?offer_id,
3761 key = %channel.offer.key(),
3762 ?gpadl_id,
3763 "invalid gpadl ID for channel"
3764 );
3765 return;
3766 };
3767 tracing::debug!(
3768 offer_id = offer_id.0,
3769 key = %channel.offer.key(),
3770 gpadl_id = gpadl_id.0,
3771 "Gpadl teardown complete"
3772 );
3773 match gpadl.state {
3774 GpadlState::InProgress
3775 | GpadlState::Offered
3776 | GpadlState::OfferedTearingDown
3777 | GpadlState::Accepted => {
3778 tracelimit::error_ratelimited!(?offer_id, key = %channel.offer.key(), ?gpadl_id, ?gpadl, "invalid gpadl state");
3779 }
3780 GpadlState::TearingDown => {
3781 if !channel.state.is_released() {
3782 self.sender().send_gpadl_torndown(gpadl_id);
3783 }
3784 self.inner
3785 .gpadls
3786 .remove(&(gpadl_id, offer_id))
3787 .expect("gpadl validated above");
3788
3789 self.check_disconnected();
3790 }
3791 }
3792 }
3793
3794 fn sender(&mut self) -> MessageSender<'_, N> {
3799 self.inner
3800 .pending_messages
3801 .sender(self.notifier, self.inner.state.is_paused())
3802 }
3803}
3804
3805fn revoke<N: Notifier>(
3806 mut sender: MessageSender<'_, N>,
3807 offer_id: OfferId,
3808 channel: &mut Channel,
3809 gpadls: &mut GpadlMap,
3810) -> bool {
3811 let info = match channel.state {
3812 ChannelState::Closed
3813 | ChannelState::Open { .. }
3814 | ChannelState::Opening { .. }
3815 | ChannelState::Closing { .. }
3816 | ChannelState::ClosingReopen { .. } => {
3817 channel.state = ChannelState::Revoked;
3818 Some(channel.info.as_ref().expect("assigned"))
3819 }
3820 ChannelState::Reoffered => {
3821 channel.state = ChannelState::Revoked;
3822 None
3823 }
3824 ChannelState::ClientReleased
3825 | ChannelState::OpeningClientRelease
3826 | ChannelState::ClosingClientRelease => None,
3827 ChannelState::Revoked => return true,
3829 };
3830 let retain = !channel.state.is_released();
3831
3832 gpadls.retain(|&(gpadl_id, gpadl_offer_id), gpadl| {
3834 if gpadl_offer_id != offer_id {
3835 return true;
3836 }
3837
3838 match gpadl.state {
3839 GpadlState::InProgress => true,
3840 GpadlState::Offered => {
3841 if let Some(info) = info {
3842 sender.send_gpadl_created(
3843 info.channel_id,
3844 gpadl_id,
3845 protocol::STATUS_UNSUCCESSFUL,
3846 );
3847 }
3848 false
3849 }
3850 GpadlState::OfferedTearingDown => false,
3851 GpadlState::Accepted => true,
3852 GpadlState::TearingDown => {
3853 if info.is_some() {
3854 sender.send_gpadl_torndown(gpadl_id);
3855 }
3856 false
3857 }
3858 }
3859 });
3860 if let Some(info) = info {
3861 sender.send_rescind(info);
3862 }
3863 if channel.restore_state != RestoreState::New {
3865 channel.restore_state = RestoreState::Restored;
3866 }
3867 retain
3868}
3869
3870struct PendingMessages(VecDeque<OutgoingMessage>);
3871
3872impl PendingMessages {
3873 fn sender<'a, N: Notifier>(
3875 &'a mut self,
3876 notifier: &'a mut N,
3877 is_paused: bool,
3878 ) -> MessageSender<'a, N> {
3879 MessageSender {
3880 notifier,
3881 pending_messages: self,
3882 is_paused,
3883 }
3884 }
3885}
3886
3887struct MessageSender<'a, N> {
3890 notifier: &'a mut N,
3891 pending_messages: &'a mut PendingMessages,
3892 is_paused: bool,
3893}
3894
3895impl<N: Notifier> MessageSender<'_, N> {
3896 fn send_message<
3898 T: IntoBytes + protocol::VmbusMessage + std::fmt::Debug + Immutable + KnownLayout,
3899 >(
3900 &mut self,
3901 msg: &T,
3902 ) {
3903 let message = OutgoingMessage::new(msg);
3904
3905 tracing::trace!(typ = ?T::MESSAGE_TYPE, ?msg, "sending message");
3906 if !self.pending_messages.0.is_empty()
3908 || self.is_paused
3909 || !self.notifier.send_message(&message, MessageTarget::Default)
3910 {
3911 tracing::trace!("message queued");
3912 self.pending_messages.0.push_back(message);
3914 }
3915 }
3916
3917 fn send_message_with_target<
3919 T: IntoBytes + protocol::VmbusMessage + std::fmt::Debug + Immutable + KnownLayout,
3920 >(
3921 &mut self,
3922 msg: &T,
3923 target: MessageTarget,
3924 ) {
3925 if target == MessageTarget::Default {
3926 self.send_message(msg);
3927 } else {
3928 tracing::trace!(typ = ?T::MESSAGE_TYPE, ?msg, "sending message");
3929 let message = OutgoingMessage::new(msg);
3932 if !self.notifier.send_message(&message, target) {
3933 tracelimit::warn_ratelimited!(?target, "failed to send message");
3934 }
3935 }
3936 }
3937
3938 fn send_offer(&mut self, channel: &mut Channel, connection_info: &ConnectionInfo) {
3940 let info = channel.info.as_ref().expect("assigned");
3941 let mut flags = channel.offer.flags;
3942 if !connection_info
3943 .version
3944 .feature_flags
3945 .confidential_channels()
3946 {
3947 flags.set_confidential_ring_buffer(false);
3948 flags.set_confidential_external_memory(false);
3949 }
3950
3951 let monitor_id = connection_info.monitor_page.and(info.monitor_id);
3958 let msg = protocol::OfferChannel {
3959 interface_id: channel.offer.interface_id,
3960 instance_id: channel.offer.instance_id,
3961 rsvd: [0; 4],
3962 flags,
3963 mmio_megabytes: channel.offer.mmio_megabytes,
3964 user_defined: channel.offer.user_defined,
3965 subchannel_index: channel.offer.subchannel_index,
3966 mmio_megabytes_optional: channel.offer.mmio_megabytes_optional,
3967 channel_id: info.channel_id,
3968 monitor_id: monitor_id.unwrap_or(MonitorId::INVALID).0,
3969 monitor_allocated: monitor_id.is_some().into(),
3970 is_dedicated: 1,
3973 connection_id: info.connection_id,
3974 };
3975 tracing::info!(
3976 channel_id = msg.channel_id.0,
3977 connection_id = msg.connection_id,
3978 key = %channel.offer.key(),
3979 "sending offer to guest"
3980 );
3981
3982 self.send_message(&msg);
3983 }
3984
3985 fn send_open_result(
3986 &mut self,
3987 channel_id: ChannelId,
3988 open_request: &OpenRequest,
3989 result: i32,
3990 target: MessageTarget,
3991 ) {
3992 self.send_message_with_target(
3993 &protocol::OpenResult {
3994 channel_id,
3995 open_id: open_request.open_id,
3996 status: result as u32,
3997 },
3998 target,
3999 );
4000 }
4001
4002 fn send_gpadl_created(&mut self, channel_id: ChannelId, gpadl_id: GpadlId, status: i32) {
4003 self.send_message(&protocol::GpadlCreated {
4004 channel_id,
4005 gpadl_id,
4006 status,
4007 });
4008 }
4009
4010 fn send_gpadl_torndown(&mut self, gpadl_id: GpadlId) {
4011 self.send_message(&protocol::GpadlTorndown { gpadl_id });
4012 }
4013
4014 fn send_rescind(&mut self, info: &OfferedInfo) {
4015 tracing::info!(
4016 channel_id = info.channel_id.0,
4017 "rescinding channel from guest"
4018 );
4019
4020 self.send_message(&protocol::RescindChannelOffer {
4021 channel_id: info.channel_id,
4022 });
4023 }
4024}
4025
4026struct VersionResponseData {
4028 version: VersionInfo,
4029 state: protocol::ConnectionState,
4030 monitor_pages: Option<MonitorPageGpas>,
4031}
4032
4033impl VersionResponseData {
4034 fn new(version: VersionInfo, state: protocol::ConnectionState) -> Self {
4036 VersionResponseData {
4037 version,
4038 state,
4039 monitor_pages: None,
4040 }
4041 }
4042
4043 fn with_monitor_pages(mut self, monitor_pages: Option<MonitorPageGpas>) -> Self {
4045 self.monitor_pages = monitor_pages;
4046 self
4047 }
4048}