1pub mod saved_state;
5#[cfg(test)]
6mod tests;
7
8use crate::Guid;
9use crate::SynicMessage;
10use crate::monitor::AssignedMonitors;
11use crate::protocol::Version;
12use hvdef::Vtl;
13use inspect::Inspect;
14pub use saved_state::RestoreError;
15pub use saved_state::SavedState;
16pub use saved_state::SavedStateData;
17use slab::Slab;
18use std::cmp::min;
19use std::collections::VecDeque;
20use std::collections::hash_map::Entry;
21use std::collections::hash_map::HashMap;
22use std::fmt::Display;
23use std::ops::Index;
24use std::ops::IndexMut;
25use std::task::Poll;
26use std::task::ready;
27use std::time::Duration;
28use thiserror::Error;
29use vmbus_channel::bus::ChannelType;
30use vmbus_channel::bus::GpadlRequest;
31use vmbus_channel::bus::OfferKey;
32use vmbus_channel::bus::OfferParams;
33use vmbus_channel::bus::OpenData;
34use vmbus_channel::bus::RestoredGpadl;
35use vmbus_core::HvsockConnectRequest;
36use vmbus_core::HvsockConnectResult;
37use vmbus_core::MaxVersionInfo;
38use vmbus_core::OutgoingMessage;
39use vmbus_core::VMBUS_SINT;
40use vmbus_core::VersionInfo;
41use vmbus_core::protocol;
42use vmbus_core::protocol::ChannelId;
43use vmbus_core::protocol::ConnectionId;
44use vmbus_core::protocol::FeatureFlags;
45use vmbus_core::protocol::GpadlId;
46use vmbus_core::protocol::Message;
47use vmbus_core::protocol::OfferFlags;
48use vmbus_core::protocol::UserDefinedData;
49use vmbus_ring::gparange;
50use vmcore::monitor::MonitorId;
51use vmcore::synic::MonitorInfo;
52use vmcore::synic::MonitorPageGpas;
53use zerocopy::FromZeros;
54use zerocopy::Immutable;
55use zerocopy::IntoBytes;
56use zerocopy::KnownLayout;
57
58#[derive(Debug, Error)]
60pub enum ChannelError {
61 #[error("unknown channel ID")]
62 UnknownChannelId,
63 #[error("unknown GPADL ID")]
64 UnknownGpadlId,
65 #[error("parse error")]
66 ParseError(#[from] protocol::ParseError),
67 #[error("invalid gpa range")]
68 InvalidGpaRange(#[source] gparange::Error),
69 #[error("duplicate GPADL ID")]
70 DuplicateGpadlId,
71 #[error("GPADL is already complete")]
72 GpadlAlreadyComplete,
73 #[error("GPADL channel ID mismatch")]
74 WrongGpadlChannelId,
75 #[error("trying to open an open channel")]
76 ChannelAlreadyOpen,
77 #[error("trying to close a closed channel")]
78 ChannelNotOpen,
79 #[error("invalid GPADL state for operation")]
80 InvalidGpadlState,
81 #[error("invalid channel state for operation")]
82 InvalidChannelState,
83 #[error("channel ID has already been released")]
84 ChannelReleased,
85 #[error("channel offers have already been sent")]
86 OffersAlreadySent,
87 #[error("invalid operation on reserved channel")]
88 ChannelReserved,
89 #[error("invalid operation on non-reserved channel")]
90 ChannelNotReserved,
91 #[error("received untrusted message for trusted connection")]
92 UntrustedMessage,
93 #[error("received a non-resuming message while paused")]
94 Paused,
95 #[error("invalid target VP")]
96 InvalidTargetVp,
97 #[error("interrupts are disabled for this channel")]
98 InterruptsDisabled,
99}
100
101#[derive(Debug, Error)]
102pub enum OfferError {
103 #[error("the channel ID {} is not valid for this operation", (.0).0)]
104 InvalidChannelId(ChannelId),
105 #[error("the channel ID {} is already in use", (.0).0)]
106 ChannelIdInUse(ChannelId),
107 #[error("offer {0} already exists")]
108 AlreadyExists(OfferKey),
109 #[error("specified resources do not match those of the existing saved or revoked offer")]
110 IncompatibleResources,
111 #[error("too many channels have been offered")]
112 TooManyChannels,
113 #[error("mismatched monitor ID from saved state; expected {0:?}, actual {1:?}")]
114 MismatchedMonitorId(Option<MonitorId>, MonitorId),
115}
116
117#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
119pub struct OfferId(usize);
120
121type IncompleteGpadlMap = HashMap<GpadlId, OfferId>;
122
123type GpadlMap = HashMap<(GpadlId, OfferId), Gpadl>;
124
125pub struct Server {
127 state: ConnectionState,
128 channels: ChannelList,
129 assigned_channels: AssignedChannels,
130 assigned_monitors: AssignedMonitors,
131 gpadls: GpadlMap,
132 incomplete_gpadls: IncompleteGpadlMap,
133 child_connection_id: u32,
134 max_version: Option<MaxVersionInfo>,
135 delayed_max_version: Option<MaxVersionInfo>,
136 pending_messages: PendingMessages,
139 require_server_allocated_mnf: bool,
143 use_absolute_channel_order: bool,
144}
145
146pub struct ServerWithNotifier<'a, T> {
147 inner: &'a mut Server,
148 notifier: &'a mut T,
149}
150
151impl<T> Drop for ServerWithNotifier<'_, T> {
152 fn drop(&mut self) {
153 self.inner.validate();
154 }
155}
156
157impl<T: Notifier> Inspect for ServerWithNotifier<'_, T> {
158 fn inspect(&self, req: inspect::Request<'_>) {
159 let mut resp = req.respond();
160 let (state, info, next_action) = match &self.inner.state {
161 ConnectionState::Disconnected => ("disconnected", None, None),
162 ConnectionState::Connecting { info, .. } => ("connecting", Some(info), None),
163 ConnectionState::Connected(info) => (
164 if info.offers_sent {
165 "connected"
166 } else {
167 "negotiated"
168 },
169 Some(info),
170 None,
171 ),
172 ConnectionState::Disconnecting { next_action, .. } => {
173 ("disconnecting", None, Some(next_action))
174 }
175 };
176
177 resp.field("connection_info", info);
178 let next_action = next_action.map(|a| match a {
179 ConnectionAction::None => "disconnect",
180 ConnectionAction::Reset => "reset",
181 ConnectionAction::SendUnloadComplete => "unload",
182 ConnectionAction::Reconnect { .. } => "reconnect",
183 ConnectionAction::SendFailedVersionResponse => "send_version_response",
184 });
185 resp.field("state", state)
186 .field("next_action", next_action)
187 .field(
188 "assigned_monitors_bitmap",
189 format_args!("{:x}", self.inner.assigned_monitors.bitmap()),
190 )
191 .child("channels", |req| {
192 let mut resp = req.respond();
193 self.inner
194 .channels
195 .inspect(self.notifier, self.inner.get_version(), &mut resp);
196 for ((gpadl_id, offer_id), gpadl) in &self.inner.gpadls {
197 let channel = &self.inner.channels[*offer_id];
198 resp.field(
199 &channel_inspect_path(
200 &channel.offer,
201 format_args!("/gpadls/{}", gpadl_id.0),
202 ),
203 gpadl,
204 );
205 }
206 });
207 }
208}
209
210#[derive(Debug, Copy, Clone, Inspect)]
212struct MonitorPageGpaInfo {
213 gpas: MonitorPageGpas,
214 server_allocated: bool,
215}
216
217impl MonitorPageGpaInfo {
218 fn from_guest_gpas(gpas: MonitorPageGpas) -> Self {
220 Self {
221 gpas,
222 server_allocated: false,
223 }
224 }
225
226 fn from_server_gpas(gpas: MonitorPageGpas) -> Self {
228 Self {
229 gpas,
230 server_allocated: true,
231 }
232 }
233}
234
235#[derive(Debug, Copy, Clone, Inspect)]
236struct ConnectionInfo {
237 version: VersionInfo,
238 trusted: bool,
241 offers_sent: bool,
242 interrupt_page: Option<u64>,
243 monitor_page: Option<MonitorPageGpaInfo>,
244 target_message_vp: u32,
245 modifying: bool,
246 client_id: Guid,
247 paused: bool,
248}
249
250#[derive(Debug)]
252enum ConnectionState {
253 Disconnected,
254 Disconnecting {
255 next_action: ConnectionAction,
256 modify_sent: bool,
257 },
258 Connecting {
259 info: ConnectionInfo,
260 next_action: ConnectionAction,
261 },
262 Connected(ConnectionInfo),
263}
264
265impl ConnectionState {
266 fn check_version(&self, min_version: Version) -> bool {
268 matches!(self, ConnectionState::Connected(info) if info.version.version >= min_version)
269 }
270
271 fn check_feature_flags(&self, flags: impl Fn(FeatureFlags) -> bool) -> bool {
274 matches!(self, ConnectionState::Connected(info) if flags(info.version.feature_flags))
275 }
276
277 fn get_version(&self) -> Option<VersionInfo> {
278 if let ConnectionState::Connected(info) = self {
279 Some(info.version)
280 } else {
281 None
282 }
283 }
284
285 fn get_connected_info(&self) -> Option<&ConnectionInfo> {
287 if let ConnectionState::Connected(info) = self {
288 Some(info)
289 } else {
290 None
291 }
292 }
293
294 fn is_trusted(&self) -> bool {
295 match self {
296 ConnectionState::Connected(info) => info.trusted,
297 ConnectionState::Connecting { info, .. } => info.trusted,
298 _ => false,
299 }
300 }
301
302 fn is_paused(&self) -> bool {
303 if let ConnectionState::Connected(info) = self {
304 info.paused
305 } else {
306 false
307 }
308 }
309}
310
311#[derive(Debug, Copy, Clone)]
312enum ConnectionAction {
313 None,
314 Reset,
315 SendUnloadComplete,
316 Reconnect {
317 initiate_contact: InitiateContactRequest,
318 },
319 SendFailedVersionResponse,
320}
321
322#[derive(PartialEq, Eq, Debug, Copy, Clone)]
323pub enum MonitorPageRequest {
324 None,
325 Some(MonitorPageGpas),
326 Invalid,
327}
328
329#[derive(PartialEq, Eq, Debug, Copy, Clone)]
330pub struct InitiateContactRequest {
331 pub version_requested: u32,
332 pub target_message_vp: u32,
333 pub monitor_page: MonitorPageRequest,
334 pub target_sint: u8,
335 pub target_vtl: u8,
336 pub feature_flags: u32,
337 pub interrupt_page: Option<u64>,
338 pub client_id: Guid,
339 pub trusted: bool,
340}
341
342#[derive(Debug, Copy, Clone)]
343pub struct OpenRequest {
344 pub open_id: u32,
345 pub ring_buffer_gpadl_id: GpadlId,
346 pub target_vp: Option<u32>,
347 pub downstream_ring_buffer_page_offset: u32,
348 pub user_data: UserDefinedData,
349 pub guest_specified_interrupt_info: Option<SignalInfo>,
350 pub flags: protocol::OpenChannelFlags,
351}
352
353#[derive(Debug, Copy, Clone, Eq, PartialEq)]
354pub enum Update<T: std::fmt::Debug + Copy + Clone> {
355 Unchanged,
356 Reset,
357 Set(T),
358}
359
360impl<T: std::fmt::Debug + Copy + Clone> From<Option<T>> for Update<T> {
361 fn from(value: Option<T>) -> Self {
362 match value {
363 None => Self::Reset,
364 Some(value) => Self::Set(value),
365 }
366 }
367}
368
369#[derive(Debug, Copy, Clone, Eq, PartialEq)]
370pub struct ModifyConnectionRequest {
371 pub version: Option<VersionInfo>,
372 pub monitor_page: Update<MonitorPageGpas>,
373 pub interrupt_page: Update<u64>,
374 pub target_message_vp: Option<u32>,
375 pub notify_relay: bool,
376}
377
378impl Default for ModifyConnectionRequest {
380 fn default() -> Self {
381 Self {
382 version: None,
383 monitor_page: Update::Unchanged,
384 interrupt_page: Update::Unchanged,
385 target_message_vp: None,
386 notify_relay: true,
387 }
388 }
389}
390
391impl From<protocol::ModifyConnection> for ModifyConnectionRequest {
392 fn from(value: protocol::ModifyConnection) -> Self {
393 let monitor_page = if value.parent_to_child_monitor_page_gpa != 0 {
394 Update::Set(MonitorPageGpas {
395 parent_to_child: value.parent_to_child_monitor_page_gpa,
396 child_to_parent: value.child_to_parent_monitor_page_gpa,
397 })
398 } else {
399 Update::Reset
400 };
401
402 Self {
403 monitor_page,
404 ..Default::default()
405 }
406 }
407}
408
409#[derive(Debug, Copy, Clone)]
411pub enum ModifyConnectionResponse {
412 Supported(
418 protocol::ConnectionState,
419 FeatureFlags,
420 Option<MonitorPageGpas>,
421 ),
422 Unsupported,
424 Modified(protocol::ConnectionState),
427}
428
429#[derive(Debug, Copy, Clone)]
430pub enum ModifyState {
431 NotModifying,
432 Modifying { pending_target_vp: Option<u32> },
433}
434
435impl ModifyState {
436 pub fn is_modifying(&self) -> bool {
437 matches!(self, ModifyState::Modifying { .. })
438 }
439}
440
441#[derive(Debug, Copy, Clone)]
442pub struct SignalInfo {
443 pub event_flag: u16,
444 pub connection_id: u32,
445}
446
447#[derive(Debug, Copy, Clone, PartialEq, Eq)]
448enum RestoreState {
449 New,
451 Restoring,
455 Unmatched,
458 Restored,
460}
461
462#[derive(Debug, Clone)]
464enum ChannelState {
465 ClientReleased,
469
470 Closed,
472
473 Opening {
476 request: OpenRequest,
477 reserved_state: Option<ReservedState>,
478 },
479
480 Open {
482 params: OpenRequest,
483 modify_state: ModifyState,
484 reserved_state: Option<ReservedState>,
485 },
486
487 Closing {
489 params: OpenRequest,
490 reserved_state: Option<ReservedState>,
491 },
492
493 ClosingReopen {
496 params: OpenRequest,
497 request: OpenRequest,
498 },
499
500 Revoked,
502
503 Reoffered,
506
507 ClosingClientRelease,
510
511 OpeningClientRelease,
514}
515
516impl ChannelState {
517 fn is_released(&self) -> bool {
520 match self {
521 ChannelState::Closed
522 | ChannelState::Opening { .. }
523 | ChannelState::Open { .. }
524 | ChannelState::Closing { .. }
525 | ChannelState::ClosingReopen { .. }
526 | ChannelState::Revoked
527 | ChannelState::Reoffered => false,
528
529 ChannelState::ClientReleased
530 | ChannelState::ClosingClientRelease
531 | ChannelState::OpeningClientRelease => true,
532 }
533 }
534
535 fn is_revoked(&self) -> bool {
537 match self {
538 ChannelState::Revoked | ChannelState::Reoffered => true,
539
540 ChannelState::ClientReleased
541 | ChannelState::Closed
542 | ChannelState::Opening { .. }
543 | ChannelState::Open { .. }
544 | ChannelState::Closing { .. }
545 | ChannelState::ClosingReopen { .. }
546 | ChannelState::ClosingClientRelease
547 | ChannelState::OpeningClientRelease => false,
548 }
549 }
550
551 fn is_reserved(&self) -> bool {
552 match self {
553 ChannelState::Open {
555 reserved_state: Some(_),
556 ..
557 }
558 | ChannelState::Opening {
559 reserved_state: Some(_),
560 ..
561 }
562 | ChannelState::Closing {
563 reserved_state: Some(_),
564 ..
565 } => true,
566
567 ChannelState::Opening { .. }
568 | ChannelState::Open { .. }
569 | ChannelState::Closing { .. }
570 | ChannelState::ClientReleased
571 | ChannelState::Closed
572 | ChannelState::ClosingReopen { .. }
573 | ChannelState::Revoked
574 | ChannelState::Reoffered
575 | ChannelState::ClosingClientRelease
576 | ChannelState::OpeningClientRelease => false,
577 }
578 }
579}
580
581impl Display for ChannelState {
582 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
583 let state = match self {
584 Self::ClientReleased => "ClientReleased",
585 Self::Closed => "Closed",
586 Self::Opening { .. } => "Opening",
587 Self::Open { .. } => "Open",
588 Self::Closing { .. } => "Closing",
589 Self::ClosingReopen { .. } => "ClosingReopen",
590 Self::Revoked => "Revoked",
591 Self::Reoffered => "Reoffered",
592 Self::ClosingClientRelease => "ClosingClientRelease",
593 Self::OpeningClientRelease => "OpeningClientRelease",
594 };
595 write!(f, "{}", state)
596 }
597}
598
599#[derive(Debug, Clone, Default, mesh::MeshPayload)]
601pub enum MnfUsage {
602 #[default]
604 Disabled,
605 Enabled { latency: Duration },
607 Relayed { monitor_id: u8 },
610}
611
612impl MnfUsage {
613 pub fn is_enabled(&self) -> bool {
614 matches!(self, Self::Enabled { .. })
615 }
616
617 pub fn is_relayed(&self) -> bool {
618 matches!(self, Self::Relayed { .. })
619 }
620
621 pub fn enabled_and_then<T>(&self, f: impl FnOnce(Duration) -> Option<T>) -> Option<T> {
622 if let Self::Enabled { latency } = self {
623 f(*latency)
624 } else {
625 None
626 }
627 }
628}
629
630impl From<Option<Duration>> for MnfUsage {
631 fn from(value: Option<Duration>) -> Self {
632 match value {
633 None => Self::Disabled,
634 Some(latency) => Self::Enabled { latency },
635 }
636 }
637}
638
639#[derive(Debug, Clone, Default, mesh::MeshPayload)]
640pub struct OfferParamsInternal {
641 pub interface_name: String,
643 pub instance_id: Guid,
644 pub interface_id: Guid,
645 pub mmio_megabytes: u16,
646 pub mmio_megabytes_optional: u16,
647 pub subchannel_index: u16,
648 pub use_mnf: MnfUsage,
649 pub offer_order: Option<u64>,
650 pub flags: OfferFlags,
651 pub user_defined: UserDefinedData,
652}
653
654impl OfferParamsInternal {
655 pub fn key(&self) -> OfferKey {
657 OfferKey {
658 interface_id: self.interface_id,
659 instance_id: self.instance_id,
660 subchannel_index: self.subchannel_index,
661 }
662 }
663}
664
665impl From<OfferParams> for OfferParamsInternal {
666 fn from(value: OfferParams) -> Self {
667 let mut user_defined = UserDefinedData::new_zeroed();
668
669 let mut flags = OfferFlags::new()
672 .with_confidential_ring_buffer(true)
673 .with_confidential_external_memory(value.allow_confidential_external_memory);
674
675 match value.channel_type {
676 ChannelType::Device { pipe_packets } => {
677 if pipe_packets {
678 flags.set_named_pipe_mode(true);
679 user_defined.as_pipe_params_mut().pipe_type = protocol::PipeType::MESSAGE;
680 }
681 }
682 ChannelType::Interface {
683 user_defined: interface_user_defined,
684 } => {
685 flags.set_enumerate_device_interface(true);
686 user_defined = interface_user_defined;
687 }
688 ChannelType::Pipe { message_mode } => {
689 flags.set_enumerate_device_interface(true);
690 flags.set_named_pipe_mode(true);
691 user_defined.as_pipe_params_mut().pipe_type = if message_mode {
692 protocol::PipeType::MESSAGE
693 } else {
694 protocol::PipeType::BYTE
695 };
696 }
697 ChannelType::HvSocket {
698 is_connect,
699 is_for_container,
700 silo_id,
701 } => {
702 flags.set_enumerate_device_interface(true);
703 flags.set_tlnpi_provider(true);
704 flags.set_named_pipe_mode(true);
705 *user_defined.as_hvsock_params_mut() = protocol::HvsockUserDefinedParameters::new(
706 is_connect,
707 is_for_container,
708 silo_id,
709 );
710 }
711 };
712
713 Self {
714 interface_name: value.interface_name,
715 instance_id: value.instance_id,
716 interface_id: value.interface_id,
717 mmio_megabytes: value.mmio_megabytes,
718 mmio_megabytes_optional: value.mmio_megabytes_optional,
719 subchannel_index: value.subchannel_index,
720 use_mnf: value.mnf_interrupt_latency.into(),
721 offer_order: value.offer_order,
722 user_defined,
723 flags,
724 }
725 }
726}
727
728#[derive(Debug, Copy, Clone, Inspect, PartialEq, Eq)]
729pub struct ConnectionTarget {
730 pub vp: u32,
731 pub sint: u8,
732}
733
734#[derive(Debug, Copy, Clone, PartialEq, Eq)]
735pub enum MessageTarget {
736 Default,
737 ReservedChannel(OfferId, ConnectionTarget),
738 Custom(ConnectionTarget),
739}
740
741impl MessageTarget {
742 pub fn for_offer(offer_id: OfferId, reserved_state: &Option<ReservedState>) -> Self {
743 if let Some(state) = reserved_state {
744 Self::ReservedChannel(offer_id, state.target)
745 } else {
746 Self::Default
747 }
748 }
749}
750
751#[derive(Debug, Copy, Clone)]
752pub struct ReservedState {
753 version: VersionInfo,
754 target: ConnectionTarget,
755}
756
757#[derive(Debug)]
759struct Channel {
760 info: Option<OfferedInfo>,
761 offer: OfferParamsInternal,
762 state: ChannelState,
763 restore_state: RestoreState,
764}
765
766#[derive(Debug, Copy, Clone)]
767struct OfferedInfo {
768 channel_id: ChannelId,
769 connection_id: u32,
770 monitor_id: Option<MonitorId>,
771}
772
773impl Channel {
774 fn inspect_state(&self, resp: &mut inspect::Response<'_>) {
775 let mut target_vp = None;
776 let mut event_flag = None;
777 let mut connection_id = None;
778 let mut reserved_target = None;
779 let state = match &self.state {
780 ChannelState::ClientReleased => "client_released",
781 ChannelState::Closed => "closed",
782 ChannelState::Opening { reserved_state, .. } => {
783 reserved_target = reserved_state.map(|state| state.target);
784 "opening"
785 }
786 ChannelState::Open {
787 params,
788 reserved_state,
789 ..
790 } => {
791 target_vp = Some(params.target_vp);
792 if let Some(id) = params.guest_specified_interrupt_info {
793 event_flag = Some(id.event_flag);
794 connection_id = Some(id.connection_id);
795 }
796 reserved_target = reserved_state.map(|state| state.target);
797 "open"
798 }
799 ChannelState::Closing { reserved_state, .. } => {
800 reserved_target = reserved_state.map(|state| state.target);
801 "closing"
802 }
803 ChannelState::ClosingReopen { .. } => "closing_reopen",
804 ChannelState::Revoked => "revoked",
805 ChannelState::Reoffered => "reoffered",
806 ChannelState::ClosingClientRelease => "closing_client_release",
807 ChannelState::OpeningClientRelease => "opening_client_release",
808 };
809 let restore_state = match self.restore_state {
810 RestoreState::New => "new",
811 RestoreState::Restoring => "restoring",
812 RestoreState::Restored => "restored",
813 RestoreState::Unmatched => "unmatched",
814 };
815 if let Some(info) = &self.info {
816 resp.field("channel_id", info.channel_id.0)
817 .field("offered_connection_id", info.connection_id)
818 .field("monitor_id", info.monitor_id.map(|id| id.0));
819 }
820 resp.field("state", state)
821 .field("restore_state", restore_state)
822 .field("interface_name", self.offer.interface_name.clone())
823 .display("instance_id", &self.offer.instance_id)
824 .display("interface_id", &self.offer.interface_id)
825 .field("mmio_megabytes", self.offer.mmio_megabytes)
826 .field("target_vp", target_vp)
827 .field("guest_specified_event_flag", event_flag)
828 .field("guest_specified_connection_id", connection_id)
829 .field("reserved_connection_target", reserved_target)
830 .binary("offer_flags", self.offer.flags.into_bits());
831 }
832
833 fn handled_monitor_info(&self) -> Option<MonitorInfo> {
843 self.offer.use_mnf.enabled_and_then(|latency| {
844 if self.state.is_reserved() {
845 None
846 } else {
847 self.info.and_then(|info| {
848 info.monitor_id.map(|monitor_id| MonitorInfo {
849 monitor_id,
850 latency,
851 })
852 })
853 }
854 })
855 }
856
857 fn prepare_channel(
860 &mut self,
861 offer_id: OfferId,
862 assigned_channels: &mut AssignedChannels,
863 assigned_monitors: &mut AssignedMonitors,
864 ) {
865 assert!(self.info.is_none());
866
867 let entry = assigned_channels
869 .allocate()
870 .expect("there are enough channel IDs for everything in ChannelList");
871
872 let channel_id = entry.id();
873 entry.insert(offer_id);
874 let connection_id = ConnectionId::new(channel_id.0, assigned_channels.vtl, VMBUS_SINT);
875
876 let monitor_id = match self.offer.use_mnf {
881 MnfUsage::Enabled { .. } => {
882 let monitor_id = assigned_monitors.assign_monitor();
883 if monitor_id.is_none() {
884 tracelimit::warn_ratelimited!("Out of monitor IDs.");
885 }
886
887 monitor_id
888 }
889 MnfUsage::Relayed { monitor_id } => Some(MonitorId(monitor_id)),
890 MnfUsage::Disabled => None,
891 };
892
893 self.info = Some(OfferedInfo {
894 channel_id,
895 connection_id: connection_id.0,
896 monitor_id,
897 });
898 }
899
900 fn release_channel(
902 &mut self,
903 offer_id: OfferId,
904 assigned_channels: &mut AssignedChannels,
905 assigned_monitors: &mut AssignedMonitors,
906 ) {
907 if let Some(info) = self.info.take() {
908 assigned_channels.free(info.channel_id, offer_id);
909
910 if let Some(monitor_id) = info.monitor_id {
912 if self.offer.use_mnf.is_enabled() {
913 assigned_monitors.release_monitor(monitor_id);
914 }
915 }
916 }
917 }
918}
919
920#[derive(Debug)]
921struct AssignedChannels {
922 assignments: Vec<Option<OfferId>>,
923 vtl: Vtl,
924 reserved_offset: usize,
925 count_in_reserved_range: usize,
927}
928
929impl AssignedChannels {
930 fn new(vtl: Vtl, channel_id_offset: u16) -> Self {
931 Self {
932 assignments: vec![None; MAX_CHANNELS],
933 vtl,
934 reserved_offset: channel_id_offset as usize,
935 count_in_reserved_range: 0,
936 }
937 }
938
939 fn allowable_channel_count(&self) -> usize {
940 MAX_CHANNELS - self.reserved_offset + self.count_in_reserved_range
941 }
942
943 fn get(&self, channel_id: ChannelId) -> Option<OfferId> {
944 self.assignments
945 .get(Self::index(channel_id))
946 .copied()
947 .flatten()
948 }
949
950 fn set(&mut self, channel_id: ChannelId) -> Result<AssignmentEntry<'_>, OfferError> {
951 let index = Self::index(channel_id);
952 if self
953 .assignments
954 .get(index)
955 .ok_or(OfferError::InvalidChannelId(channel_id))?
956 .is_some()
957 {
958 return Err(OfferError::ChannelIdInUse(channel_id));
959 }
960 Ok(AssignmentEntry { list: self, index })
961 }
962
963 fn allocate(&mut self) -> Option<AssignmentEntry<'_>> {
964 let index = self.reserved_offset
965 + self.assignments[self.reserved_offset..]
966 .iter()
967 .position(|x| x.is_none())?;
968 Some(AssignmentEntry { list: self, index })
969 }
970
971 fn free(&mut self, channel_id: ChannelId, offer_id: OfferId) {
972 let index = Self::index(channel_id);
973 let slot = &mut self.assignments[index];
974 assert_eq!(slot.take(), Some(offer_id));
975 if index < self.reserved_offset {
976 self.count_in_reserved_range -= 1;
977 }
978 }
979
980 fn index(channel_id: ChannelId) -> usize {
981 channel_id.0.wrapping_sub(1) as usize
982 }
983}
984
985struct AssignmentEntry<'a> {
986 list: &'a mut AssignedChannels,
987 index: usize,
988}
989
990impl AssignmentEntry<'_> {
991 pub fn id(&self) -> ChannelId {
992 ChannelId(self.index as u32 + 1)
993 }
994
995 pub fn insert(self, offer_id: OfferId) {
996 assert!(
997 self.list.assignments[self.index]
998 .replace(offer_id)
999 .is_none()
1000 );
1001
1002 if self.index < self.list.reserved_offset {
1003 self.list.count_in_reserved_range += 1;
1004 }
1005 }
1006}
1007
1008struct ChannelList {
1009 channels: Slab<Channel>,
1010}
1011
1012fn channel_inspect_path(offer: &OfferParamsInternal, suffix: std::fmt::Arguments<'_>) -> String {
1013 if offer.subchannel_index == 0 {
1014 format!("{}{}", offer.instance_id, suffix)
1015 } else {
1016 format!(
1017 "{}/subchannels/{}{}",
1018 offer.instance_id, offer.subchannel_index, suffix
1019 )
1020 }
1021}
1022
1023impl ChannelList {
1024 fn inspect(
1025 &self,
1026 notifier: &impl Notifier,
1027 version: Option<VersionInfo>,
1028 resp: &mut inspect::Response<'_>,
1029 ) {
1030 for (offer_id, channel) in self.iter() {
1031 resp.child(
1032 &channel_inspect_path(&channel.offer, format_args!("")),
1033 |req| {
1034 let mut resp = req.respond();
1035 channel.inspect_state(&mut resp);
1036
1037 resp.merge(inspect::adhoc(|req| {
1041 if !matches!(channel.state, ChannelState::Revoked) {
1042 notifier.inspect(version, offer_id, req);
1043 }
1044 }));
1045 },
1046 );
1047 }
1048 }
1049}
1050
1051pub const MAX_CHANNELS: usize = 2047;
1054
1055impl ChannelList {
1056 fn new() -> Self {
1057 Self {
1058 channels: Slab::new(),
1059 }
1060 }
1061
1062 fn len(&self) -> usize {
1064 self.channels.len()
1065 }
1066
1067 fn offer(&mut self, new_channel: Channel) -> OfferId {
1069 OfferId(self.channels.insert(new_channel))
1070 }
1071
1072 fn remove(&mut self, offer_id: OfferId) {
1074 let channel = self.channels.remove(offer_id.0);
1075 assert!(channel.info.is_none());
1076 }
1077
1078 fn get_by_channel_id_mut(
1080 &mut self,
1081 assigned_channels: &AssignedChannels,
1082 channel_id: ChannelId,
1083 ) -> Result<(OfferId, &mut Channel), ChannelError> {
1084 let offer_id = assigned_channels
1085 .get(channel_id)
1086 .ok_or(ChannelError::UnknownChannelId)?;
1087 let channel = &mut self[offer_id];
1088 if channel.state.is_released() {
1089 return Err(ChannelError::ChannelReleased);
1090 }
1091 assert_eq!(
1092 channel.info.as_ref().map(|info| info.channel_id),
1093 Some(channel_id)
1094 );
1095 Ok((offer_id, channel))
1096 }
1097
1098 fn get_by_channel_id(
1100 &self,
1101 assigned_channels: &AssignedChannels,
1102 channel_id: ChannelId,
1103 ) -> Result<(OfferId, &Channel), ChannelError> {
1104 let offer_id = assigned_channels
1105 .get(channel_id)
1106 .ok_or(ChannelError::UnknownChannelId)?;
1107 let channel = &self[offer_id];
1108 if channel.state.is_released() {
1109 return Err(ChannelError::ChannelReleased);
1110 }
1111 assert_eq!(
1112 channel.info.as_ref().map(|info| info.channel_id),
1113 Some(channel_id)
1114 );
1115 Ok((offer_id, channel))
1116 }
1117
1118 fn get_by_key_mut(&mut self, key: &OfferKey) -> Option<(OfferId, &mut Channel)> {
1121 for (offer_id, channel) in self.iter_mut() {
1122 if channel.offer.instance_id == key.instance_id
1123 && channel.offer.interface_id == key.interface_id
1124 && channel.offer.subchannel_index == key.subchannel_index
1125 {
1126 return Some((offer_id, channel));
1127 }
1128 }
1129 None
1130 }
1131
1132 fn iter(&self) -> impl Iterator<Item = (OfferId, &Channel)> {
1134 self.channels
1135 .iter()
1136 .map(|(id, channel)| (OfferId(id), channel))
1137 }
1138
1139 fn iter_mut(&mut self) -> impl Iterator<Item = (OfferId, &mut Channel)> {
1141 self.channels
1142 .iter_mut()
1143 .map(|(id, channel)| (OfferId(id), channel))
1144 }
1145
1146 fn retain<F>(&mut self, mut f: F)
1148 where
1149 F: FnMut(OfferId, &mut Channel) -> bool,
1150 {
1151 self.channels.retain(|id, channel| {
1152 let retain = f(OfferId(id), channel);
1153 if !retain {
1154 assert!(channel.info.is_none());
1155 }
1156 retain
1157 })
1158 }
1159}
1160
1161impl Index<OfferId> for ChannelList {
1162 type Output = Channel;
1163
1164 fn index(&self, offer_id: OfferId) -> &Self::Output {
1165 &self.channels[offer_id.0]
1166 }
1167}
1168
1169impl IndexMut<OfferId> for ChannelList {
1170 fn index_mut(&mut self, offer_id: OfferId) -> &mut Self::Output {
1171 &mut self.channels[offer_id.0]
1172 }
1173}
1174
1175#[derive(Debug, Inspect)]
1177struct Gpadl {
1178 count: u16,
1179 #[inspect(skip)]
1180 buf: Vec<u64>,
1181 state: GpadlState,
1182}
1183
1184#[derive(Debug, Copy, Clone, PartialEq, Eq, Inspect)]
1185enum GpadlState {
1186 InProgress,
1188 Offered,
1190 OfferedTearingDown,
1193 Accepted,
1195 TearingDown,
1197}
1198
1199impl Gpadl {
1200 fn new(count: u16, len: usize) -> Self {
1203 Self {
1204 state: GpadlState::InProgress,
1205 count,
1206 buf: Vec::with_capacity(len),
1207 }
1208 }
1209
1210 fn append(&mut self, data: &[u8]) -> Result<bool, ChannelError> {
1212 if self.state == GpadlState::InProgress {
1213 let buf = &mut self.buf;
1214 let len = min(data.len() & !7, (buf.capacity() - buf.len()) * 8);
1219 let data = &data[..len];
1220 let start = buf.len();
1221 buf.resize(buf.len() + data.len() / 8, 0);
1222 buf[start..].as_mut_bytes().copy_from_slice(data);
1223 Ok(if buf.len() == buf.capacity() {
1224 gparange::validate_gpa_ranges(self.count as usize, buf)
1225 .map_err(ChannelError::InvalidGpaRange)?;
1226 self.state = GpadlState::Offered;
1227 true
1228 } else {
1229 false
1230 })
1231 } else {
1232 Err(ChannelError::GpadlAlreadyComplete)
1233 }
1234 }
1235}
1236
1237#[derive(Debug, Copy, Clone)]
1239pub struct OpenParams {
1240 pub open_data: OpenData,
1241 pub connection_id: u32,
1242 pub event_flag: u16,
1243 pub monitor_info: Option<MonitorInfo>,
1244 pub flags: protocol::OpenChannelFlags,
1245 pub reserved_target: Option<ConnectionTarget>,
1246 pub channel_id: ChannelId,
1247}
1248
1249impl OpenParams {
1250 fn from_request(
1251 info: &OfferedInfo,
1252 request: &OpenRequest,
1253 monitor_info: Option<MonitorInfo>,
1254 reserved_target: Option<ConnectionTarget>,
1255 ) -> Self {
1256 let (event_flag, connection_id) = if let Some(id) = request.guest_specified_interrupt_info {
1259 (id.event_flag, id.connection_id)
1260 } else {
1261 (info.channel_id.0 as u16, info.connection_id)
1262 };
1263
1264 Self {
1265 open_data: OpenData {
1266 target_vp: request.target_vp,
1267 ring_offset: request.downstream_ring_buffer_page_offset,
1268 ring_gpadl_id: request.ring_buffer_gpadl_id,
1269 user_data: request.user_data,
1270 event_flag,
1271 connection_id,
1272 },
1273 connection_id,
1274 event_flag,
1275 monitor_info,
1276 flags: request.flags.with_unused(0),
1277 reserved_target,
1278 channel_id: info.channel_id,
1279 }
1280 }
1281}
1282
1283#[derive(Debug)]
1285pub enum Action {
1286 Open(OpenParams, VersionInfo),
1287 Close,
1288 Gpadl(GpadlId, u16, Vec<u64>),
1289 TeardownGpadl {
1290 gpadl_id: GpadlId,
1291 post_restore: bool,
1292 },
1293 Modify {
1294 target_vp: u32,
1295 },
1296}
1297
1298static SUPPORTED_VERSIONS: &[Version] = &[
1300 Version::V1,
1301 Version::Win7,
1302 Version::Win8,
1303 Version::Win8_1,
1304 Version::Win10,
1305 Version::Win10Rs3_0,
1306 Version::Win10Rs3_1,
1307 Version::Win10Rs4,
1308 Version::Win10Rs5,
1309 Version::Iron,
1310 Version::Copper,
1311];
1312
1313const SUPPORTED_FEATURE_FLAGS: FeatureFlags = FeatureFlags::new()
1316 .with_guest_specified_signal_parameters(true)
1317 .with_channel_interrupt_redirection(true)
1318 .with_modify_connection(true)
1319 .with_client_id(true)
1320 .with_pause_resume(true)
1321 .with_server_specified_monitor_pages(true);
1322
1323pub trait Notifier: Send {
1325 fn notify(&mut self, offer_id: OfferId, action: Action);
1327
1328 fn forward_unhandled(&mut self, request: InitiateContactRequest);
1330
1331 fn modify_connection(&mut self, request: ModifyConnectionRequest) -> anyhow::Result<()>;
1337
1338 fn inspect(&self, version: Option<VersionInfo>, offer_id: OfferId, req: inspect::Request<'_>) {
1340 let _ = (version, offer_id, req);
1341 }
1342
1343 #[must_use]
1346 fn send_message(&mut self, message: &OutgoingMessage, target: MessageTarget) -> bool;
1347
1348 fn notify_hvsock(&mut self, request: &HvsockConnectRequest);
1350
1351 fn reset_complete(&mut self);
1353
1354 fn unload_complete(&mut self);
1356}
1357
1358impl Server {
1359 pub fn new(
1361 vtl: Vtl,
1362 child_connection_id: u32,
1363 channel_id_offset: u16,
1364 use_absolute_channel_order: bool,
1365 ) -> Self {
1366 Server {
1367 state: ConnectionState::Disconnected,
1368 channels: ChannelList::new(),
1369 assigned_channels: AssignedChannels::new(vtl, channel_id_offset),
1370 assigned_monitors: AssignedMonitors::new(),
1371 gpadls: Default::default(),
1372 incomplete_gpadls: Default::default(),
1373 child_connection_id,
1374 max_version: None,
1375 delayed_max_version: None,
1376 pending_messages: PendingMessages(VecDeque::new()),
1377 require_server_allocated_mnf: false,
1378 use_absolute_channel_order,
1379 }
1380 }
1381
1382 pub fn with_notifier<'a, T: Notifier>(
1384 &'a mut self,
1385 notifier: &'a mut T,
1386 ) -> ServerWithNotifier<'a, T> {
1387 self.validate();
1388 ServerWithNotifier {
1389 inner: self,
1390 notifier,
1391 }
1392 }
1393
1394 pub fn set_require_server_allocated_mnf(&mut self, require: bool) {
1397 self.require_server_allocated_mnf = require;
1398 }
1399
1400 fn validate(&self) {
1401 #[cfg(debug_assertions)]
1402 for (_, channel) in self.channels.iter() {
1403 let should_have_info = !channel.state.is_released();
1404 if channel.info.is_some() != should_have_info {
1405 panic!("channel invariant violation: {channel:?}");
1406 }
1407 }
1408 }
1409
1410 pub fn set_compatibility_version(&mut self, version: MaxVersionInfo, delay: bool) {
1412 if delay {
1413 self.delayed_max_version = Some(version)
1414 } else {
1415 tracing::info!(?version, "Limiting VmBus connections to version");
1416 self.max_version = Some(version);
1417 }
1418 }
1419
1420 pub fn channel_gpadls(&self, offer_id: OfferId) -> Vec<RestoredGpadl> {
1421 self.gpadls
1422 .iter()
1423 .filter_map(|(&(gpadl_id, gpadl_offer_id), gpadl)| {
1424 if offer_id != gpadl_offer_id {
1425 return None;
1426 }
1427 let accepted = match gpadl.state {
1428 GpadlState::Offered | GpadlState::OfferedTearingDown => false,
1429 GpadlState::Accepted => true,
1430 GpadlState::InProgress | GpadlState::TearingDown => return None,
1431 };
1432 Some(RestoredGpadl {
1433 request: GpadlRequest {
1434 id: gpadl_id,
1435 count: gpadl.count,
1436 buf: gpadl.buf.clone(),
1437 },
1438 accepted,
1439 })
1440 })
1441 .collect()
1442 }
1443
1444 pub fn get_version(&self) -> Option<VersionInfo> {
1445 self.state.get_version()
1446 }
1447
1448 pub fn get_restore_open_params(&self, offer_id: OfferId) -> Result<OpenParams, RestoreError> {
1449 let channel = &self.channels[offer_id];
1450
1451 match channel.restore_state {
1453 RestoreState::New => {
1454 return Err(RestoreError::MissingChannel(channel.offer.key()));
1458 }
1459 RestoreState::Restoring => {}
1460 RestoreState::Unmatched => unreachable!(),
1461 RestoreState::Restored => {
1462 return Err(RestoreError::AlreadyRestored(channel.offer.key()));
1463 }
1464 }
1465
1466 let info = channel
1467 .info
1468 .ok_or_else(|| RestoreError::MissingChannel(channel.offer.key()))?;
1469
1470 let (request, reserved_state) = match channel.state {
1471 ChannelState::Closed => {
1472 return Err(RestoreError::MismatchedOpenState(channel.offer.key()));
1473 }
1474 ChannelState::Closing { params, .. } | ChannelState::ClosingReopen { params, .. } => {
1475 (params, None)
1476 }
1477 ChannelState::Opening {
1478 request,
1479 reserved_state,
1480 } => (request, reserved_state),
1481 ChannelState::Open {
1482 params,
1483 reserved_state,
1484 ..
1485 } => (params, reserved_state),
1486 ChannelState::ClientReleased | ChannelState::Reoffered => {
1487 return Err(RestoreError::MissingChannel(channel.offer.key()));
1488 }
1489 ChannelState::Revoked
1490 | ChannelState::ClosingClientRelease
1491 | ChannelState::OpeningClientRelease => unreachable!(),
1492 };
1493
1494 Ok(OpenParams::from_request(
1495 &info,
1496 &request,
1497 channel.handled_monitor_info(),
1498 reserved_state.map(|state| state.target),
1499 ))
1500 }
1501
1502 pub fn has_pending_messages(&self) -> bool {
1504 !self.pending_messages.0.is_empty() && !self.state.is_paused()
1505 }
1506
1507 pub fn poll_flush_pending_messages(
1509 &mut self,
1510 mut send: impl FnMut(&OutgoingMessage) -> Poll<()>,
1511 ) -> Poll<()> {
1512 if !self.state.is_paused() {
1513 while let Some(message) = self.pending_messages.0.front() {
1514 ready!(send(message));
1515 self.pending_messages.0.pop_front();
1516 }
1517 }
1518
1519 Poll::Ready(())
1520 }
1521}
1522
1523impl<'a, N: 'a + Notifier> ServerWithNotifier<'a, N> {
1524 pub fn restore_channel(&mut self, offer_id: OfferId, open: bool) -> Result<(), RestoreError> {
1530 let channel = &mut self.inner.channels[offer_id];
1531
1532 match channel.restore_state {
1535 RestoreState::New => {
1536 if open {
1540 return Err(RestoreError::MissingChannel(channel.offer.key()));
1541 } else {
1542 return Ok(());
1543 }
1544 }
1545 RestoreState::Restoring => {}
1546 RestoreState::Unmatched => unreachable!(),
1547 RestoreState::Restored => {
1548 return Err(RestoreError::AlreadyRestored(channel.offer.key()));
1549 }
1550 }
1551
1552 let info = channel
1553 .info
1554 .ok_or_else(|| RestoreError::MissingChannel(channel.offer.key()))?;
1555
1556 if let Some(monitor_info) = channel.handled_monitor_info() {
1557 if !self
1558 .inner
1559 .assigned_monitors
1560 .claim_monitor(monitor_info.monitor_id)
1561 {
1562 return Err(RestoreError::DuplicateMonitorId(monitor_info.monitor_id.0));
1563 }
1564 }
1565
1566 if open {
1567 match channel.state {
1568 ChannelState::Closed => {
1569 return Err(RestoreError::MismatchedOpenState(channel.offer.key()));
1570 }
1571 ChannelState::Closing { .. } | ChannelState::ClosingReopen { .. } => {
1572 self.notifier.notify(offer_id, Action::Close);
1573 }
1574 ChannelState::Opening {
1575 request,
1576 reserved_state,
1577 } => {
1578 self.inner
1579 .pending_messages
1580 .sender(self.notifier, self.inner.state.is_paused())
1581 .send_open_result(
1582 info.channel_id,
1583 &request,
1584 protocol::STATUS_SUCCESS,
1585 MessageTarget::for_offer(offer_id, &reserved_state),
1586 );
1587 channel.state = ChannelState::Open {
1588 params: request,
1589 modify_state: ModifyState::NotModifying,
1590 reserved_state,
1591 };
1592 }
1593 ChannelState::Open { .. } => {}
1594 ChannelState::ClientReleased | ChannelState::Reoffered => {
1595 return Err(RestoreError::MissingChannel(channel.offer.key()));
1596 }
1597 ChannelState::Revoked
1598 | ChannelState::ClosingClientRelease
1599 | ChannelState::OpeningClientRelease => unreachable!(),
1600 };
1601 } else {
1602 match channel.state {
1603 ChannelState::Closed => {}
1604 ChannelState::Reoffered => {}
1609 ChannelState::Closing { .. } => {
1610 channel.state = ChannelState::Closed;
1611 }
1612 ChannelState::ClosingReopen { request, .. } => {
1613 self.notifier.notify(
1614 offer_id,
1615 Action::Open(
1616 OpenParams::from_request(
1617 &info,
1618 &request,
1619 channel.handled_monitor_info(),
1620 None,
1621 ),
1622 self.inner.state.get_version().expect("must be connected"),
1623 ),
1624 );
1625 channel.state = ChannelState::Opening {
1626 request,
1627 reserved_state: None,
1628 };
1629 }
1630 ChannelState::Opening {
1631 request,
1632 reserved_state,
1633 } => {
1634 self.notifier.notify(
1635 offer_id,
1636 Action::Open(
1637 OpenParams::from_request(
1638 &info,
1639 &request,
1640 channel.handled_monitor_info(),
1641 reserved_state.map(|state| state.target),
1642 ),
1643 self.inner.state.get_version().expect("must be connected"),
1644 ),
1645 );
1646 }
1647 ChannelState::Open { .. } => {
1648 return Err(RestoreError::MismatchedOpenState(channel.offer.key()));
1649 }
1650 ChannelState::ClientReleased => {
1651 return Err(RestoreError::MissingChannel(channel.offer.key()));
1652 }
1653 ChannelState::Revoked
1654 | ChannelState::ClosingClientRelease
1655 | ChannelState::OpeningClientRelease => unreachable!(),
1656 }
1657 }
1658
1659 channel.restore_state = RestoreState::Restored;
1660 Ok(())
1661 }
1662
1663 pub fn revoke_unclaimed_channels(&mut self) {
1666 for (offer_id, channel) in self.inner.channels.iter_mut() {
1667 match channel.restore_state {
1668 RestoreState::Restored => {
1669 }
1671 RestoreState::New => {
1672 if let ConnectionState::Connected(info) = &self.inner.state {
1677 if info.offers_sent && matches!(channel.state, ChannelState::ClientReleased)
1678 {
1679 channel.prepare_channel(
1680 offer_id,
1681 &mut self.inner.assigned_channels,
1682 &mut self.inner.assigned_monitors,
1683 );
1684 channel.state = ChannelState::Closed;
1685 self.inner
1686 .pending_messages
1687 .sender(self.notifier, self.inner.state.is_paused())
1688 .send_offer(channel, info);
1689 }
1690 }
1691 }
1692 RestoreState::Restoring => {
1693 let retain = revoke(
1697 self.inner
1698 .pending_messages
1699 .sender(self.notifier, self.inner.state.is_paused()),
1700 offer_id,
1701 channel,
1702 &mut self.inner.gpadls,
1703 );
1704 assert!(retain, "channel has not been released");
1705 channel.state = ChannelState::Reoffered;
1706 }
1707 RestoreState::Unmatched => {
1708 let retain = revoke(
1711 self.inner
1712 .pending_messages
1713 .sender(self.notifier, self.inner.state.is_paused()),
1714 offer_id,
1715 channel,
1716 &mut self.inner.gpadls,
1717 );
1718 assert!(retain, "channel has not been released");
1719 }
1720 }
1721 }
1722
1723 for (&(gpadl_id, offer_id), gpadl) in self.inner.gpadls.iter_mut() {
1725 match gpadl.state {
1726 GpadlState::InProgress | GpadlState::Accepted => {}
1727 GpadlState::Offered => {
1728 self.notifier.notify(
1729 offer_id,
1730 Action::Gpadl(gpadl_id, gpadl.count, gpadl.buf.clone()),
1731 );
1732 }
1733 GpadlState::TearingDown => {
1734 self.notifier.notify(
1735 offer_id,
1736 Action::TeardownGpadl {
1737 gpadl_id,
1738 post_restore: true,
1739 },
1740 );
1741 }
1742 GpadlState::OfferedTearingDown => unreachable!(),
1743 }
1744 }
1745
1746 self.check_disconnected();
1747 }
1748
1749 pub fn reset(&mut self) {
1754 assert!(!self.is_resetting());
1755 if self.request_disconnect(ConnectionAction::Reset) {
1756 self.complete_reset();
1757 }
1758 }
1759
1760 fn complete_reset(&mut self) {
1761 for (_, channel) in self.inner.channels.iter_mut() {
1763 channel.restore_state = RestoreState::New;
1764 }
1765 self.inner.pending_messages.0.clear();
1766 self.notifier.reset_complete();
1767 }
1768
1769 pub fn offer_channel(&mut self, offer: OfferParamsInternal) -> Result<OfferId, OfferError> {
1771 if let Some((offer_id, channel)) = self.inner.channels.get_by_key_mut(&offer.key()) {
1773 if channel.restore_state != RestoreState::Unmatched
1777 && !matches!(channel.state, ChannelState::Revoked)
1778 {
1779 return Err(OfferError::AlreadyExists(offer.key()));
1780 }
1781
1782 let info = channel.info.expect("assigned");
1783 if channel.restore_state == RestoreState::Unmatched {
1784 tracing::debug!(
1785 offer_id = offer_id.0,
1786 key = %channel.offer.key(),
1787 "matched channel"
1788 );
1789
1790 assert!(!matches!(channel.state, ChannelState::Revoked));
1791 channel.restore_state = RestoreState::Restoring;
1795
1796 if let MnfUsage::Relayed { monitor_id } = offer.use_mnf {
1799 if info.monitor_id != Some(MonitorId(monitor_id)) {
1800 return Err(OfferError::MismatchedMonitorId(
1801 info.monitor_id,
1802 MonitorId(monitor_id),
1803 ));
1804 }
1805 }
1806 } else {
1807 channel.state = ChannelState::Reoffered;
1811 tracing::info!(?offer_id, key = %channel.offer.key(), "channel marked for reoffer");
1812 }
1813
1814 channel.offer = offer;
1815 return Ok(offer_id);
1816 }
1817
1818 let mut connected_info = None;
1819 let state = match &self.inner.state {
1820 ConnectionState::Connected(info) => {
1821 if info.offers_sent {
1822 connected_info = Some(info);
1823 ChannelState::Closed
1824 } else {
1825 ChannelState::ClientReleased
1826 }
1827 }
1828 ConnectionState::Connecting { .. }
1829 | ConnectionState::Disconnecting { .. }
1830 | ConnectionState::Disconnected => ChannelState::ClientReleased,
1831 };
1832
1833 if self.inner.channels.len() >= self.inner.assigned_channels.allowable_channel_count() {
1835 return Err(OfferError::TooManyChannels);
1836 }
1837
1838 let key = offer.key();
1839 let confidential_ring_buffer = offer.flags.confidential_ring_buffer();
1840 let confidential_external_memory = offer.flags.confidential_external_memory();
1841 let channel = Channel {
1842 info: None,
1843 offer,
1844 state,
1845 restore_state: RestoreState::New,
1846 };
1847
1848 let offer_id = self.inner.channels.offer(channel);
1849 if let Some(info) = connected_info {
1850 let channel = &mut self.inner.channels[offer_id];
1851 channel.prepare_channel(
1852 offer_id,
1853 &mut self.inner.assigned_channels,
1854 &mut self.inner.assigned_monitors,
1855 );
1856
1857 self.inner
1858 .pending_messages
1859 .sender(self.notifier, self.inner.state.is_paused())
1860 .send_offer(channel, info);
1861 }
1862
1863 tracing::info!(?offer_id, %key, confidential_ring_buffer, confidential_external_memory, "new channel");
1864 Ok(offer_id)
1865 }
1866
1867 pub fn revoke_channel(&mut self, offer_id: OfferId) {
1869 let channel = &mut self.inner.channels[offer_id];
1870 let retain = revoke(
1871 self.inner
1872 .pending_messages
1873 .sender(self.notifier, self.inner.state.is_paused()),
1874 offer_id,
1875 channel,
1876 &mut self.inner.gpadls,
1877 );
1878 if !retain {
1879 self.inner.channels.remove(offer_id);
1880 }
1881
1882 self.check_disconnected();
1883 }
1884
1885 pub fn open_complete(&mut self, offer_id: OfferId, result: i32) {
1887 let channel = &mut self.inner.channels[offer_id];
1888 tracing::debug!(offer_id = offer_id.0, key = %channel.offer.key(), result, "open complete");
1889
1890 match channel.state {
1891 ChannelState::Opening {
1892 request,
1893 reserved_state,
1894 } => {
1895 let channel_id = channel.info.expect("assigned").channel_id;
1896 if result >= 0 {
1897 tracelimit::info_ratelimited!(
1898 offer_id = offer_id.0,
1899 channel_id = channel_id.0,
1900 key = %channel.offer.key(),
1901 result,
1902 "opened channel"
1903 );
1904 } else {
1905 tracelimit::error_ratelimited!(
1907 offer_id = offer_id.0,
1908 channel_id = channel_id.0,
1909 key = %channel.offer.key(),
1910 result,
1911 "failed to open channel"
1912 );
1913 }
1914
1915 self.inner
1916 .pending_messages
1917 .sender(self.notifier, self.inner.state.is_paused())
1918 .send_open_result(
1919 channel_id,
1920 &request,
1921 result,
1922 MessageTarget::for_offer(offer_id, &reserved_state),
1923 );
1924 channel.state = if result >= 0 {
1925 ChannelState::Open {
1926 params: request,
1927 modify_state: ModifyState::NotModifying,
1928 reserved_state,
1929 }
1930 } else {
1931 ChannelState::Closed
1932 };
1933 }
1934 ChannelState::OpeningClientRelease => {
1935 tracing::info!(
1936 offer_id = offer_id.0,
1937 key = %channel.offer.key(),
1938 result,
1939 "opened channel (client released)"
1940 );
1941
1942 if result >= 0 {
1943 channel.state = ChannelState::ClosingClientRelease;
1944 self.notifier.notify(offer_id, Action::Close);
1945 } else {
1946 channel.state = ChannelState::ClientReleased;
1947 self.check_disconnected();
1948 }
1949 }
1950
1951 ChannelState::ClientReleased
1952 | ChannelState::Closed
1953 | ChannelState::Open { .. }
1954 | ChannelState::Closing { .. }
1955 | ChannelState::ClosingReopen { .. }
1956 | ChannelState::Revoked
1957 | ChannelState::Reoffered
1958 | ChannelState::ClosingClientRelease => {
1959 tracing::error!(?offer_id, key = %channel.offer.key(), state = ?channel.state, "invalid open complete")
1960 }
1961 }
1962 }
1963
1964 fn are_channels_reset(&self, include_reserved: bool) -> bool {
1967 self.inner.gpadls.keys().all(|(_, offer_id)| {
1968 !include_reserved && self.inner.channels[*offer_id].state.is_reserved()
1969 }) && self.inner.channels.iter().all(|(_, channel)| {
1970 matches!(channel.state, ChannelState::ClientReleased)
1971 || (!include_reserved && channel.state.is_reserved())
1972 })
1973 }
1974
1975 fn check_disconnected(&mut self) {
1979 match self.inner.state {
1980 ConnectionState::Disconnecting {
1981 next_action,
1982 modify_sent: false,
1983 } => {
1984 if self.are_channels_reset(matches!(next_action, ConnectionAction::Reset)) {
1985 self.notify_disconnect(next_action);
1986 }
1987 }
1988 ConnectionState::Disconnecting {
1989 modify_sent: true, ..
1990 }
1991 | ConnectionState::Disconnected
1992 | ConnectionState::Connected { .. }
1993 | ConnectionState::Connecting { .. } => (),
1994 }
1995 }
1996
1997 fn notify_disconnect(&mut self, next_action: ConnectionAction) {
1999 debug_assert!(self.are_channels_reset(matches!(next_action, ConnectionAction::Reset)));
2001 self.inner.state = ConnectionState::Disconnecting {
2002 next_action,
2003 modify_sent: true,
2004 };
2005
2006 self.notifier
2008 .modify_connection(ModifyConnectionRequest {
2009 monitor_page: Update::Reset,
2010 interrupt_page: Update::Reset,
2011 ..Default::default()
2012 })
2013 .expect("resetting state should not fail");
2014 }
2015
2016 fn is_resetting(&self) -> bool {
2019 matches!(
2020 &self.inner.state,
2021 ConnectionState::Connecting {
2022 next_action: ConnectionAction::Reset,
2023 ..
2024 } | ConnectionState::Disconnecting {
2025 next_action: ConnectionAction::Reset,
2026 ..
2027 }
2028 )
2029 }
2030
2031 pub fn close_complete(&mut self, offer_id: OfferId) {
2033 let channel = &mut self.inner.channels[offer_id];
2034 tracing::info!(offer_id = offer_id.0, key = %channel.offer.key(), "closed channel");
2035 match channel.state {
2036 ChannelState::Closing {
2037 reserved_state: Some(reserved_state),
2038 ..
2039 } => {
2040 channel.state = ChannelState::Closed;
2041 if matches!(self.inner.state, ConnectionState::Connected { .. }) {
2042 let channel_id = channel.info.expect("assigned").channel_id;
2043 self.send_close_reserved_channel_response(
2044 channel_id,
2045 offer_id,
2046 reserved_state.target,
2047 );
2048 } else {
2049 if Self::client_release_channel(
2052 self.inner
2053 .pending_messages
2054 .sender(self.notifier, self.inner.state.is_paused()),
2055 offer_id,
2056 channel,
2057 &mut self.inner.gpadls,
2058 &mut self.inner.incomplete_gpadls,
2059 &mut self.inner.assigned_channels,
2060 &mut self.inner.assigned_monitors,
2061 None,
2062 ) {
2063 self.inner.channels.remove(offer_id);
2064 }
2065 }
2066 }
2067 ChannelState::Closing { .. } => {
2068 channel.state = ChannelState::Closed;
2069 }
2070 ChannelState::ClosingClientRelease => {
2071 channel.state = ChannelState::ClientReleased;
2072 self.check_disconnected();
2073 }
2074 ChannelState::ClosingReopen { request, .. } => {
2075 channel.state = ChannelState::Closed;
2076 self.open_channel(offer_id, &request, None);
2077 }
2078
2079 ChannelState::Closed
2080 | ChannelState::ClientReleased
2081 | ChannelState::Opening { .. }
2082 | ChannelState::Open { .. }
2083 | ChannelState::Revoked
2084 | ChannelState::Reoffered
2085 | ChannelState::OpeningClientRelease => {
2086 tracing::error!(?offer_id, key = %channel.offer.key(), state = ?channel.state, "invalid close complete")
2087 }
2088 }
2089 }
2090
2091 fn send_close_reserved_channel_response(
2092 &mut self,
2093 channel_id: ChannelId,
2094 offer_id: OfferId,
2095 target: ConnectionTarget,
2096 ) {
2097 self.sender().send_message_with_target(
2098 &protocol::CloseReservedChannelResponse { channel_id },
2099 MessageTarget::ReservedChannel(offer_id, target),
2100 );
2101 }
2102
2103 fn handle_initiate_contact(
2106 &mut self,
2107 input: &protocol::InitiateContact2,
2108 message: &SynicMessage,
2109 includes_client_id: bool,
2110 ) -> Result<(), ChannelError> {
2111 let target_info =
2112 protocol::TargetInfo::from(input.initiate_contact.interrupt_page_or_target_info);
2113
2114 let target_sint = if message.multiclient
2115 && input.initiate_contact.version_requested >= Version::Win10Rs3_1 as u32
2116 {
2117 target_info.sint()
2118 } else {
2119 VMBUS_SINT
2120 };
2121
2122 let target_vtl = if message.multiclient
2123 && input.initiate_contact.version_requested >= Version::Win10Rs4 as u32
2124 {
2125 target_info.vtl()
2126 } else {
2127 0
2128 };
2129
2130 let feature_flags = if input.initiate_contact.version_requested >= Version::Copper as u32 {
2131 target_info.feature_flags()
2132 } else {
2133 0
2134 };
2135
2136 let target_message_vp =
2141 if input.initiate_contact.version_requested >= Version::Win8_1 as u32 {
2142 input.initiate_contact.target_message_vp
2143 } else {
2144 0
2145 };
2146
2147 let interrupt_page = (input.initiate_contact.version_requested < Version::Win8 as u32
2154 && input.initiate_contact.interrupt_page_or_target_info != 0)
2155 .then_some(input.initiate_contact.interrupt_page_or_target_info);
2156
2157 let monitor_page = if (input.initiate_contact.parent_to_child_monitor_page_gpa == 0)
2160 != (input.initiate_contact.child_to_parent_monitor_page_gpa == 0)
2161 {
2162 MonitorPageRequest::Invalid
2163 } else if input.initiate_contact.parent_to_child_monitor_page_gpa != 0 {
2164 MonitorPageRequest::Some(MonitorPageGpas {
2165 parent_to_child: input.initiate_contact.parent_to_child_monitor_page_gpa,
2166 child_to_parent: input.initiate_contact.child_to_parent_monitor_page_gpa,
2167 })
2168 } else {
2169 MonitorPageRequest::None
2170 };
2171
2172 let client_id = if FeatureFlags::from(feature_flags).client_id() {
2175 if includes_client_id {
2176 input.client_id
2177 } else {
2178 return Err(ChannelError::ParseError(
2179 protocol::ParseError::MessageTooSmall(Some(
2180 protocol::MessageType::INITIATE_CONTACT,
2181 )),
2182 ));
2183 }
2184 } else {
2185 Guid::ZERO
2186 };
2187
2188 let request = InitiateContactRequest {
2189 version_requested: input.initiate_contact.version_requested,
2190 target_message_vp,
2191 monitor_page,
2192 target_sint,
2193 target_vtl,
2194 feature_flags,
2195 interrupt_page,
2196 client_id,
2197 trusted: message.trusted,
2198 };
2199 self.initiate_contact(request);
2200 Ok(())
2201 }
2202
2203 pub fn initiate_contact(&mut self, request: InitiateContactRequest) {
2204 let vtl = self.inner.assigned_channels.vtl as u8;
2207 if request.target_vtl != vtl {
2208 self.notifier.forward_unhandled(request);
2210 return;
2211 }
2212
2213 if request.target_sint != VMBUS_SINT {
2214 tracelimit::warn_ratelimited!(
2215 target_vtl = request.target_vtl,
2216 target_sint = request.target_sint,
2217 version = request.version_requested,
2218 "unsupported multiclient request",
2219 );
2220
2221 self.send_version_response_with_target(
2223 None,
2224 MessageTarget::Custom(ConnectionTarget {
2225 vp: request.target_message_vp,
2226 sint: request.target_sint,
2227 }),
2228 );
2229
2230 return;
2231 }
2232
2233 if !self.request_disconnect(ConnectionAction::Reconnect {
2234 initiate_contact: request,
2235 }) {
2236 return;
2237 }
2238
2239 let Some(version) = self.check_version_supported(&request) else {
2240 tracelimit::warn_ratelimited!(
2241 vtl,
2242 version = request.version_requested,
2243 client_id = ?request.client_id,
2244 "Guest requested unsupported version"
2245 );
2246
2247 self.send_version_response(None);
2249 return;
2250 };
2251
2252 tracelimit::info_ratelimited!(
2253 vtl,
2254 ?version,
2255 client_id = ?request.client_id,
2256 trusted = request.trusted,
2257 "Guest negotiated version"
2258 );
2259
2260 let monitor_page = match request.monitor_page {
2263 MonitorPageRequest::Some(mp) => {
2264 if self.inner.require_server_allocated_mnf {
2265 if !version.feature_flags.server_specified_monitor_pages() {
2266 tracelimit::warn_ratelimited!(
2267 "guest-supplied monitor pages not supported; MNF will be disabled"
2268 );
2269 }
2270
2271 None
2272 } else {
2273 Some(mp)
2274 }
2275 }
2276 MonitorPageRequest::None => None,
2277 MonitorPageRequest::Invalid => {
2278 self.send_version_response(Some(VersionResponseData::new(
2280 version,
2281 protocol::ConnectionState::FAILED_UNKNOWN_FAILURE,
2282 )));
2283
2284 return;
2285 }
2286 };
2287
2288 self.inner.state = ConnectionState::Connecting {
2289 info: ConnectionInfo {
2290 version,
2291 trusted: request.trusted,
2292 interrupt_page: request.interrupt_page,
2293 monitor_page: monitor_page.map(MonitorPageGpaInfo::from_guest_gpas),
2294 target_message_vp: request.target_message_vp,
2295 modifying: false,
2296 offers_sent: false,
2297 client_id: request.client_id,
2298 paused: false,
2299 },
2300 next_action: ConnectionAction::None,
2301 };
2302
2303 if let Err(err) = self.notifier.modify_connection(ModifyConnectionRequest {
2306 version: Some(version),
2307 monitor_page: monitor_page.into(),
2308 interrupt_page: request.interrupt_page.into(),
2309 target_message_vp: Some(request.target_message_vp),
2310 notify_relay: true,
2311 }) {
2312 tracelimit::error_ratelimited!(?err, "server failed to change state");
2313 self.inner.state = ConnectionState::Disconnected;
2314 self.send_version_response(Some(VersionResponseData::new(
2315 version,
2316 protocol::ConnectionState::FAILED_UNKNOWN_FAILURE,
2317 )));
2318 }
2319 }
2320
2321 pub(crate) fn complete_initiate_contact(&mut self, response: ModifyConnectionResponse) {
2322 let ConnectionState::Connecting {
2323 mut info,
2324 next_action,
2325 } = self.inner.state
2326 else {
2327 panic!("Invalid state for completing InitiateContact.");
2328 };
2329
2330 const LOCAL_FEATURE_FLAGS: FeatureFlags = FeatureFlags::new()
2334 .with_client_id(true)
2335 .with_confidential_channels(true);
2336
2337 let (relay_feature_flags, server_specified_monitor_page) = match response {
2338 ModifyConnectionResponse::Supported(
2340 protocol::ConnectionState::SUCCESSFUL,
2341 feature_flags,
2342 server_specified_monitor_page,
2343 ) => (feature_flags, server_specified_monitor_page),
2344 ModifyConnectionResponse::Supported(
2347 connection_state,
2348 feature_flags,
2349 server_specified_monitor_page,
2350 ) => {
2351 tracelimit::error_ratelimited!(
2352 ?connection_state,
2353 "initiate contact failed because relay request failed"
2354 );
2355
2356 info.version.feature_flags &= (feature_flags | LOCAL_FEATURE_FLAGS)
2359 .with_server_specified_monitor_pages(server_specified_monitor_page.is_some());
2360
2361 self.send_version_response(Some(VersionResponseData::new(
2362 info.version,
2363 connection_state,
2364 )));
2365 self.inner.state = ConnectionState::Disconnected;
2366 return;
2367 }
2368 ModifyConnectionResponse::Unsupported => {
2371 self.send_version_response(None);
2372 self.inner.state = ConnectionState::Disconnected;
2373 return;
2374 }
2375 ModifyConnectionResponse::Modified(_) => {
2376 panic!("Invalid response for completing InitiateContact.");
2377 }
2378 };
2379
2380 assert!(
2382 info.version.feature_flags.server_specified_monitor_pages()
2383 || server_specified_monitor_page.is_none()
2384 );
2385
2386 info.version.feature_flags &= relay_feature_flags | LOCAL_FEATURE_FLAGS;
2389
2390 if let Some(gpas) = server_specified_monitor_page {
2394 info.monitor_page = Some(MonitorPageGpaInfo::from_server_gpas(gpas));
2395 info.version
2396 .feature_flags
2397 .set_server_specified_monitor_pages(true);
2398 } else {
2399 info.version
2400 .feature_flags
2401 .set_server_specified_monitor_pages(false);
2402 }
2403
2404 let version = info.version;
2405 self.inner.state = ConnectionState::Connected(info);
2406
2407 self.send_version_response(Some(
2408 VersionResponseData::new(version, protocol::ConnectionState::SUCCESSFUL)
2409 .with_monitor_pages(server_specified_monitor_page),
2410 ));
2411 if !matches!(next_action, ConnectionAction::None) && self.request_disconnect(next_action) {
2412 self.do_next_action(next_action);
2413 }
2414 }
2415
2416 fn check_version_supported(&self, request: &InitiateContactRequest) -> Option<VersionInfo> {
2418 let version = SUPPORTED_VERSIONS
2419 .iter()
2420 .find(|v| request.version_requested == **v as u32)
2421 .copied()?;
2422
2423 if let Some(max_version) = self.inner.max_version {
2425 if version as u32 > max_version.version {
2426 return None;
2427 }
2428 }
2429
2430 let supported_flags = if version >= Version::Copper {
2431 let max_supported_flags =
2433 SUPPORTED_FEATURE_FLAGS.with_confidential_channels(request.trusted);
2434
2435 if let Some(max_version) = self.inner.max_version {
2437 max_supported_flags & max_version.feature_flags
2438 } else {
2439 max_supported_flags
2440 }
2441 } else {
2442 FeatureFlags::new()
2443 };
2444
2445 let feature_flags = supported_flags & request.feature_flags.into();
2446
2447 assert!(version >= Version::Copper || feature_flags == FeatureFlags::new());
2448 if feature_flags.into_bits() != request.feature_flags {
2449 tracelimit::warn_ratelimited!(
2450 supported = feature_flags.into_bits(),
2451 requested = request.feature_flags,
2452 "Guest requested unsupported feature flags."
2453 );
2454 }
2455
2456 Some(VersionInfo {
2457 version,
2458 feature_flags,
2459 })
2460 }
2461
2462 fn send_version_response(&mut self, data: Option<VersionResponseData>) {
2463 self.send_version_response_with_target(data, MessageTarget::Default);
2464 }
2465
2466 fn send_version_response_with_target(
2467 &mut self,
2468 data: Option<VersionResponseData>,
2469 target: MessageTarget,
2470 ) {
2471 enum VersionResponseType {
2472 PreCopper,
2473 Copper,
2474 CopperWithServerMnf,
2475 }
2476
2477 let mut response_copper_with_mnf = protocol::VersionResponse3::new_zeroed();
2478 let response_copper = &mut response_copper_with_mnf.version_response2;
2479 let response = &mut response_copper.version_response;
2480 let mut response_type = VersionResponseType::PreCopper;
2481 if let Some(data) = data {
2482 if data.state == protocol::ConnectionState::SUCCESSFUL
2485 || data.version.version >= Version::Win8
2486 {
2487 response.version_supported = 1;
2488 response.connection_state = data.state;
2489 response.selected_version_or_connection_id =
2490 if data.version.version >= Version::Win10Rs3_1 {
2491 self.inner.child_connection_id
2492 } else {
2493 data.version.version as u32
2494 };
2495
2496 if data.version.version >= Version::Copper {
2497 response_copper.supported_features = data.version.feature_flags.into();
2498 response_type = VersionResponseType::Copper;
2499 if let Some(monitor_page) = data.monitor_pages {
2500 assert!(data.version.feature_flags.server_specified_monitor_pages());
2501 response_copper_with_mnf.child_to_parent_monitor_page_gpa =
2502 monitor_page.child_to_parent;
2503 response_copper_with_mnf.parent_to_child_monitor_page_gpa =
2504 monitor_page.parent_to_child;
2505 response_type = VersionResponseType::CopperWithServerMnf;
2506 }
2507 }
2508 }
2509 }
2510
2511 match response_type {
2513 VersionResponseType::PreCopper => {
2514 self.sender().send_message_with_target(response, target)
2515 }
2516 VersionResponseType::Copper => self
2517 .sender()
2518 .send_message_with_target(response_copper, target),
2519 VersionResponseType::CopperWithServerMnf => self
2520 .sender()
2521 .send_message_with_target(&response_copper_with_mnf, target),
2522 }
2523 }
2524
2525 fn request_disconnect(&mut self, new_action: ConnectionAction) -> bool {
2528 assert!(!self.is_resetting());
2529
2530 let gpadls = &mut self.inner.gpadls;
2532 let vm_reset = matches!(new_action, ConnectionAction::Reset);
2533 self.inner.channels.retain(|offer_id, channel| {
2534 (!vm_reset && channel.state.is_reserved())
2536 || !Self::client_release_channel(
2537 self.inner
2538 .pending_messages
2539 .sender(self.notifier, self.inner.state.is_paused()),
2540 offer_id,
2541 channel,
2542 gpadls,
2543 &mut self.inner.incomplete_gpadls,
2544 &mut self.inner.assigned_channels,
2545 &mut self.inner.assigned_monitors,
2546 None,
2547 )
2548 });
2549
2550 match &mut self.inner.state {
2554 ConnectionState::Disconnected => {
2555 if vm_reset {
2557 if !self.are_channels_reset(true) {
2558 self.inner.state = ConnectionState::Disconnecting {
2559 next_action: ConnectionAction::Reset,
2560 modify_sent: false,
2561 };
2562 }
2563 } else {
2564 assert!(self.are_channels_reset(false));
2565 }
2566 }
2567
2568 ConnectionState::Connected { .. } => {
2569 if self.are_channels_reset(vm_reset) {
2570 self.notify_disconnect(new_action);
2571 } else {
2572 self.inner.state = ConnectionState::Disconnecting {
2573 next_action: new_action,
2574 modify_sent: false,
2575 };
2576 }
2577 }
2578
2579 ConnectionState::Connecting { next_action, .. }
2580 | ConnectionState::Disconnecting { next_action, .. } => {
2581 *next_action = new_action;
2582 }
2583 }
2584
2585 matches!(self.inner.state, ConnectionState::Disconnected)
2586 }
2587
2588 pub(crate) fn complete_disconnect(&mut self) {
2589 if let ConnectionState::Disconnecting {
2590 next_action,
2591 modify_sent,
2592 } = std::mem::replace(&mut self.inner.state, ConnectionState::Disconnected)
2593 {
2594 assert!(self.are_channels_reset(matches!(next_action, ConnectionAction::Reset)));
2595 if !modify_sent {
2596 tracelimit::warn_ratelimited!("unexpected modify response");
2597 }
2598
2599 self.inner.state = ConnectionState::Disconnected;
2600 self.do_next_action(next_action);
2601 } else {
2602 unreachable!("not ready for disconnect");
2603 }
2604 }
2605
2606 fn do_next_action(&mut self, action: ConnectionAction) {
2607 match action {
2608 ConnectionAction::None => {}
2609 ConnectionAction::Reset => {
2610 self.complete_reset();
2611 }
2612 ConnectionAction::SendUnloadComplete => {
2613 self.complete_unload();
2614 }
2615 ConnectionAction::Reconnect { initiate_contact } => {
2616 self.initiate_contact(initiate_contact);
2617 }
2618 ConnectionAction::SendFailedVersionResponse => {
2619 self.send_version_response(None);
2622 }
2623 }
2624 }
2625
2626 fn handle_unload(&mut self) {
2628 tracing::debug!(
2629 vtl = self.inner.assigned_channels.vtl as u8,
2630 state = ?self.inner.state,
2631 "VmBus received unload request from guest",
2632 );
2633
2634 if self.request_disconnect(ConnectionAction::SendUnloadComplete) {
2635 self.complete_unload();
2636 }
2637 }
2638
2639 fn complete_unload(&mut self) {
2640 self.notifier.unload_complete();
2641 if let Some(version) = self.inner.delayed_max_version.take() {
2642 self.inner.set_compatibility_version(version, false);
2643 }
2644
2645 self.sender().send_message(&protocol::UnloadComplete {});
2646 tracelimit::info_ratelimited!("Vmbus disconnected");
2647 }
2648
2649 fn handle_request_offers(&mut self) -> Result<(), ChannelError> {
2651 let ConnectionState::Connected(info) = &mut self.inner.state else {
2652 unreachable!(
2653 "in unexpected state {:?}, should be prevented by Message::parse()",
2654 self.inner.state
2655 );
2656 };
2657
2658 if info.offers_sent {
2659 return Err(ChannelError::OffersAlreadySent);
2660 }
2661
2662 info.offers_sent = true;
2663
2664 let mut sorted_channels: Vec<_> = self
2667 .inner
2668 .channels
2669 .iter_mut()
2670 .filter(|(_, channel)| !channel.state.is_reserved())
2671 .collect();
2672
2673 if self.inner.use_absolute_channel_order {
2674 sorted_channels.sort_unstable_by_key(|(_, channel)| {
2675 (
2676 channel.offer.offer_order.unwrap_or(u64::MAX),
2677 channel.offer.interface_id,
2678 channel.offer.instance_id,
2679 )
2680 });
2681 } else {
2682 sorted_channels.sort_unstable_by_key(|(_, channel)| {
2683 (
2684 channel.offer.interface_id,
2685 channel.offer.offer_order.unwrap_or(u64::MAX),
2686 channel.offer.instance_id,
2687 )
2688 });
2689 }
2690
2691 for (offer_id, channel) in sorted_channels {
2692 assert!(matches!(channel.state, ChannelState::ClientReleased));
2693
2694 channel.prepare_channel(
2695 offer_id,
2696 &mut self.inner.assigned_channels,
2697 &mut self.inner.assigned_monitors,
2698 );
2699
2700 channel.state = ChannelState::Closed;
2701 self.inner
2702 .pending_messages
2703 .sender(self.notifier, info.paused)
2704 .send_offer(channel, info);
2705 }
2706 self.sender().send_message(&protocol::AllOffersDelivered {});
2707
2708 Ok(())
2709 }
2710
2711 #[must_use]
2714 fn gpadl_updated(
2715 mut sender: MessageSender<'_, N>,
2716 offer_id: OfferId,
2717 channel: &Channel,
2718 gpadl_id: GpadlId,
2719 gpadl: &Gpadl,
2720 ) -> bool {
2721 if channel.state.is_revoked() {
2722 let channel_id = channel.info.as_ref().expect("assigned").channel_id;
2723 sender.send_gpadl_created(channel_id, gpadl_id, protocol::STATUS_UNSUCCESSFUL);
2724 false
2725 } else {
2726 sender.notifier.notify(
2728 offer_id,
2729 Action::Gpadl(gpadl_id, gpadl.count, gpadl.buf.clone()),
2730 );
2731 true
2732 }
2733 }
2734
2735 fn handle_gpadl_header_core(
2737 &mut self,
2738 input: &protocol::GpadlHeader,
2739 range: &[u8],
2740 ) -> Result<(), ChannelError> {
2741 let (offer_id, channel) = self
2743 .inner
2744 .channels
2745 .get_by_channel_id_mut(&self.inner.assigned_channels, input.channel_id)?;
2746
2747 if channel.state.is_reserved() {
2750 return Err(ChannelError::ChannelReserved);
2751 }
2752
2753 let mut gpadl = Gpadl::new(input.count, input.len as usize / 8);
2755 let done = gpadl.append(range)?;
2756
2757 let gpadl = match self.inner.gpadls.entry((input.gpadl_id, offer_id)) {
2759 Entry::Vacant(entry) => entry.insert(gpadl),
2760 Entry::Occupied(_) => return Err(ChannelError::DuplicateGpadlId),
2761 };
2762
2763 if !done {
2769 match self.inner.incomplete_gpadls.entry(input.gpadl_id) {
2770 Entry::Vacant(entry) => {
2771 entry.insert(offer_id);
2772 }
2773 Entry::Occupied(_) => {
2774 self.inner.gpadls.remove(&(input.gpadl_id, offer_id));
2775 tracelimit::error_ratelimited!(
2776 channel_id = ?input.channel_id,
2777 key = %channel.offer.key(),
2778 gpadl_id = ?input.gpadl_id,
2779 "duplicate in-progress gpadl ID",
2780 );
2781 return Err(ChannelError::DuplicateGpadlId);
2782 }
2783 }
2784 }
2785
2786 if done
2787 && !Self::gpadl_updated(
2788 self.inner
2789 .pending_messages
2790 .sender(self.notifier, self.inner.state.is_paused()),
2791 offer_id,
2792 channel,
2793 input.gpadl_id,
2794 gpadl,
2795 )
2796 {
2797 self.inner.gpadls.remove(&(input.gpadl_id, offer_id));
2798 }
2799 Ok(())
2800 }
2801
2802 fn handle_gpadl_header(&mut self, input: &protocol::GpadlHeader, range: &[u8]) {
2804 if let Err(err) = self.handle_gpadl_header_core(input, range) {
2805 tracelimit::warn_ratelimited!(
2806 err = &err as &dyn std::error::Error,
2807 channel_id = ?input.channel_id,
2808 key = %self.inner.channels.get_by_channel_id(&self.inner.assigned_channels, input.channel_id).map(|(_, c)| c.offer.key()).unwrap_or_default(),
2809 gpadl_id = ?input.gpadl_id,
2810 "error handling gpadl header"
2811 );
2812
2813 self.sender().send_gpadl_created(
2815 input.channel_id,
2816 input.gpadl_id,
2817 protocol::STATUS_UNSUCCESSFUL,
2818 );
2819 }
2820 }
2821
2822 fn handle_gpadl_body(
2828 &mut self,
2829 input: &protocol::GpadlBody,
2830 range: &[u8],
2831 ) -> Result<(), ChannelError> {
2832 let &offer_id = self
2836 .inner
2837 .incomplete_gpadls
2838 .get(&input.gpadl_id)
2839 .ok_or(ChannelError::UnknownGpadlId)?;
2840 let gpadl = self
2841 .inner
2842 .gpadls
2843 .get_mut(&(input.gpadl_id, offer_id))
2844 .ok_or(ChannelError::UnknownGpadlId)?;
2845 let channel = &mut self.inner.channels[offer_id];
2846
2847 match gpadl.append(range) {
2848 Ok(done) => {
2849 if done {
2850 self.inner.incomplete_gpadls.remove(&input.gpadl_id);
2851 if !Self::gpadl_updated(
2852 self.inner
2853 .pending_messages
2854 .sender(self.notifier, self.inner.state.is_paused()),
2855 offer_id,
2856 channel,
2857 input.gpadl_id,
2858 gpadl,
2859 ) {
2860 self.inner.gpadls.remove(&(input.gpadl_id, offer_id));
2861 }
2862 }
2863 }
2864 Err(err) => {
2865 self.inner.incomplete_gpadls.remove(&input.gpadl_id);
2866 self.inner.gpadls.remove(&(input.gpadl_id, offer_id));
2867 let channel_id = channel.info.as_ref().expect("assigned").channel_id;
2868 tracelimit::warn_ratelimited!(
2869 err = &err as &dyn std::error::Error,
2870 channel_id = channel_id.0,
2871 key = %channel.offer.key(),
2872 gpadl_id = input.gpadl_id.0,
2873 "error handling gpadl body"
2874 );
2875 self.sender().send_gpadl_created(
2876 channel_id,
2877 input.gpadl_id,
2878 protocol::STATUS_UNSUCCESSFUL,
2879 );
2880 }
2881 }
2882
2883 Ok(())
2884 }
2885
2886 fn handle_gpadl_teardown(
2888 &mut self,
2889 input: &protocol::GpadlTeardown,
2890 ) -> Result<(), ChannelError> {
2891 let (offer_id, channel) = self
2892 .inner
2893 .channels
2894 .get_by_channel_id_mut(&self.inner.assigned_channels, input.channel_id)?;
2895
2896 tracing::debug!(
2897 channel_id = input.channel_id.0,
2898 key = %channel.offer.key(),
2899 gpadl_id = input.gpadl_id.0,
2900 "Received GPADL teardown request"
2901 );
2902
2903 let gpadl = self
2904 .inner
2905 .gpadls
2906 .get_mut(&(input.gpadl_id, offer_id))
2907 .ok_or(ChannelError::UnknownGpadlId)?;
2908
2909 match gpadl.state {
2910 GpadlState::InProgress
2911 | GpadlState::Offered
2912 | GpadlState::OfferedTearingDown
2913 | GpadlState::TearingDown => {
2914 return Err(ChannelError::InvalidGpadlState);
2915 }
2916 GpadlState::Accepted => {
2917 if channel.info.as_ref().map(|info| info.channel_id) != Some(input.channel_id) {
2918 return Err(ChannelError::WrongGpadlChannelId);
2919 }
2920
2921 if channel.state.is_reserved() {
2925 return Err(ChannelError::ChannelReserved);
2926 }
2927
2928 if channel.state.is_revoked() {
2929 tracing::trace!(
2930 channel_id = input.channel_id.0,
2931 key = %channel.offer.key(),
2932 gpadl_id = input.gpadl_id.0,
2933 "Gpadl teardown for revoked channel"
2934 );
2935
2936 self.inner.gpadls.remove(&(input.gpadl_id, offer_id));
2937 self.sender().send_gpadl_torndown(input.gpadl_id);
2938 } else {
2939 gpadl.state = GpadlState::TearingDown;
2940 self.notifier.notify(
2941 offer_id,
2942 Action::TeardownGpadl {
2943 gpadl_id: input.gpadl_id,
2944 post_restore: false,
2945 },
2946 );
2947 }
2948 }
2949 }
2950 Ok(())
2951 }
2952
2953 fn open_channel(
2956 &mut self,
2957 offer_id: OfferId,
2958 input: &OpenRequest,
2959 reserved_state: Option<ReservedState>,
2960 ) {
2961 let channel = &mut self.inner.channels[offer_id];
2962 assert!(matches!(channel.state, ChannelState::Closed));
2963
2964 channel.state = ChannelState::Opening {
2965 request: *input,
2966 reserved_state,
2967 };
2968
2969 let info = channel.info.as_ref().expect("assigned");
2972 self.notifier.notify(
2973 offer_id,
2974 Action::Open(
2975 OpenParams::from_request(
2976 info,
2977 input,
2978 channel.handled_monitor_info(),
2979 reserved_state.map(|state| state.target),
2980 ),
2981 self.inner.state.get_version().expect("must be connected"),
2982 ),
2983 );
2984 }
2985
2986 fn handle_open_channel(&mut self, input: &protocol::OpenChannel2) -> Result<(), ChannelError> {
2988 let (offer_id, channel) = self
2989 .inner
2990 .channels
2991 .get_by_channel_id_mut(&self.inner.assigned_channels, input.open_channel.channel_id)?;
2992
2993 let guest_specified_interrupt_info = self
2994 .inner
2995 .state
2996 .check_feature_flags(|ff| ff.guest_specified_signal_parameters())
2997 .then_some(SignalInfo {
2998 event_flag: input.event_flag,
2999 connection_id: input.connection_id,
3000 });
3001
3002 let flags = if self
3003 .inner
3004 .state
3005 .check_feature_flags(|ff| ff.channel_interrupt_redirection())
3006 {
3007 input.flags
3008 } else {
3009 Default::default()
3010 };
3011
3012 let request = OpenRequest {
3013 open_id: input.open_channel.open_id,
3014 ring_buffer_gpadl_id: input.open_channel.ring_buffer_gpadl_id,
3015 target_vp: protocol::vp_index_if_enabled(input.open_channel.target_vp),
3016 downstream_ring_buffer_page_offset: input
3017 .open_channel
3018 .downstream_ring_buffer_page_offset,
3019 user_data: input.open_channel.user_data,
3020 guest_specified_interrupt_info,
3021 flags,
3022 };
3023
3024 match channel.state {
3025 ChannelState::Closed => self.open_channel(offer_id, &request, None),
3026 ChannelState::Closing { params, .. } => {
3027 channel.state = ChannelState::ClosingReopen { params, request }
3031 }
3032 ChannelState::Revoked | ChannelState::Reoffered => {}
3033
3034 ChannelState::Open { .. }
3035 | ChannelState::Opening { .. }
3036 | ChannelState::ClosingReopen { .. } => return Err(ChannelError::ChannelAlreadyOpen),
3037
3038 ChannelState::ClientReleased
3039 | ChannelState::ClosingClientRelease
3040 | ChannelState::OpeningClientRelease => unreachable!(),
3041 }
3042 Ok(())
3043 }
3044
3045 fn handle_close_channel(&mut self, input: &protocol::CloseChannel) -> Result<(), ChannelError> {
3047 let (offer_id, channel) = self
3048 .inner
3049 .channels
3050 .get_by_channel_id_mut(&self.inner.assigned_channels, input.channel_id)?;
3051
3052 match channel.state {
3053 ChannelState::Open {
3054 params,
3055 modify_state,
3056 reserved_state: None,
3057 } => {
3058 if modify_state.is_modifying() {
3059 tracelimit::warn_ratelimited!(
3060 key = %channel.offer.key(),
3061 ?modify_state,
3062 "Client is closing the channel with a modify in progress"
3063 )
3064 }
3065
3066 channel.state = ChannelState::Closing {
3067 params,
3068 reserved_state: None,
3069 };
3070 self.notifier.notify(offer_id, Action::Close);
3071 }
3072
3073 ChannelState::Open {
3074 reserved_state: Some(_),
3075 ..
3076 } => return Err(ChannelError::ChannelReserved),
3077
3078 ChannelState::Revoked | ChannelState::Reoffered => {}
3079
3080 ChannelState::Closed
3081 | ChannelState::Opening { .. }
3082 | ChannelState::Closing { .. }
3083 | ChannelState::ClosingReopen { .. } => return Err(ChannelError::ChannelNotOpen),
3084
3085 ChannelState::ClientReleased
3086 | ChannelState::ClosingClientRelease
3087 | ChannelState::OpeningClientRelease => unreachable!(),
3088 }
3089
3090 Ok(())
3091 }
3092
3093 fn handle_open_reserved_channel(
3096 &mut self,
3097 input: &protocol::OpenReservedChannel,
3098 version: VersionInfo,
3099 ) -> Result<(), ChannelError> {
3100 let (offer_id, channel) = self
3101 .inner
3102 .channels
3103 .get_by_channel_id_mut(&self.inner.assigned_channels, input.channel_id)?;
3104
3105 let target = ConnectionTarget {
3106 vp: input.target_vp,
3107 sint: input.target_sint as u8,
3108 };
3109
3110 let reserved_state = Some(ReservedState { version, target });
3111
3112 let request = OpenRequest {
3113 ring_buffer_gpadl_id: input.ring_buffer_gpadl,
3114 target_vp: None,
3116 downstream_ring_buffer_page_offset: input.downstream_page_offset,
3117 open_id: 0,
3118 user_data: UserDefinedData::new_zeroed(),
3119 guest_specified_interrupt_info: None,
3120 flags: Default::default(),
3121 };
3122
3123 match channel.state {
3124 ChannelState::Closed => self.open_channel(offer_id, &request, reserved_state),
3125 ChannelState::Revoked | ChannelState::Reoffered => {}
3126
3127 ChannelState::Open { .. } | ChannelState::Opening { .. } => {
3128 return Err(ChannelError::ChannelAlreadyOpen);
3129 }
3130
3131 ChannelState::Closing { .. } | ChannelState::ClosingReopen { .. } => {
3132 return Err(ChannelError::InvalidChannelState);
3133 }
3134
3135 ChannelState::ClientReleased
3136 | ChannelState::ClosingClientRelease
3137 | ChannelState::OpeningClientRelease => unreachable!(),
3138 }
3139 Ok(())
3140 }
3141
3142 fn handle_close_reserved_channel(
3145 &mut self,
3146 input: &protocol::CloseReservedChannel,
3147 ) -> Result<(), ChannelError> {
3148 let (offer_id, channel) = self
3149 .inner
3150 .channels
3151 .get_by_channel_id_mut(&self.inner.assigned_channels, input.channel_id)?;
3152
3153 match channel.state {
3154 ChannelState::Open {
3155 params,
3156 reserved_state: Some(mut resvd),
3157 ..
3158 } => {
3159 resvd.target.vp = input.target_vp;
3160 resvd.target.sint = input.target_sint as u8;
3161 channel.state = ChannelState::Closing {
3162 params,
3163 reserved_state: Some(resvd),
3164 };
3165 self.notifier.notify(offer_id, Action::Close);
3166 }
3167
3168 ChannelState::Open {
3169 reserved_state: None,
3170 ..
3171 } => return Err(ChannelError::ChannelNotReserved),
3172
3173 ChannelState::Revoked | ChannelState::Reoffered => {}
3174
3175 ChannelState::Closed
3176 | ChannelState::Opening { .. }
3177 | ChannelState::Closing { .. }
3178 | ChannelState::ClosingReopen { .. } => return Err(ChannelError::ChannelNotOpen),
3179
3180 ChannelState::ClientReleased
3181 | ChannelState::ClosingClientRelease
3182 | ChannelState::OpeningClientRelease => unreachable!(),
3183 }
3184
3185 Ok(())
3186 }
3187
3188 #[must_use]
3192 fn client_release_channel(
3193 mut sender: MessageSender<'_, N>,
3194 offer_id: OfferId,
3195 channel: &mut Channel,
3196 gpadls: &mut GpadlMap,
3197 incomplete_gpadls: &mut IncompleteGpadlMap,
3198 assigned_channels: &mut AssignedChannels,
3199 assigned_monitors: &mut AssignedMonitors,
3200 info: Option<&ConnectionInfo>,
3201 ) -> bool {
3202 tracelimit::info_ratelimited!(?offer_id, key = %channel.offer.key(), "client released channel");
3203 gpadls.retain(|&(gpadl_id, gpadl_offer_id), gpadl| {
3205 if gpadl_offer_id != offer_id {
3206 return true;
3207 }
3208 match gpadl.state {
3209 GpadlState::InProgress => {
3210 incomplete_gpadls.remove(&gpadl_id);
3211 false
3212 }
3213 GpadlState::Offered => {
3214 gpadl.state = GpadlState::OfferedTearingDown;
3215 true
3216 }
3217 GpadlState::Accepted => {
3218 if channel.state.is_revoked() {
3219 false
3221 } else {
3222 gpadl.state = GpadlState::TearingDown;
3223 sender.notifier.notify(
3224 offer_id,
3225 Action::TeardownGpadl {
3226 gpadl_id,
3227 post_restore: false,
3228 },
3229 );
3230 true
3231 }
3232 }
3233 GpadlState::OfferedTearingDown | GpadlState::TearingDown => true,
3234 }
3235 });
3236
3237 let remove = match &mut channel.state {
3238 ChannelState::Closed => {
3239 channel.state = ChannelState::ClientReleased;
3240 false
3241 }
3242 ChannelState::Reoffered => {
3243 if let Some(info) = info {
3244 channel.state = ChannelState::Closed;
3245 channel.restore_state = RestoreState::New;
3246 sender.send_offer(channel, info);
3247 return false;
3249 }
3250 channel.state = ChannelState::ClientReleased;
3251 false
3252 }
3253 ChannelState::Revoked => {
3254 channel.state = ChannelState::ClientReleased;
3255 true
3256 }
3257 ChannelState::Opening { .. } => {
3258 channel.state = ChannelState::OpeningClientRelease;
3259 false
3260 }
3261 ChannelState::Open { .. } => {
3262 channel.state = ChannelState::ClosingClientRelease;
3263 sender.notifier.notify(offer_id, Action::Close);
3264 false
3265 }
3266 ChannelState::Closing { .. } | ChannelState::ClosingReopen { .. } => {
3267 channel.state = ChannelState::ClosingClientRelease;
3268 false
3269 }
3270
3271 ChannelState::ClosingClientRelease
3272 | ChannelState::OpeningClientRelease
3273 | ChannelState::ClientReleased => false,
3274 };
3275
3276 assert!(channel.state.is_released());
3277
3278 channel.release_channel(offer_id, assigned_channels, assigned_monitors);
3279 remove
3280 }
3281
3282 fn handle_rel_id_released(
3284 &mut self,
3285 input: &protocol::RelIdReleased,
3286 ) -> Result<(), ChannelError> {
3287 let channel_id = input.channel_id;
3288 let (offer_id, channel) = self
3289 .inner
3290 .channels
3291 .get_by_channel_id_mut(&self.inner.assigned_channels, channel_id)?;
3292
3293 match channel.state {
3294 ChannelState::Closed
3295 | ChannelState::Revoked
3296 | ChannelState::Closing { .. }
3297 | ChannelState::Reoffered => {
3298 if Self::client_release_channel(
3299 self.inner
3300 .pending_messages
3301 .sender(self.notifier, self.inner.state.is_paused()),
3302 offer_id,
3303 channel,
3304 &mut self.inner.gpadls,
3305 &mut self.inner.incomplete_gpadls,
3306 &mut self.inner.assigned_channels,
3307 &mut self.inner.assigned_monitors,
3308 self.inner.state.get_connected_info(),
3309 ) {
3310 self.inner.channels.remove(offer_id);
3311 }
3312
3313 self.check_disconnected();
3314 }
3315
3316 ChannelState::Opening { .. }
3317 | ChannelState::Open { .. }
3318 | ChannelState::ClosingReopen { .. } => return Err(ChannelError::InvalidChannelState),
3319
3320 ChannelState::ClientReleased
3321 | ChannelState::OpeningClientRelease
3322 | ChannelState::ClosingClientRelease => unreachable!(),
3323 }
3324 Ok(())
3325 }
3326
3327 fn handle_tl_connect_request(&mut self, request: protocol::TlConnectRequest2) {
3330 let version = self
3331 .inner
3332 .state
3333 .get_version()
3334 .expect("must be connected")
3335 .version;
3336
3337 let hosted_silo_unaware = version < Version::Win10Rs5;
3338 self.notifier
3339 .notify_hvsock(&HvsockConnectRequest::from_message(
3340 request,
3341 hosted_silo_unaware,
3342 ));
3343 }
3344
3345 pub fn send_tl_connect_result(&mut self, result: HvsockConnectResult) {
3347 if !result.success && self.inner.state.check_version(Version::Win10Rs3_0) {
3351 self.sender().send_message(&protocol::TlConnectResult {
3355 service_id: result.service_id,
3356 endpoint_id: result.endpoint_id,
3357 status: protocol::STATUS_CONNECTION_REFUSED,
3358 })
3359 }
3360 }
3361
3362 fn handle_modify_channel(
3365 &mut self,
3366 request: &protocol::ModifyChannel,
3367 ) -> Result<(), ChannelError> {
3368 let result = self.modify_channel(request);
3369 if result.is_err() {
3370 self.send_modify_channel_response(request.channel_id, protocol::STATUS_UNSUCCESSFUL);
3371 }
3372
3373 result
3374 }
3375
3376 fn modify_channel(&mut self, request: &protocol::ModifyChannel) -> Result<(), ChannelError> {
3378 if request.target_vp == protocol::VP_INDEX_DISABLE_INTERRUPT {
3380 return Err(ChannelError::InvalidTargetVp);
3381 }
3382
3383 let (offer_id, channel) = self
3384 .inner
3385 .channels
3386 .get_by_channel_id_mut(&self.inner.assigned_channels, request.channel_id)?;
3387
3388 let (open_request, modify_state) = match &mut channel.state {
3389 ChannelState::Open {
3390 params,
3391 modify_state,
3392 reserved_state: None,
3393 } => (params, modify_state),
3394 _ => return Err(ChannelError::InvalidChannelState),
3395 };
3396
3397 if open_request.target_vp.is_none() {
3398 return Err(ChannelError::InterruptsDisabled);
3399 }
3400
3401 if let ModifyState::Modifying { pending_target_vp } = modify_state {
3402 if self.inner.state.check_version(Version::Iron) {
3403 tracelimit::warn_ratelimited!(
3406 key = %channel.offer.key(),
3407 "Client sent new ModifyChannel before receiving ModifyChannelResponse."
3408 );
3409 } else {
3410 *pending_target_vp = Some(request.target_vp);
3413 }
3414 } else {
3415 self.notifier.notify(
3416 offer_id,
3417 Action::Modify {
3418 target_vp: request.target_vp,
3419 },
3420 );
3421
3422 open_request.target_vp = Some(request.target_vp);
3424 *modify_state = ModifyState::Modifying {
3425 pending_target_vp: None,
3426 };
3427 }
3428
3429 Ok(())
3430 }
3431
3432 pub fn modify_channel_complete(&mut self, offer_id: OfferId, status: i32) {
3439 let channel = &mut self.inner.channels[offer_id];
3440
3441 if let ChannelState::Open {
3442 params,
3443 modify_state: ModifyState::Modifying { pending_target_vp },
3444 reserved_state: None,
3445 } = channel.state
3446 {
3447 channel.state = ChannelState::Open {
3448 params,
3449 modify_state: ModifyState::NotModifying,
3450 reserved_state: None,
3451 };
3452
3453 let channel_id = channel.info.as_ref().expect("assigned").channel_id;
3455 let key = channel.offer.key();
3456 self.send_modify_channel_response(channel_id, status);
3457
3458 if let Some(target_vp) = pending_target_vp {
3460 let request = protocol::ModifyChannel {
3461 channel_id,
3462 target_vp,
3463 };
3464
3465 if let Err(error) = self.handle_modify_channel(&request) {
3466 tracelimit::warn_ratelimited!(?error, %key, "Pending ModifyChannel request failed.")
3467 }
3468 }
3469 }
3470 }
3471
3472 fn send_modify_channel_response(&mut self, channel_id: ChannelId, status: i32) {
3473 if self.inner.state.check_version(Version::Iron) {
3474 self.sender()
3475 .send_message(&protocol::ModifyChannelResponse { channel_id, status });
3476 }
3477 }
3478
3479 fn handle_modify_connection(&mut self, request: protocol::ModifyConnection) {
3480 if let Err(err) = self.modify_connection(request) {
3481 tracelimit::error_ratelimited!(?err, "modifying connection failed");
3482 self.complete_modify_connection(ModifyConnectionResponse::Modified(
3483 protocol::ConnectionState::FAILED_UNKNOWN_FAILURE,
3484 ));
3485 }
3486 }
3487
3488 fn modify_connection(&mut self, request: protocol::ModifyConnection) -> anyhow::Result<()> {
3489 let ConnectionState::Connected(info) = &mut self.inner.state else {
3490 anyhow::bail!(
3491 "Invalid state for ModifyConnection request: {:?}",
3492 self.inner.state
3493 );
3494 };
3495
3496 if info.modifying {
3497 anyhow::bail!(
3498 "Duplicate ModifyConnection request, state: {:?}",
3499 self.inner.state
3500 );
3501 }
3502
3503 if matches!(
3504 info.monitor_page,
3505 Some(MonitorPageGpaInfo {
3506 server_allocated: true,
3507 ..
3508 })
3509 ) {
3510 anyhow::bail!("Cannot modify server-allocated monitor pages");
3511 }
3512
3513 if (request.child_to_parent_monitor_page_gpa == 0)
3514 != (request.parent_to_child_monitor_page_gpa == 0)
3515 {
3516 anyhow::bail!("Guest must specify either both or no monitor pages, {request:?}");
3517 }
3518
3519 let monitor_page = (request.child_to_parent_monitor_page_gpa != 0).then_some(
3520 MonitorPageGpaInfo::from_guest_gpas(MonitorPageGpas {
3521 child_to_parent: request.child_to_parent_monitor_page_gpa,
3522 parent_to_child: request.parent_to_child_monitor_page_gpa,
3523 }),
3524 );
3525
3526 info.modifying = true;
3527 info.monitor_page = monitor_page;
3528 tracing::debug!("modifying connection parameters.");
3529 self.notifier.modify_connection(request.into())?;
3530
3531 Ok(())
3532 }
3533
3534 pub fn complete_modify_connection(&mut self, response: ModifyConnectionResponse) {
3535 tracing::debug!(?response, "modifying connection parameters complete");
3536
3537 match &mut self.inner.state {
3541 ConnectionState::Connecting { .. } => self.complete_initiate_contact(response),
3542 ConnectionState::Disconnecting { .. } => self.complete_disconnect(),
3543 ConnectionState::Connected(info) => {
3544 let ModifyConnectionResponse::Modified(connection_state) = response else {
3545 panic!(
3546 "Relay should not return {:?} for a modify request with no version.",
3547 response
3548 );
3549 };
3550
3551 if !info.modifying {
3552 panic!(
3553 "ModifyConnection response while not modifying, state: {:?}",
3554 self.inner.state
3555 );
3556 }
3557
3558 info.modifying = false;
3559 self.sender()
3560 .send_message(&protocol::ModifyConnectionResponse { connection_state });
3561 }
3562 _ => panic!(
3563 "Invalid state for ModifyConnection response: {:?}",
3564 self.inner.state
3565 ),
3566 }
3567 }
3568
3569 fn handle_pause(&mut self) {
3570 tracelimit::info_ratelimited!("pausing sending messages");
3571 self.sender().send_message(&protocol::PauseResponse {});
3572 let ConnectionState::Connected(info) = &mut self.inner.state else {
3573 unreachable!(
3574 "in unexpected state {:?}, should be prevented by Message::parse()",
3575 self.inner.state
3576 );
3577 };
3578 info.paused = true;
3579 }
3580
3581 pub fn handle_synic_message(&mut self, message: SynicMessage) -> Result<(), ChannelError> {
3583 assert!(!self.is_resetting());
3584
3585 let version = self.inner.state.get_version();
3586 let msg = Message::parse(&message.data, version)?;
3587 tracing::trace!(?msg, message.trusted, "received vmbus message");
3588 if self.inner.state.is_trusted() && !message.trusted {
3593 tracelimit::warn_ratelimited!(?msg, "Received untrusted message");
3594 return Err(ChannelError::UntrustedMessage);
3595 }
3596
3597 match &mut self.inner.state {
3599 ConnectionState::Connected(info) if info.paused => {
3600 if !matches!(
3601 msg,
3602 Message::Resume(..)
3603 | Message::Unload(..)
3604 | Message::InitiateContact { .. }
3605 | Message::InitiateContact2 { .. }
3606 ) {
3607 tracelimit::warn_ratelimited!(?msg, "Received message while paused");
3608 return Err(ChannelError::Paused);
3609 }
3610 tracelimit::info_ratelimited!("resuming sending messages");
3611 info.paused = false;
3612 }
3613 _ => {}
3614 }
3615
3616 match msg {
3617 Message::InitiateContact2(input, ..) => {
3618 self.handle_initiate_contact(&input, &message, true)?
3619 }
3620 Message::InitiateContact(input, ..) => {
3621 self.handle_initiate_contact(&input.into(), &message, false)?
3622 }
3623 Message::Unload(..) => self.handle_unload(),
3624 Message::RequestOffers(..) => self.handle_request_offers()?,
3625 Message::GpadlHeader(input, range) => self.handle_gpadl_header(&input, range),
3626 Message::GpadlBody(input, range) => self.handle_gpadl_body(&input, range)?,
3627 Message::GpadlTeardown(input, ..) => self.handle_gpadl_teardown(&input)?,
3628 Message::OpenChannel(input, ..) => self.handle_open_channel(&input.into())?,
3629 Message::OpenChannel2(input, ..) => self.handle_open_channel(&input)?,
3630 Message::CloseChannel(input, ..) => self.handle_close_channel(&input)?,
3631 Message::RelIdReleased(input, ..) => self.handle_rel_id_released(&input)?,
3632 Message::TlConnectRequest(input, ..) => self.handle_tl_connect_request(input.into()),
3633 Message::TlConnectRequest2(input, ..) => self.handle_tl_connect_request(input),
3634 Message::ModifyChannel(input, ..) => self.handle_modify_channel(&input)?,
3635 Message::ModifyConnection(input, ..) => self.handle_modify_connection(input),
3636 Message::OpenReservedChannel(input, ..) => self.handle_open_reserved_channel(
3637 &input,
3638 version.expect("version validated by Message::parse"),
3639 )?,
3640 Message::CloseReservedChannel(input, ..) => {
3641 self.handle_close_reserved_channel(&input)?
3642 }
3643 Message::Pause(protocol::Pause, ..) => self.handle_pause(),
3644 Message::Resume(protocol::Resume, ..) => {}
3645 Message::OfferChannel(..)
3647 | Message::RescindChannelOffer(..)
3648 | Message::AllOffersDelivered(..)
3649 | Message::OpenResult(..)
3650 | Message::GpadlCreated(..)
3651 | Message::GpadlTorndown(..)
3652 | Message::VersionResponse(..)
3653 | Message::VersionResponse2(..)
3654 | Message::VersionResponse3(..)
3655 | Message::UnloadComplete(..)
3656 | Message::CloseReservedChannelResponse(..)
3657 | Message::TlConnectResult(..)
3658 | Message::ModifyChannelResponse(..)
3659 | Message::ModifyConnectionResponse(..)
3660 | Message::PauseResponse(..) => {
3661 unreachable!("Server received client message {:?}", msg);
3662 }
3663 }
3664 Ok(())
3665 }
3666
3667 pub fn gpadl_create_complete(&mut self, offer_id: OfferId, gpadl_id: GpadlId, status: i32) {
3669 let Some(gpadl) = self.inner.gpadls.get_mut(&(gpadl_id, offer_id)) else {
3670 tracelimit::error_ratelimited!(
3671 ?offer_id,
3672 key = %self.inner.channels[offer_id].offer.key(),
3673 ?gpadl_id,
3674 "invalid gpadl ID for channel"
3675 );
3676 return;
3677 };
3678 let retain = match gpadl.state {
3679 GpadlState::InProgress | GpadlState::TearingDown | GpadlState::Accepted => {
3680 tracelimit::error_ratelimited!(?offer_id, ?gpadl_id, ?gpadl, "invalid gpadl state");
3681 return;
3682 }
3683 GpadlState::Offered => {
3684 let channel_id = self.inner.channels[offer_id]
3685 .info
3686 .as_ref()
3687 .expect("assigned")
3688 .channel_id;
3689 self.inner
3690 .pending_messages
3691 .sender(self.notifier, self.inner.state.is_paused())
3692 .send_gpadl_created(channel_id, gpadl_id, status);
3693 if status >= 0 {
3694 gpadl.state = GpadlState::Accepted;
3695 true
3696 } else {
3697 false
3698 }
3699 }
3700 GpadlState::OfferedTearingDown => {
3701 if status >= 0 {
3702 self.notifier.notify(
3704 offer_id,
3705 Action::TeardownGpadl {
3706 gpadl_id,
3707 post_restore: false,
3708 },
3709 );
3710 gpadl.state = GpadlState::TearingDown;
3711 true
3712 } else {
3713 false
3714 }
3715 }
3716 };
3717 if !retain {
3718 self.inner
3719 .gpadls
3720 .remove(&(gpadl_id, offer_id))
3721 .expect("gpadl validated above");
3722
3723 self.check_disconnected();
3724 }
3725 }
3726
3727 pub fn gpadl_teardown_complete(&mut self, offer_id: OfferId, gpadl_id: GpadlId) {
3729 let channel = &mut self.inner.channels[offer_id];
3730 let Some(gpadl) = self.inner.gpadls.get_mut(&(gpadl_id, offer_id)) else {
3731 tracelimit::error_ratelimited!(
3732 ?offer_id,
3733 key = %channel.offer.key(),
3734 ?gpadl_id,
3735 "invalid gpadl ID for channel"
3736 );
3737 return;
3738 };
3739 tracing::debug!(
3740 offer_id = offer_id.0,
3741 key = %channel.offer.key(),
3742 gpadl_id = gpadl_id.0,
3743 "Gpadl teardown complete"
3744 );
3745 match gpadl.state {
3746 GpadlState::InProgress
3747 | GpadlState::Offered
3748 | GpadlState::OfferedTearingDown
3749 | GpadlState::Accepted => {
3750 tracelimit::error_ratelimited!(?offer_id, key = %channel.offer.key(), ?gpadl_id, ?gpadl, "invalid gpadl state");
3751 }
3752 GpadlState::TearingDown => {
3753 if !channel.state.is_released() {
3754 self.sender().send_gpadl_torndown(gpadl_id);
3755 }
3756 self.inner
3757 .gpadls
3758 .remove(&(gpadl_id, offer_id))
3759 .expect("gpadl validated above");
3760
3761 self.check_disconnected();
3762 }
3763 }
3764 }
3765
3766 fn sender(&mut self) -> MessageSender<'_, N> {
3771 self.inner
3772 .pending_messages
3773 .sender(self.notifier, self.inner.state.is_paused())
3774 }
3775}
3776
3777fn revoke<N: Notifier>(
3778 mut sender: MessageSender<'_, N>,
3779 offer_id: OfferId,
3780 channel: &mut Channel,
3781 gpadls: &mut GpadlMap,
3782) -> bool {
3783 let info = match channel.state {
3784 ChannelState::Closed
3785 | ChannelState::Open { .. }
3786 | ChannelState::Opening { .. }
3787 | ChannelState::Closing { .. }
3788 | ChannelState::ClosingReopen { .. } => {
3789 channel.state = ChannelState::Revoked;
3790 Some(channel.info.as_ref().expect("assigned"))
3791 }
3792 ChannelState::Reoffered => {
3793 channel.state = ChannelState::Revoked;
3794 None
3795 }
3796 ChannelState::ClientReleased
3797 | ChannelState::OpeningClientRelease
3798 | ChannelState::ClosingClientRelease => None,
3799 ChannelState::Revoked => return true,
3801 };
3802 let retain = !channel.state.is_released();
3803
3804 gpadls.retain(|&(gpadl_id, gpadl_offer_id), gpadl| {
3806 if gpadl_offer_id != offer_id {
3807 return true;
3808 }
3809
3810 match gpadl.state {
3811 GpadlState::InProgress => true,
3812 GpadlState::Offered => {
3813 if let Some(info) = info {
3814 sender.send_gpadl_created(
3815 info.channel_id,
3816 gpadl_id,
3817 protocol::STATUS_UNSUCCESSFUL,
3818 );
3819 }
3820 false
3821 }
3822 GpadlState::OfferedTearingDown => false,
3823 GpadlState::Accepted => true,
3824 GpadlState::TearingDown => {
3825 if info.is_some() {
3826 sender.send_gpadl_torndown(gpadl_id);
3827 }
3828 false
3829 }
3830 }
3831 });
3832 if let Some(info) = info {
3833 sender.send_rescind(info);
3834 }
3835 if channel.restore_state != RestoreState::New {
3837 channel.restore_state = RestoreState::Restored;
3838 }
3839 retain
3840}
3841
3842struct PendingMessages(VecDeque<OutgoingMessage>);
3843
3844impl PendingMessages {
3845 fn sender<'a, N: Notifier>(
3847 &'a mut self,
3848 notifier: &'a mut N,
3849 is_paused: bool,
3850 ) -> MessageSender<'a, N> {
3851 MessageSender {
3852 notifier,
3853 pending_messages: self,
3854 is_paused,
3855 }
3856 }
3857}
3858
3859struct MessageSender<'a, N> {
3862 notifier: &'a mut N,
3863 pending_messages: &'a mut PendingMessages,
3864 is_paused: bool,
3865}
3866
3867impl<N: Notifier> MessageSender<'_, N> {
3868 fn send_message<
3870 T: IntoBytes + protocol::VmbusMessage + std::fmt::Debug + Immutable + KnownLayout,
3871 >(
3872 &mut self,
3873 msg: &T,
3874 ) {
3875 let message = OutgoingMessage::new(msg);
3876
3877 tracing::trace!(typ = ?T::MESSAGE_TYPE, ?msg, "sending message");
3878 if !self.pending_messages.0.is_empty()
3880 || self.is_paused
3881 || !self.notifier.send_message(&message, MessageTarget::Default)
3882 {
3883 tracing::trace!("message queued");
3884 self.pending_messages.0.push_back(message);
3886 }
3887 }
3888
3889 fn send_message_with_target<
3891 T: IntoBytes + protocol::VmbusMessage + std::fmt::Debug + Immutable + KnownLayout,
3892 >(
3893 &mut self,
3894 msg: &T,
3895 target: MessageTarget,
3896 ) {
3897 if target == MessageTarget::Default {
3898 self.send_message(msg);
3899 } else {
3900 tracing::trace!(typ = ?T::MESSAGE_TYPE, ?msg, "sending message");
3901 let message = OutgoingMessage::new(msg);
3904 if !self.notifier.send_message(&message, target) {
3905 tracelimit::warn_ratelimited!(?target, "failed to send message");
3906 }
3907 }
3908 }
3909
3910 fn send_offer(&mut self, channel: &mut Channel, connection_info: &ConnectionInfo) {
3912 let info = channel.info.as_ref().expect("assigned");
3913 let mut flags = channel.offer.flags;
3914 if !connection_info
3915 .version
3916 .feature_flags
3917 .confidential_channels()
3918 {
3919 flags.set_confidential_ring_buffer(false);
3920 flags.set_confidential_external_memory(false);
3921 }
3922
3923 let monitor_id = connection_info.monitor_page.and(info.monitor_id);
3930 let msg = protocol::OfferChannel {
3931 interface_id: channel.offer.interface_id,
3932 instance_id: channel.offer.instance_id,
3933 rsvd: [0; 4],
3934 flags,
3935 mmio_megabytes: channel.offer.mmio_megabytes,
3936 user_defined: channel.offer.user_defined,
3937 subchannel_index: channel.offer.subchannel_index,
3938 mmio_megabytes_optional: channel.offer.mmio_megabytes_optional,
3939 channel_id: info.channel_id,
3940 monitor_id: monitor_id.unwrap_or(MonitorId::INVALID).0,
3941 monitor_allocated: monitor_id.is_some().into(),
3942 is_dedicated: 1,
3945 connection_id: info.connection_id,
3946 };
3947 tracing::info!(
3948 channel_id = msg.channel_id.0,
3949 connection_id = msg.connection_id,
3950 key = %channel.offer.key(),
3951 "sending offer to guest"
3952 );
3953
3954 self.send_message(&msg);
3955 }
3956
3957 fn send_open_result(
3958 &mut self,
3959 channel_id: ChannelId,
3960 open_request: &OpenRequest,
3961 result: i32,
3962 target: MessageTarget,
3963 ) {
3964 self.send_message_with_target(
3965 &protocol::OpenResult {
3966 channel_id,
3967 open_id: open_request.open_id,
3968 status: result as u32,
3969 },
3970 target,
3971 );
3972 }
3973
3974 fn send_gpadl_created(&mut self, channel_id: ChannelId, gpadl_id: GpadlId, status: i32) {
3975 self.send_message(&protocol::GpadlCreated {
3976 channel_id,
3977 gpadl_id,
3978 status,
3979 });
3980 }
3981
3982 fn send_gpadl_torndown(&mut self, gpadl_id: GpadlId) {
3983 self.send_message(&protocol::GpadlTorndown { gpadl_id });
3984 }
3985
3986 fn send_rescind(&mut self, info: &OfferedInfo) {
3987 tracing::info!(
3988 channel_id = info.channel_id.0,
3989 "rescinding channel from guest"
3990 );
3991
3992 self.send_message(&protocol::RescindChannelOffer {
3993 channel_id: info.channel_id,
3994 });
3995 }
3996}
3997
3998struct VersionResponseData {
4000 version: VersionInfo,
4001 state: protocol::ConnectionState,
4002 monitor_pages: Option<MonitorPageGpas>,
4003}
4004
4005impl VersionResponseData {
4006 fn new(version: VersionInfo, state: protocol::ConnectionState) -> Self {
4008 VersionResponseData {
4009 version,
4010 state,
4011 monitor_pages: None,
4012 }
4013 }
4014
4015 fn with_monitor_pages(mut self, monitor_pages: Option<MonitorPageGpas>) -> Self {
4017 self.monitor_pages = monitor_pages;
4018 self
4019 }
4020}