1pub mod saved_state;
5#[cfg(test)]
6mod tests;
7
8use crate::Guid;
9use crate::SynicMessage;
10use crate::monitor::AssignedMonitors;
11use crate::protocol::Version;
12use hvdef::Vtl;
13use inspect::Inspect;
14pub use saved_state::RestoreError;
15pub use saved_state::SavedState;
16pub use saved_state::SavedStateData;
17use slab::Slab;
18use std::cmp::min;
19use std::collections::VecDeque;
20use std::collections::hash_map::Entry;
21use std::collections::hash_map::HashMap;
22use std::fmt::Display;
23use std::ops::Index;
24use std::ops::IndexMut;
25use std::task::Poll;
26use std::task::ready;
27use std::time::Duration;
28use thiserror::Error;
29use vmbus_channel::bus::ChannelType;
30use vmbus_channel::bus::GpadlRequest;
31use vmbus_channel::bus::OfferKey;
32use vmbus_channel::bus::OfferParams;
33use vmbus_channel::bus::OpenData;
34use vmbus_channel::bus::RestoredGpadl;
35use vmbus_core::HvsockConnectRequest;
36use vmbus_core::HvsockConnectResult;
37use vmbus_core::MaxVersionInfo;
38use vmbus_core::OutgoingMessage;
39use vmbus_core::VMBUS_SINT;
40use vmbus_core::VersionInfo;
41use vmbus_core::protocol;
42use vmbus_core::protocol::ChannelId;
43use vmbus_core::protocol::ConnectionId;
44use vmbus_core::protocol::FeatureFlags;
45use vmbus_core::protocol::GpadlId;
46use vmbus_core::protocol::Message;
47use vmbus_core::protocol::OfferFlags;
48use vmbus_core::protocol::UserDefinedData;
49use vmbus_ring::gparange;
50use vmcore::monitor::MonitorId;
51use vmcore::synic::MonitorInfo;
52use vmcore::synic::MonitorPageGpas;
53use zerocopy::FromZeros;
54use zerocopy::Immutable;
55use zerocopy::IntoBytes;
56use zerocopy::KnownLayout;
57
58#[derive(Debug, Error)]
60pub enum ChannelError {
61 #[error("unknown channel ID")]
62 UnknownChannelId,
63 #[error("unknown GPADL ID")]
64 UnknownGpadlId,
65 #[error("parse error")]
66 ParseError(#[from] protocol::ParseError),
67 #[error("invalid gpa range")]
68 InvalidGpaRange(#[source] gparange::Error),
69 #[error("duplicate GPADL ID")]
70 DuplicateGpadlId,
71 #[error("GPADL is already complete")]
72 GpadlAlreadyComplete,
73 #[error("GPADL channel ID mismatch")]
74 WrongGpadlChannelId,
75 #[error("trying to open an open channel")]
76 ChannelAlreadyOpen,
77 #[error("trying to close a closed channel")]
78 ChannelNotOpen,
79 #[error("invalid GPADL state for operation")]
80 InvalidGpadlState,
81 #[error("invalid channel state for operation")]
82 InvalidChannelState,
83 #[error("channel ID has already been released")]
84 ChannelReleased,
85 #[error("channel offers have already been sent")]
86 OffersAlreadySent,
87 #[error("invalid operation on reserved channel")]
88 ChannelReserved,
89 #[error("invalid operation on non-reserved channel")]
90 ChannelNotReserved,
91 #[error("received untrusted message for trusted connection")]
92 UntrustedMessage,
93 #[error("received a non-resuming message while paused")]
94 Paused,
95}
96
97#[derive(Debug, Error)]
98pub enum OfferError {
99 #[error("the channel ID {} is not valid for this operation", (.0).0)]
100 InvalidChannelId(ChannelId),
101 #[error("the channel ID {} is already in use", (.0).0)]
102 ChannelIdInUse(ChannelId),
103 #[error("offer {0} already exists")]
104 AlreadyExists(OfferKey),
105 #[error("specified resources do not match those of the existing saved or revoked offer")]
106 IncompatibleResources,
107 #[error("too many channels have been offered")]
108 TooManyChannels,
109 #[error("mismatched monitor ID from saved state; expected {0:?}, actual {1:?}")]
110 MismatchedMonitorId(Option<MonitorId>, MonitorId),
111}
112
113#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
115pub struct OfferId(usize);
116
117type IncompleteGpadlMap = HashMap<GpadlId, OfferId>;
118
119type GpadlMap = HashMap<(GpadlId, OfferId), Gpadl>;
120
121pub struct Server {
123 state: ConnectionState,
124 channels: ChannelList,
125 assigned_channels: AssignedChannels,
126 assigned_monitors: AssignedMonitors,
127 gpadls: GpadlMap,
128 incomplete_gpadls: IncompleteGpadlMap,
129 child_connection_id: u32,
130 max_version: Option<MaxVersionInfo>,
131 delayed_max_version: Option<MaxVersionInfo>,
132 pending_messages: PendingMessages,
135 require_server_allocated_mnf: bool,
139 use_absolute_channel_order: bool,
140}
141
142pub struct ServerWithNotifier<'a, T> {
143 inner: &'a mut Server,
144 notifier: &'a mut T,
145}
146
147impl<T> Drop for ServerWithNotifier<'_, T> {
148 fn drop(&mut self) {
149 self.inner.validate();
150 }
151}
152
153impl<T: Notifier> Inspect for ServerWithNotifier<'_, T> {
154 fn inspect(&self, req: inspect::Request<'_>) {
155 let mut resp = req.respond();
156 let (state, info, next_action) = match &self.inner.state {
157 ConnectionState::Disconnected => ("disconnected", None, None),
158 ConnectionState::Connecting { info, .. } => ("connecting", Some(info), None),
159 ConnectionState::Connected(info) => (
160 if info.offers_sent {
161 "connected"
162 } else {
163 "negotiated"
164 },
165 Some(info),
166 None,
167 ),
168 ConnectionState::Disconnecting { next_action, .. } => {
169 ("disconnecting", None, Some(next_action))
170 }
171 };
172
173 resp.field("connection_info", info);
174 let next_action = next_action.map(|a| match a {
175 ConnectionAction::None => "disconnect",
176 ConnectionAction::Reset => "reset",
177 ConnectionAction::SendUnloadComplete => "unload",
178 ConnectionAction::Reconnect { .. } => "reconnect",
179 ConnectionAction::SendFailedVersionResponse => "send_version_response",
180 });
181 resp.field("state", state)
182 .field("next_action", next_action)
183 .field(
184 "assigned_monitors_bitmap",
185 format_args!("{:x}", self.inner.assigned_monitors.bitmap()),
186 )
187 .child("channels", |req| {
188 let mut resp = req.respond();
189 self.inner
190 .channels
191 .inspect(self.notifier, self.inner.get_version(), &mut resp);
192 for ((gpadl_id, offer_id), gpadl) in &self.inner.gpadls {
193 let channel = &self.inner.channels[*offer_id];
194 resp.field(
195 &channel_inspect_path(
196 &channel.offer,
197 format_args!("/gpadls/{}", gpadl_id.0),
198 ),
199 gpadl,
200 );
201 }
202 });
203 }
204}
205
206#[derive(Debug, Copy, Clone, Inspect)]
208struct MonitorPageGpaInfo {
209 gpas: MonitorPageGpas,
210 server_allocated: bool,
211}
212
213impl MonitorPageGpaInfo {
214 fn from_guest_gpas(gpas: MonitorPageGpas) -> Self {
216 Self {
217 gpas,
218 server_allocated: false,
219 }
220 }
221
222 fn from_server_gpas(gpas: MonitorPageGpas) -> Self {
224 Self {
225 gpas,
226 server_allocated: true,
227 }
228 }
229}
230
231#[derive(Debug, Copy, Clone, Inspect)]
232struct ConnectionInfo {
233 version: VersionInfo,
234 trusted: bool,
237 offers_sent: bool,
238 interrupt_page: Option<u64>,
239 monitor_page: Option<MonitorPageGpaInfo>,
240 target_message_vp: u32,
241 modifying: bool,
242 client_id: Guid,
243 paused: bool,
244}
245
246#[derive(Debug)]
248enum ConnectionState {
249 Disconnected,
250 Disconnecting {
251 next_action: ConnectionAction,
252 modify_sent: bool,
253 },
254 Connecting {
255 info: ConnectionInfo,
256 next_action: ConnectionAction,
257 },
258 Connected(ConnectionInfo),
259}
260
261impl ConnectionState {
262 fn check_version(&self, min_version: Version) -> bool {
264 matches!(self, ConnectionState::Connected(info) if info.version.version >= min_version)
265 }
266
267 fn check_feature_flags(&self, flags: impl Fn(FeatureFlags) -> bool) -> bool {
270 matches!(self, ConnectionState::Connected(info) if flags(info.version.feature_flags))
271 }
272
273 fn get_version(&self) -> Option<VersionInfo> {
274 if let ConnectionState::Connected(info) = self {
275 Some(info.version)
276 } else {
277 None
278 }
279 }
280
281 fn get_connected_info(&self) -> Option<&ConnectionInfo> {
283 if let ConnectionState::Connected(info) = self {
284 Some(info)
285 } else {
286 None
287 }
288 }
289
290 fn is_trusted(&self) -> bool {
291 match self {
292 ConnectionState::Connected(info) => info.trusted,
293 ConnectionState::Connecting { info, .. } => info.trusted,
294 _ => false,
295 }
296 }
297
298 fn is_paused(&self) -> bool {
299 if let ConnectionState::Connected(info) = self {
300 info.paused
301 } else {
302 false
303 }
304 }
305}
306
307#[derive(Debug, Copy, Clone)]
308enum ConnectionAction {
309 None,
310 Reset,
311 SendUnloadComplete,
312 Reconnect {
313 initiate_contact: InitiateContactRequest,
314 },
315 SendFailedVersionResponse,
316}
317
318#[derive(PartialEq, Eq, Debug, Copy, Clone)]
319pub enum MonitorPageRequest {
320 None,
321 Some(MonitorPageGpas),
322 Invalid,
323}
324
325#[derive(PartialEq, Eq, Debug, Copy, Clone)]
326pub struct InitiateContactRequest {
327 pub version_requested: u32,
328 pub target_message_vp: u32,
329 pub monitor_page: MonitorPageRequest,
330 pub target_sint: u8,
331 pub target_vtl: u8,
332 pub feature_flags: u32,
333 pub interrupt_page: Option<u64>,
334 pub client_id: Guid,
335 pub trusted: bool,
336}
337
338#[derive(Debug, Copy, Clone)]
339pub struct OpenRequest {
340 pub open_id: u32,
341 pub ring_buffer_gpadl_id: GpadlId,
342 pub target_vp: u32,
343 pub downstream_ring_buffer_page_offset: u32,
344 pub user_data: UserDefinedData,
345 pub guest_specified_interrupt_info: Option<SignalInfo>,
346 pub flags: protocol::OpenChannelFlags,
347}
348
349#[derive(Debug, Copy, Clone, Eq, PartialEq)]
350pub enum Update<T: std::fmt::Debug + Copy + Clone> {
351 Unchanged,
352 Reset,
353 Set(T),
354}
355
356impl<T: std::fmt::Debug + Copy + Clone> From<Option<T>> for Update<T> {
357 fn from(value: Option<T>) -> Self {
358 match value {
359 None => Self::Reset,
360 Some(value) => Self::Set(value),
361 }
362 }
363}
364
365#[derive(Debug, Copy, Clone, Eq, PartialEq)]
366pub struct ModifyConnectionRequest {
367 pub version: Option<VersionInfo>,
368 pub monitor_page: Update<MonitorPageGpas>,
369 pub interrupt_page: Update<u64>,
370 pub target_message_vp: Option<u32>,
371 pub notify_relay: bool,
372}
373
374impl Default for ModifyConnectionRequest {
376 fn default() -> Self {
377 Self {
378 version: None,
379 monitor_page: Update::Unchanged,
380 interrupt_page: Update::Unchanged,
381 target_message_vp: None,
382 notify_relay: true,
383 }
384 }
385}
386
387impl From<protocol::ModifyConnection> for ModifyConnectionRequest {
388 fn from(value: protocol::ModifyConnection) -> Self {
389 let monitor_page = if value.parent_to_child_monitor_page_gpa != 0 {
390 Update::Set(MonitorPageGpas {
391 parent_to_child: value.parent_to_child_monitor_page_gpa,
392 child_to_parent: value.child_to_parent_monitor_page_gpa,
393 })
394 } else {
395 Update::Reset
396 };
397
398 Self {
399 monitor_page,
400 ..Default::default()
401 }
402 }
403}
404
405#[derive(Debug, Copy, Clone)]
407pub enum ModifyConnectionResponse {
408 Supported(
414 protocol::ConnectionState,
415 FeatureFlags,
416 Option<MonitorPageGpas>,
417 ),
418 Unsupported,
420 Modified(protocol::ConnectionState),
423}
424
425#[derive(Debug, Copy, Clone)]
426pub enum ModifyState {
427 NotModifying,
428 Modifying { pending_target_vp: Option<u32> },
429}
430
431impl ModifyState {
432 pub fn is_modifying(&self) -> bool {
433 matches!(self, ModifyState::Modifying { .. })
434 }
435}
436
437#[derive(Debug, Copy, Clone)]
438pub struct SignalInfo {
439 pub event_flag: u16,
440 pub connection_id: u32,
441}
442
443#[derive(Debug, Copy, Clone, PartialEq, Eq)]
444enum RestoreState {
445 New,
447 Restoring,
451 Unmatched,
454 Restored,
456}
457
458#[derive(Debug, Clone)]
460enum ChannelState {
461 ClientReleased,
465
466 Closed,
468
469 Opening {
472 request: OpenRequest,
473 reserved_state: Option<ReservedState>,
474 },
475
476 Open {
478 params: OpenRequest,
479 modify_state: ModifyState,
480 reserved_state: Option<ReservedState>,
481 },
482
483 Closing {
485 params: OpenRequest,
486 reserved_state: Option<ReservedState>,
487 },
488
489 ClosingReopen {
492 params: OpenRequest,
493 request: OpenRequest,
494 },
495
496 Revoked,
498
499 Reoffered,
502
503 ClosingClientRelease,
506
507 OpeningClientRelease,
510}
511
512impl ChannelState {
513 fn is_released(&self) -> bool {
516 match self {
517 ChannelState::Closed
518 | ChannelState::Opening { .. }
519 | ChannelState::Open { .. }
520 | ChannelState::Closing { .. }
521 | ChannelState::ClosingReopen { .. }
522 | ChannelState::Revoked
523 | ChannelState::Reoffered => false,
524
525 ChannelState::ClientReleased
526 | ChannelState::ClosingClientRelease
527 | ChannelState::OpeningClientRelease => true,
528 }
529 }
530
531 fn is_revoked(&self) -> bool {
533 match self {
534 ChannelState::Revoked | ChannelState::Reoffered => true,
535
536 ChannelState::ClientReleased
537 | ChannelState::Closed
538 | ChannelState::Opening { .. }
539 | ChannelState::Open { .. }
540 | ChannelState::Closing { .. }
541 | ChannelState::ClosingReopen { .. }
542 | ChannelState::ClosingClientRelease
543 | ChannelState::OpeningClientRelease => false,
544 }
545 }
546
547 fn is_reserved(&self) -> bool {
548 match self {
549 ChannelState::Open {
551 reserved_state: Some(_),
552 ..
553 }
554 | ChannelState::Opening {
555 reserved_state: Some(_),
556 ..
557 }
558 | ChannelState::Closing {
559 reserved_state: Some(_),
560 ..
561 } => true,
562
563 ChannelState::Opening { .. }
564 | ChannelState::Open { .. }
565 | ChannelState::Closing { .. }
566 | ChannelState::ClientReleased
567 | ChannelState::Closed
568 | ChannelState::ClosingReopen { .. }
569 | ChannelState::Revoked
570 | ChannelState::Reoffered
571 | ChannelState::ClosingClientRelease
572 | ChannelState::OpeningClientRelease => false,
573 }
574 }
575}
576
577impl Display for ChannelState {
578 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
579 let state = match self {
580 Self::ClientReleased => "ClientReleased",
581 Self::Closed => "Closed",
582 Self::Opening { .. } => "Opening",
583 Self::Open { .. } => "Open",
584 Self::Closing { .. } => "Closing",
585 Self::ClosingReopen { .. } => "ClosingReopen",
586 Self::Revoked => "Revoked",
587 Self::Reoffered => "Reoffered",
588 Self::ClosingClientRelease => "ClosingClientRelease",
589 Self::OpeningClientRelease => "OpeningClientRelease",
590 };
591 write!(f, "{}", state)
592 }
593}
594
595#[derive(Debug, Clone, Default, mesh::MeshPayload)]
597pub enum MnfUsage {
598 #[default]
600 Disabled,
601 Enabled { latency: Duration },
603 Relayed { monitor_id: u8 },
606}
607
608impl MnfUsage {
609 pub fn is_enabled(&self) -> bool {
610 matches!(self, Self::Enabled { .. })
611 }
612
613 pub fn is_relayed(&self) -> bool {
614 matches!(self, Self::Relayed { .. })
615 }
616
617 pub fn enabled_and_then<T>(&self, f: impl FnOnce(Duration) -> Option<T>) -> Option<T> {
618 if let Self::Enabled { latency } = self {
619 f(*latency)
620 } else {
621 None
622 }
623 }
624}
625
626impl From<Option<Duration>> for MnfUsage {
627 fn from(value: Option<Duration>) -> Self {
628 match value {
629 None => Self::Disabled,
630 Some(latency) => Self::Enabled { latency },
631 }
632 }
633}
634
635#[derive(Debug, Clone, Default, mesh::MeshPayload)]
636pub struct OfferParamsInternal {
637 pub interface_name: String,
639 pub instance_id: Guid,
640 pub interface_id: Guid,
641 pub mmio_megabytes: u16,
642 pub mmio_megabytes_optional: u16,
643 pub subchannel_index: u16,
644 pub use_mnf: MnfUsage,
645 pub offer_order: Option<u64>,
646 pub flags: OfferFlags,
647 pub user_defined: UserDefinedData,
648}
649
650impl OfferParamsInternal {
651 pub fn key(&self) -> OfferKey {
653 OfferKey {
654 interface_id: self.interface_id,
655 instance_id: self.instance_id,
656 subchannel_index: self.subchannel_index,
657 }
658 }
659}
660
661impl From<OfferParams> for OfferParamsInternal {
662 fn from(value: OfferParams) -> Self {
663 let mut user_defined = UserDefinedData::new_zeroed();
664
665 let mut flags = OfferFlags::new()
668 .with_confidential_ring_buffer(true)
669 .with_confidential_external_memory(value.allow_confidential_external_memory);
670
671 match value.channel_type {
672 ChannelType::Device { pipe_packets } => {
673 if pipe_packets {
674 flags.set_named_pipe_mode(true);
675 user_defined.as_pipe_params_mut().pipe_type = protocol::PipeType::MESSAGE;
676 }
677 }
678 ChannelType::Interface {
679 user_defined: interface_user_defined,
680 } => {
681 flags.set_enumerate_device_interface(true);
682 user_defined = interface_user_defined;
683 }
684 ChannelType::Pipe { message_mode } => {
685 flags.set_enumerate_device_interface(true);
686 flags.set_named_pipe_mode(true);
687 user_defined.as_pipe_params_mut().pipe_type = if message_mode {
688 protocol::PipeType::MESSAGE
689 } else {
690 protocol::PipeType::BYTE
691 };
692 }
693 ChannelType::HvSocket {
694 is_connect,
695 is_for_container,
696 silo_id,
697 } => {
698 flags.set_enumerate_device_interface(true);
699 flags.set_tlnpi_provider(true);
700 flags.set_named_pipe_mode(true);
701 *user_defined.as_hvsock_params_mut() = protocol::HvsockUserDefinedParameters::new(
702 is_connect,
703 is_for_container,
704 silo_id,
705 );
706 }
707 };
708
709 Self {
710 interface_name: value.interface_name,
711 instance_id: value.instance_id,
712 interface_id: value.interface_id,
713 mmio_megabytes: value.mmio_megabytes,
714 mmio_megabytes_optional: value.mmio_megabytes_optional,
715 subchannel_index: value.subchannel_index,
716 use_mnf: value.mnf_interrupt_latency.into(),
717 offer_order: value.offer_order,
718 user_defined,
719 flags,
720 }
721 }
722}
723
724#[derive(Debug, Copy, Clone, Inspect, PartialEq, Eq)]
725pub struct ConnectionTarget {
726 pub vp: u32,
727 pub sint: u8,
728}
729
730#[derive(Debug, Copy, Clone, PartialEq, Eq)]
731pub enum MessageTarget {
732 Default,
733 ReservedChannel(OfferId, ConnectionTarget),
734 Custom(ConnectionTarget),
735}
736
737impl MessageTarget {
738 pub fn for_offer(offer_id: OfferId, reserved_state: &Option<ReservedState>) -> Self {
739 if let Some(state) = reserved_state {
740 Self::ReservedChannel(offer_id, state.target)
741 } else {
742 Self::Default
743 }
744 }
745}
746
747#[derive(Debug, Copy, Clone)]
748pub struct ReservedState {
749 version: VersionInfo,
750 target: ConnectionTarget,
751}
752
753#[derive(Debug)]
755struct Channel {
756 info: Option<OfferedInfo>,
757 offer: OfferParamsInternal,
758 state: ChannelState,
759 restore_state: RestoreState,
760}
761
762#[derive(Debug, Copy, Clone)]
763struct OfferedInfo {
764 channel_id: ChannelId,
765 connection_id: u32,
766 monitor_id: Option<MonitorId>,
767}
768
769impl Channel {
770 fn inspect_state(&self, resp: &mut inspect::Response<'_>) {
771 let mut target_vp = None;
772 let mut event_flag = None;
773 let mut connection_id = None;
774 let mut reserved_target = None;
775 let state = match &self.state {
776 ChannelState::ClientReleased => "client_released",
777 ChannelState::Closed => "closed",
778 ChannelState::Opening { reserved_state, .. } => {
779 reserved_target = reserved_state.map(|state| state.target);
780 "opening"
781 }
782 ChannelState::Open {
783 params,
784 reserved_state,
785 ..
786 } => {
787 target_vp = Some(params.target_vp);
788 if let Some(id) = params.guest_specified_interrupt_info {
789 event_flag = Some(id.event_flag);
790 connection_id = Some(id.connection_id);
791 }
792 reserved_target = reserved_state.map(|state| state.target);
793 "open"
794 }
795 ChannelState::Closing { reserved_state, .. } => {
796 reserved_target = reserved_state.map(|state| state.target);
797 "closing"
798 }
799 ChannelState::ClosingReopen { .. } => "closing_reopen",
800 ChannelState::Revoked => "revoked",
801 ChannelState::Reoffered => "reoffered",
802 ChannelState::ClosingClientRelease => "closing_client_release",
803 ChannelState::OpeningClientRelease => "opening_client_release",
804 };
805 let restore_state = match self.restore_state {
806 RestoreState::New => "new",
807 RestoreState::Restoring => "restoring",
808 RestoreState::Restored => "restored",
809 RestoreState::Unmatched => "unmatched",
810 };
811 if let Some(info) = &self.info {
812 resp.field("channel_id", info.channel_id.0)
813 .field("offered_connection_id", info.connection_id)
814 .field("monitor_id", info.monitor_id.map(|id| id.0));
815 }
816 resp.field("state", state)
817 .field("restore_state", restore_state)
818 .field("interface_name", self.offer.interface_name.clone())
819 .display("instance_id", &self.offer.instance_id)
820 .display("interface_id", &self.offer.interface_id)
821 .field("mmio_megabytes", self.offer.mmio_megabytes)
822 .field("target_vp", target_vp)
823 .field("guest_specified_event_flag", event_flag)
824 .field("guest_specified_connection_id", connection_id)
825 .field("reserved_connection_target", reserved_target)
826 .binary("offer_flags", self.offer.flags.into_bits());
827 }
828
829 fn handled_monitor_info(&self) -> Option<MonitorInfo> {
839 self.offer.use_mnf.enabled_and_then(|latency| {
840 if self.state.is_reserved() {
841 None
842 } else {
843 self.info.and_then(|info| {
844 info.monitor_id.map(|monitor_id| MonitorInfo {
845 monitor_id,
846 latency,
847 })
848 })
849 }
850 })
851 }
852
853 fn prepare_channel(
856 &mut self,
857 offer_id: OfferId,
858 assigned_channels: &mut AssignedChannels,
859 assigned_monitors: &mut AssignedMonitors,
860 ) {
861 assert!(self.info.is_none());
862
863 let entry = assigned_channels
865 .allocate()
866 .expect("there are enough channel IDs for everything in ChannelList");
867
868 let channel_id = entry.id();
869 entry.insert(offer_id);
870 let connection_id = ConnectionId::new(channel_id.0, assigned_channels.vtl, VMBUS_SINT);
871
872 let monitor_id = match self.offer.use_mnf {
877 MnfUsage::Enabled { .. } => {
878 let monitor_id = assigned_monitors.assign_monitor();
879 if monitor_id.is_none() {
880 tracelimit::warn_ratelimited!("Out of monitor IDs.");
881 }
882
883 monitor_id
884 }
885 MnfUsage::Relayed { monitor_id } => Some(MonitorId(monitor_id)),
886 MnfUsage::Disabled => None,
887 };
888
889 self.info = Some(OfferedInfo {
890 channel_id,
891 connection_id: connection_id.0,
892 monitor_id,
893 });
894 }
895
896 fn release_channel(
898 &mut self,
899 offer_id: OfferId,
900 assigned_channels: &mut AssignedChannels,
901 assigned_monitors: &mut AssignedMonitors,
902 ) {
903 if let Some(info) = self.info.take() {
904 assigned_channels.free(info.channel_id, offer_id);
905
906 if let Some(monitor_id) = info.monitor_id {
908 if self.offer.use_mnf.is_enabled() {
909 assigned_monitors.release_monitor(monitor_id);
910 }
911 }
912 }
913 }
914}
915
916#[derive(Debug)]
917struct AssignedChannels {
918 assignments: Vec<Option<OfferId>>,
919 vtl: Vtl,
920 reserved_offset: usize,
921 count_in_reserved_range: usize,
923}
924
925impl AssignedChannels {
926 fn new(vtl: Vtl, channel_id_offset: u16) -> Self {
927 Self {
928 assignments: vec![None; MAX_CHANNELS],
929 vtl,
930 reserved_offset: channel_id_offset as usize,
931 count_in_reserved_range: 0,
932 }
933 }
934
935 fn allowable_channel_count(&self) -> usize {
936 MAX_CHANNELS - self.reserved_offset + self.count_in_reserved_range
937 }
938
939 fn get(&self, channel_id: ChannelId) -> Option<OfferId> {
940 self.assignments
941 .get(Self::index(channel_id))
942 .copied()
943 .flatten()
944 }
945
946 fn set(&mut self, channel_id: ChannelId) -> Result<AssignmentEntry<'_>, OfferError> {
947 let index = Self::index(channel_id);
948 if self
949 .assignments
950 .get(index)
951 .ok_or(OfferError::InvalidChannelId(channel_id))?
952 .is_some()
953 {
954 return Err(OfferError::ChannelIdInUse(channel_id));
955 }
956 Ok(AssignmentEntry { list: self, index })
957 }
958
959 fn allocate(&mut self) -> Option<AssignmentEntry<'_>> {
960 let index = self.reserved_offset
961 + self.assignments[self.reserved_offset..]
962 .iter()
963 .position(|x| x.is_none())?;
964 Some(AssignmentEntry { list: self, index })
965 }
966
967 fn free(&mut self, channel_id: ChannelId, offer_id: OfferId) {
968 let index = Self::index(channel_id);
969 let slot = &mut self.assignments[index];
970 assert_eq!(slot.take(), Some(offer_id));
971 if index < self.reserved_offset {
972 self.count_in_reserved_range -= 1;
973 }
974 }
975
976 fn index(channel_id: ChannelId) -> usize {
977 channel_id.0.wrapping_sub(1) as usize
978 }
979}
980
981struct AssignmentEntry<'a> {
982 list: &'a mut AssignedChannels,
983 index: usize,
984}
985
986impl AssignmentEntry<'_> {
987 pub fn id(&self) -> ChannelId {
988 ChannelId(self.index as u32 + 1)
989 }
990
991 pub fn insert(self, offer_id: OfferId) {
992 assert!(
993 self.list.assignments[self.index]
994 .replace(offer_id)
995 .is_none()
996 );
997
998 if self.index < self.list.reserved_offset {
999 self.list.count_in_reserved_range += 1;
1000 }
1001 }
1002}
1003
1004struct ChannelList {
1005 channels: Slab<Channel>,
1006}
1007
1008fn channel_inspect_path(offer: &OfferParamsInternal, suffix: std::fmt::Arguments<'_>) -> String {
1009 if offer.subchannel_index == 0 {
1010 format!("{}{}", offer.instance_id, suffix)
1011 } else {
1012 format!(
1013 "{}/subchannels/{}{}",
1014 offer.instance_id, offer.subchannel_index, suffix
1015 )
1016 }
1017}
1018
1019impl ChannelList {
1020 fn inspect(
1021 &self,
1022 notifier: &impl Notifier,
1023 version: Option<VersionInfo>,
1024 resp: &mut inspect::Response<'_>,
1025 ) {
1026 for (offer_id, channel) in self.iter() {
1027 resp.child(
1028 &channel_inspect_path(&channel.offer, format_args!("")),
1029 |req| {
1030 let mut resp = req.respond();
1031 channel.inspect_state(&mut resp);
1032
1033 resp.merge(inspect::adhoc(|req| {
1037 if !matches!(channel.state, ChannelState::Revoked) {
1038 notifier.inspect(version, offer_id, req);
1039 }
1040 }));
1041 },
1042 );
1043 }
1044 }
1045}
1046
1047pub const MAX_CHANNELS: usize = 2047;
1050
1051impl ChannelList {
1052 fn new() -> Self {
1053 Self {
1054 channels: Slab::new(),
1055 }
1056 }
1057
1058 fn len(&self) -> usize {
1060 self.channels.len()
1061 }
1062
1063 fn offer(&mut self, new_channel: Channel) -> OfferId {
1065 OfferId(self.channels.insert(new_channel))
1066 }
1067
1068 fn remove(&mut self, offer_id: OfferId) {
1070 let channel = self.channels.remove(offer_id.0);
1071 assert!(channel.info.is_none());
1072 }
1073
1074 fn get_by_channel_id_mut(
1076 &mut self,
1077 assigned_channels: &AssignedChannels,
1078 channel_id: ChannelId,
1079 ) -> Result<(OfferId, &mut Channel), ChannelError> {
1080 let offer_id = assigned_channels
1081 .get(channel_id)
1082 .ok_or(ChannelError::UnknownChannelId)?;
1083 let channel = &mut self[offer_id];
1084 if channel.state.is_released() {
1085 return Err(ChannelError::ChannelReleased);
1086 }
1087 assert_eq!(
1088 channel.info.as_ref().map(|info| info.channel_id),
1089 Some(channel_id)
1090 );
1091 Ok((offer_id, channel))
1092 }
1093
1094 fn get_by_channel_id(
1096 &self,
1097 assigned_channels: &AssignedChannels,
1098 channel_id: ChannelId,
1099 ) -> Result<(OfferId, &Channel), ChannelError> {
1100 let offer_id = assigned_channels
1101 .get(channel_id)
1102 .ok_or(ChannelError::UnknownChannelId)?;
1103 let channel = &self[offer_id];
1104 if channel.state.is_released() {
1105 return Err(ChannelError::ChannelReleased);
1106 }
1107 assert_eq!(
1108 channel.info.as_ref().map(|info| info.channel_id),
1109 Some(channel_id)
1110 );
1111 Ok((offer_id, channel))
1112 }
1113
1114 fn get_by_key_mut(&mut self, key: &OfferKey) -> Option<(OfferId, &mut Channel)> {
1117 for (offer_id, channel) in self.iter_mut() {
1118 if channel.offer.instance_id == key.instance_id
1119 && channel.offer.interface_id == key.interface_id
1120 && channel.offer.subchannel_index == key.subchannel_index
1121 {
1122 return Some((offer_id, channel));
1123 }
1124 }
1125 None
1126 }
1127
1128 fn iter(&self) -> impl Iterator<Item = (OfferId, &Channel)> {
1130 self.channels
1131 .iter()
1132 .map(|(id, channel)| (OfferId(id), channel))
1133 }
1134
1135 fn iter_mut(&mut self) -> impl Iterator<Item = (OfferId, &mut Channel)> {
1137 self.channels
1138 .iter_mut()
1139 .map(|(id, channel)| (OfferId(id), channel))
1140 }
1141
1142 fn retain<F>(&mut self, mut f: F)
1144 where
1145 F: FnMut(OfferId, &mut Channel) -> bool,
1146 {
1147 self.channels.retain(|id, channel| {
1148 let retain = f(OfferId(id), channel);
1149 if !retain {
1150 assert!(channel.info.is_none());
1151 }
1152 retain
1153 })
1154 }
1155}
1156
1157impl Index<OfferId> for ChannelList {
1158 type Output = Channel;
1159
1160 fn index(&self, offer_id: OfferId) -> &Self::Output {
1161 &self.channels[offer_id.0]
1162 }
1163}
1164
1165impl IndexMut<OfferId> for ChannelList {
1166 fn index_mut(&mut self, offer_id: OfferId) -> &mut Self::Output {
1167 &mut self.channels[offer_id.0]
1168 }
1169}
1170
1171#[derive(Debug, Inspect)]
1173struct Gpadl {
1174 count: u16,
1175 #[inspect(skip)]
1176 buf: Vec<u64>,
1177 state: GpadlState,
1178}
1179
1180#[derive(Debug, Copy, Clone, PartialEq, Eq, Inspect)]
1181enum GpadlState {
1182 InProgress,
1184 Offered,
1186 OfferedTearingDown,
1189 Accepted,
1191 TearingDown,
1193}
1194
1195impl Gpadl {
1196 fn new(count: u16, len: usize) -> Self {
1199 Self {
1200 state: GpadlState::InProgress,
1201 count,
1202 buf: Vec::with_capacity(len),
1203 }
1204 }
1205
1206 fn append(&mut self, data: &[u8]) -> Result<bool, ChannelError> {
1208 if self.state == GpadlState::InProgress {
1209 let buf = &mut self.buf;
1210 let len = min(data.len() & !7, (buf.capacity() - buf.len()) * 8);
1215 let data = &data[..len];
1216 let start = buf.len();
1217 buf.resize(buf.len() + data.len() / 8, 0);
1218 buf[start..].as_mut_bytes().copy_from_slice(data);
1219 Ok(if buf.len() == buf.capacity() {
1220 gparange::validate_gpa_ranges(self.count as usize, buf)
1221 .map_err(ChannelError::InvalidGpaRange)?;
1222 self.state = GpadlState::Offered;
1223 true
1224 } else {
1225 false
1226 })
1227 } else {
1228 Err(ChannelError::GpadlAlreadyComplete)
1229 }
1230 }
1231}
1232
1233#[derive(Debug, Copy, Clone)]
1235pub struct OpenParams {
1236 pub open_data: OpenData,
1237 pub connection_id: u32,
1238 pub event_flag: u16,
1239 pub monitor_info: Option<MonitorInfo>,
1240 pub flags: protocol::OpenChannelFlags,
1241 pub reserved_target: Option<ConnectionTarget>,
1242 pub channel_id: ChannelId,
1243}
1244
1245impl OpenParams {
1246 fn from_request(
1247 info: &OfferedInfo,
1248 request: &OpenRequest,
1249 monitor_info: Option<MonitorInfo>,
1250 reserved_target: Option<ConnectionTarget>,
1251 ) -> Self {
1252 let (event_flag, connection_id) = if let Some(id) = request.guest_specified_interrupt_info {
1255 (id.event_flag, id.connection_id)
1256 } else {
1257 (info.channel_id.0 as u16, info.connection_id)
1258 };
1259
1260 Self {
1261 open_data: OpenData {
1262 target_vp: request.target_vp,
1263 ring_offset: request.downstream_ring_buffer_page_offset,
1264 ring_gpadl_id: request.ring_buffer_gpadl_id,
1265 user_data: request.user_data,
1266 event_flag,
1267 connection_id,
1268 },
1269 connection_id,
1270 event_flag,
1271 monitor_info,
1272 flags: request.flags.with_unused(0),
1273 reserved_target,
1274 channel_id: info.channel_id,
1275 }
1276 }
1277}
1278
1279#[derive(Debug)]
1281pub enum Action {
1282 Open(OpenParams, VersionInfo),
1283 Close,
1284 Gpadl(GpadlId, u16, Vec<u64>),
1285 TeardownGpadl {
1286 gpadl_id: GpadlId,
1287 post_restore: bool,
1288 },
1289 Modify {
1290 target_vp: u32,
1291 },
1292}
1293
1294static SUPPORTED_VERSIONS: &[Version] = &[
1296 Version::V1,
1297 Version::Win7,
1298 Version::Win8,
1299 Version::Win8_1,
1300 Version::Win10,
1301 Version::Win10Rs3_0,
1302 Version::Win10Rs3_1,
1303 Version::Win10Rs4,
1304 Version::Win10Rs5,
1305 Version::Iron,
1306 Version::Copper,
1307];
1308
1309const SUPPORTED_FEATURE_FLAGS: FeatureFlags = FeatureFlags::new()
1312 .with_guest_specified_signal_parameters(true)
1313 .with_channel_interrupt_redirection(true)
1314 .with_modify_connection(true)
1315 .with_client_id(true)
1316 .with_pause_resume(true)
1317 .with_server_specified_monitor_pages(true);
1318
1319pub trait Notifier: Send {
1321 fn notify(&mut self, offer_id: OfferId, action: Action);
1323
1324 fn forward_unhandled(&mut self, request: InitiateContactRequest);
1326
1327 fn modify_connection(&mut self, request: ModifyConnectionRequest) -> anyhow::Result<()>;
1333
1334 fn inspect(&self, version: Option<VersionInfo>, offer_id: OfferId, req: inspect::Request<'_>) {
1336 let _ = (version, offer_id, req);
1337 }
1338
1339 #[must_use]
1342 fn send_message(&mut self, message: &OutgoingMessage, target: MessageTarget) -> bool;
1343
1344 fn notify_hvsock(&mut self, request: &HvsockConnectRequest);
1346
1347 fn reset_complete(&mut self);
1349
1350 fn unload_complete(&mut self);
1352}
1353
1354impl Server {
1355 pub fn new(
1357 vtl: Vtl,
1358 child_connection_id: u32,
1359 channel_id_offset: u16,
1360 use_absolute_channel_order: bool,
1361 ) -> Self {
1362 Server {
1363 state: ConnectionState::Disconnected,
1364 channels: ChannelList::new(),
1365 assigned_channels: AssignedChannels::new(vtl, channel_id_offset),
1366 assigned_monitors: AssignedMonitors::new(),
1367 gpadls: Default::default(),
1368 incomplete_gpadls: Default::default(),
1369 child_connection_id,
1370 max_version: None,
1371 delayed_max_version: None,
1372 pending_messages: PendingMessages(VecDeque::new()),
1373 require_server_allocated_mnf: false,
1374 use_absolute_channel_order,
1375 }
1376 }
1377
1378 pub fn with_notifier<'a, T: Notifier>(
1380 &'a mut self,
1381 notifier: &'a mut T,
1382 ) -> ServerWithNotifier<'a, T> {
1383 self.validate();
1384 ServerWithNotifier {
1385 inner: self,
1386 notifier,
1387 }
1388 }
1389
1390 pub fn set_require_server_allocated_mnf(&mut self, require: bool) {
1393 self.require_server_allocated_mnf = require;
1394 }
1395
1396 fn validate(&self) {
1397 #[cfg(debug_assertions)]
1398 for (_, channel) in self.channels.iter() {
1399 let should_have_info = !channel.state.is_released();
1400 if channel.info.is_some() != should_have_info {
1401 panic!("channel invariant violation: {channel:?}");
1402 }
1403 }
1404 }
1405
1406 pub fn set_compatibility_version(&mut self, version: MaxVersionInfo, delay: bool) {
1408 if delay {
1409 self.delayed_max_version = Some(version)
1410 } else {
1411 tracing::info!(?version, "Limiting VmBus connections to version");
1412 self.max_version = Some(version);
1413 }
1414 }
1415
1416 pub fn channel_gpadls(&self, offer_id: OfferId) -> Vec<RestoredGpadl> {
1417 self.gpadls
1418 .iter()
1419 .filter_map(|(&(gpadl_id, gpadl_offer_id), gpadl)| {
1420 if offer_id != gpadl_offer_id {
1421 return None;
1422 }
1423 let accepted = match gpadl.state {
1424 GpadlState::Offered | GpadlState::OfferedTearingDown => false,
1425 GpadlState::Accepted => true,
1426 GpadlState::InProgress | GpadlState::TearingDown => return None,
1427 };
1428 Some(RestoredGpadl {
1429 request: GpadlRequest {
1430 id: gpadl_id,
1431 count: gpadl.count,
1432 buf: gpadl.buf.clone(),
1433 },
1434 accepted,
1435 })
1436 })
1437 .collect()
1438 }
1439
1440 pub fn get_version(&self) -> Option<VersionInfo> {
1441 self.state.get_version()
1442 }
1443
1444 pub fn get_restore_open_params(&self, offer_id: OfferId) -> Result<OpenParams, RestoreError> {
1445 let channel = &self.channels[offer_id];
1446
1447 match channel.restore_state {
1449 RestoreState::New => {
1450 return Err(RestoreError::MissingChannel(channel.offer.key()));
1454 }
1455 RestoreState::Restoring => {}
1456 RestoreState::Unmatched => unreachable!(),
1457 RestoreState::Restored => {
1458 return Err(RestoreError::AlreadyRestored(channel.offer.key()));
1459 }
1460 }
1461
1462 let info = channel
1463 .info
1464 .ok_or_else(|| RestoreError::MissingChannel(channel.offer.key()))?;
1465
1466 let (request, reserved_state) = match channel.state {
1467 ChannelState::Closed => {
1468 return Err(RestoreError::MismatchedOpenState(channel.offer.key()));
1469 }
1470 ChannelState::Closing { params, .. } | ChannelState::ClosingReopen { params, .. } => {
1471 (params, None)
1472 }
1473 ChannelState::Opening {
1474 request,
1475 reserved_state,
1476 } => (request, reserved_state),
1477 ChannelState::Open {
1478 params,
1479 reserved_state,
1480 ..
1481 } => (params, reserved_state),
1482 ChannelState::ClientReleased | ChannelState::Reoffered => {
1483 return Err(RestoreError::MissingChannel(channel.offer.key()));
1484 }
1485 ChannelState::Revoked
1486 | ChannelState::ClosingClientRelease
1487 | ChannelState::OpeningClientRelease => unreachable!(),
1488 };
1489
1490 Ok(OpenParams::from_request(
1491 &info,
1492 &request,
1493 channel.handled_monitor_info(),
1494 reserved_state.map(|state| state.target),
1495 ))
1496 }
1497
1498 pub fn has_pending_messages(&self) -> bool {
1500 !self.pending_messages.0.is_empty() && !self.state.is_paused()
1501 }
1502
1503 pub fn poll_flush_pending_messages(
1505 &mut self,
1506 mut send: impl FnMut(&OutgoingMessage) -> Poll<()>,
1507 ) -> Poll<()> {
1508 if !self.state.is_paused() {
1509 while let Some(message) = self.pending_messages.0.front() {
1510 ready!(send(message));
1511 self.pending_messages.0.pop_front();
1512 }
1513 }
1514
1515 Poll::Ready(())
1516 }
1517}
1518
1519impl<'a, N: 'a + Notifier> ServerWithNotifier<'a, N> {
1520 pub fn restore_channel(&mut self, offer_id: OfferId, open: bool) -> Result<(), RestoreError> {
1526 let channel = &mut self.inner.channels[offer_id];
1527
1528 match channel.restore_state {
1531 RestoreState::New => {
1532 if open {
1536 return Err(RestoreError::MissingChannel(channel.offer.key()));
1537 } else {
1538 return Ok(());
1539 }
1540 }
1541 RestoreState::Restoring => {}
1542 RestoreState::Unmatched => unreachable!(),
1543 RestoreState::Restored => {
1544 return Err(RestoreError::AlreadyRestored(channel.offer.key()));
1545 }
1546 }
1547
1548 let info = channel
1549 .info
1550 .ok_or_else(|| RestoreError::MissingChannel(channel.offer.key()))?;
1551
1552 if let Some(monitor_info) = channel.handled_monitor_info() {
1553 if !self
1554 .inner
1555 .assigned_monitors
1556 .claim_monitor(monitor_info.monitor_id)
1557 {
1558 return Err(RestoreError::DuplicateMonitorId(monitor_info.monitor_id.0));
1559 }
1560 }
1561
1562 if open {
1563 match channel.state {
1564 ChannelState::Closed => {
1565 return Err(RestoreError::MismatchedOpenState(channel.offer.key()));
1566 }
1567 ChannelState::Closing { .. } | ChannelState::ClosingReopen { .. } => {
1568 self.notifier.notify(offer_id, Action::Close);
1569 }
1570 ChannelState::Opening {
1571 request,
1572 reserved_state,
1573 } => {
1574 self.inner
1575 .pending_messages
1576 .sender(self.notifier, self.inner.state.is_paused())
1577 .send_open_result(
1578 info.channel_id,
1579 &request,
1580 protocol::STATUS_SUCCESS,
1581 MessageTarget::for_offer(offer_id, &reserved_state),
1582 );
1583 channel.state = ChannelState::Open {
1584 params: request,
1585 modify_state: ModifyState::NotModifying,
1586 reserved_state,
1587 };
1588 }
1589 ChannelState::Open { .. } => {}
1590 ChannelState::ClientReleased | ChannelState::Reoffered => {
1591 return Err(RestoreError::MissingChannel(channel.offer.key()));
1592 }
1593 ChannelState::Revoked
1594 | ChannelState::ClosingClientRelease
1595 | ChannelState::OpeningClientRelease => unreachable!(),
1596 };
1597 } else {
1598 match channel.state {
1599 ChannelState::Closed => {}
1600 ChannelState::Reoffered => {}
1605 ChannelState::Closing { .. } => {
1606 channel.state = ChannelState::Closed;
1607 }
1608 ChannelState::ClosingReopen { request, .. } => {
1609 self.notifier.notify(
1610 offer_id,
1611 Action::Open(
1612 OpenParams::from_request(
1613 &info,
1614 &request,
1615 channel.handled_monitor_info(),
1616 None,
1617 ),
1618 self.inner.state.get_version().expect("must be connected"),
1619 ),
1620 );
1621 channel.state = ChannelState::Opening {
1622 request,
1623 reserved_state: None,
1624 };
1625 }
1626 ChannelState::Opening {
1627 request,
1628 reserved_state,
1629 } => {
1630 self.notifier.notify(
1631 offer_id,
1632 Action::Open(
1633 OpenParams::from_request(
1634 &info,
1635 &request,
1636 channel.handled_monitor_info(),
1637 reserved_state.map(|state| state.target),
1638 ),
1639 self.inner.state.get_version().expect("must be connected"),
1640 ),
1641 );
1642 }
1643 ChannelState::Open { .. } => {
1644 return Err(RestoreError::MismatchedOpenState(channel.offer.key()));
1645 }
1646 ChannelState::ClientReleased => {
1647 return Err(RestoreError::MissingChannel(channel.offer.key()));
1648 }
1649 ChannelState::Revoked
1650 | ChannelState::ClosingClientRelease
1651 | ChannelState::OpeningClientRelease => unreachable!(),
1652 }
1653 }
1654
1655 channel.restore_state = RestoreState::Restored;
1656 Ok(())
1657 }
1658
1659 pub fn revoke_unclaimed_channels(&mut self) {
1662 for (offer_id, channel) in self.inner.channels.iter_mut() {
1663 match channel.restore_state {
1664 RestoreState::Restored => {
1665 }
1667 RestoreState::New => {
1668 if let ConnectionState::Connected(info) = &self.inner.state {
1673 if info.offers_sent && matches!(channel.state, ChannelState::ClientReleased)
1674 {
1675 channel.prepare_channel(
1676 offer_id,
1677 &mut self.inner.assigned_channels,
1678 &mut self.inner.assigned_monitors,
1679 );
1680 channel.state = ChannelState::Closed;
1681 self.inner
1682 .pending_messages
1683 .sender(self.notifier, self.inner.state.is_paused())
1684 .send_offer(channel, info);
1685 }
1686 }
1687 }
1688 RestoreState::Restoring => {
1689 let retain = revoke(
1693 self.inner
1694 .pending_messages
1695 .sender(self.notifier, self.inner.state.is_paused()),
1696 offer_id,
1697 channel,
1698 &mut self.inner.gpadls,
1699 );
1700 assert!(retain, "channel has not been released");
1701 channel.state = ChannelState::Reoffered;
1702 }
1703 RestoreState::Unmatched => {
1704 let retain = revoke(
1707 self.inner
1708 .pending_messages
1709 .sender(self.notifier, self.inner.state.is_paused()),
1710 offer_id,
1711 channel,
1712 &mut self.inner.gpadls,
1713 );
1714 assert!(retain, "channel has not been released");
1715 }
1716 }
1717 }
1718
1719 for (&(gpadl_id, offer_id), gpadl) in self.inner.gpadls.iter_mut() {
1721 match gpadl.state {
1722 GpadlState::InProgress | GpadlState::Accepted => {}
1723 GpadlState::Offered => {
1724 self.notifier.notify(
1725 offer_id,
1726 Action::Gpadl(gpadl_id, gpadl.count, gpadl.buf.clone()),
1727 );
1728 }
1729 GpadlState::TearingDown => {
1730 self.notifier.notify(
1731 offer_id,
1732 Action::TeardownGpadl {
1733 gpadl_id,
1734 post_restore: true,
1735 },
1736 );
1737 }
1738 GpadlState::OfferedTearingDown => unreachable!(),
1739 }
1740 }
1741
1742 self.check_disconnected();
1743 }
1744
1745 pub fn reset(&mut self) {
1750 assert!(!self.is_resetting());
1751 if self.request_disconnect(ConnectionAction::Reset) {
1752 self.complete_reset();
1753 }
1754 }
1755
1756 fn complete_reset(&mut self) {
1757 for (_, channel) in self.inner.channels.iter_mut() {
1759 channel.restore_state = RestoreState::New;
1760 }
1761 self.inner.pending_messages.0.clear();
1762 self.notifier.reset_complete();
1763 }
1764
1765 pub fn offer_channel(&mut self, offer: OfferParamsInternal) -> Result<OfferId, OfferError> {
1767 if let Some((offer_id, channel)) = self.inner.channels.get_by_key_mut(&offer.key()) {
1769 if channel.restore_state != RestoreState::Unmatched
1773 && !matches!(channel.state, ChannelState::Revoked)
1774 {
1775 return Err(OfferError::AlreadyExists(offer.key()));
1776 }
1777
1778 let info = channel.info.expect("assigned");
1779 if channel.restore_state == RestoreState::Unmatched {
1780 tracing::debug!(
1781 offer_id = offer_id.0,
1782 key = %channel.offer.key(),
1783 "matched channel"
1784 );
1785
1786 assert!(!matches!(channel.state, ChannelState::Revoked));
1787 channel.restore_state = RestoreState::Restoring;
1791
1792 if let MnfUsage::Relayed { monitor_id } = offer.use_mnf {
1795 if info.monitor_id != Some(MonitorId(monitor_id)) {
1796 return Err(OfferError::MismatchedMonitorId(
1797 info.monitor_id,
1798 MonitorId(monitor_id),
1799 ));
1800 }
1801 }
1802 } else {
1803 channel.state = ChannelState::Reoffered;
1807 tracing::info!(?offer_id, key = %channel.offer.key(), "channel marked for reoffer");
1808 }
1809
1810 channel.offer = offer;
1811 return Ok(offer_id);
1812 }
1813
1814 let mut connected_info = None;
1815 let state = match &self.inner.state {
1816 ConnectionState::Connected(info) => {
1817 if info.offers_sent {
1818 connected_info = Some(info);
1819 ChannelState::Closed
1820 } else {
1821 ChannelState::ClientReleased
1822 }
1823 }
1824 ConnectionState::Connecting { .. }
1825 | ConnectionState::Disconnecting { .. }
1826 | ConnectionState::Disconnected => ChannelState::ClientReleased,
1827 };
1828
1829 if self.inner.channels.len() >= self.inner.assigned_channels.allowable_channel_count() {
1831 return Err(OfferError::TooManyChannels);
1832 }
1833
1834 let key = offer.key();
1835 let confidential_ring_buffer = offer.flags.confidential_ring_buffer();
1836 let confidential_external_memory = offer.flags.confidential_external_memory();
1837 let channel = Channel {
1838 info: None,
1839 offer,
1840 state,
1841 restore_state: RestoreState::New,
1842 };
1843
1844 let offer_id = self.inner.channels.offer(channel);
1845 if let Some(info) = connected_info {
1846 let channel = &mut self.inner.channels[offer_id];
1847 channel.prepare_channel(
1848 offer_id,
1849 &mut self.inner.assigned_channels,
1850 &mut self.inner.assigned_monitors,
1851 );
1852
1853 self.inner
1854 .pending_messages
1855 .sender(self.notifier, self.inner.state.is_paused())
1856 .send_offer(channel, info);
1857 }
1858
1859 tracing::info!(?offer_id, %key, confidential_ring_buffer, confidential_external_memory, "new channel");
1860 Ok(offer_id)
1861 }
1862
1863 pub fn revoke_channel(&mut self, offer_id: OfferId) {
1865 let channel = &mut self.inner.channels[offer_id];
1866 let retain = revoke(
1867 self.inner
1868 .pending_messages
1869 .sender(self.notifier, self.inner.state.is_paused()),
1870 offer_id,
1871 channel,
1872 &mut self.inner.gpadls,
1873 );
1874 if !retain {
1875 self.inner.channels.remove(offer_id);
1876 }
1877
1878 self.check_disconnected();
1879 }
1880
1881 pub fn open_complete(&mut self, offer_id: OfferId, result: i32) {
1883 let channel = &mut self.inner.channels[offer_id];
1884 tracing::debug!(offer_id = offer_id.0, key = %channel.offer.key(), result, "open complete");
1885
1886 match channel.state {
1887 ChannelState::Opening {
1888 request,
1889 reserved_state,
1890 } => {
1891 let channel_id = channel.info.expect("assigned").channel_id;
1892 if result >= 0 {
1893 tracelimit::info_ratelimited!(
1894 offer_id = offer_id.0,
1895 channel_id = channel_id.0,
1896 key = %channel.offer.key(),
1897 result,
1898 "opened channel"
1899 );
1900 } else {
1901 tracelimit::error_ratelimited!(
1903 offer_id = offer_id.0,
1904 channel_id = channel_id.0,
1905 key = %channel.offer.key(),
1906 result,
1907 "failed to open channel"
1908 );
1909 }
1910
1911 self.inner
1912 .pending_messages
1913 .sender(self.notifier, self.inner.state.is_paused())
1914 .send_open_result(
1915 channel_id,
1916 &request,
1917 result,
1918 MessageTarget::for_offer(offer_id, &reserved_state),
1919 );
1920 channel.state = if result >= 0 {
1921 ChannelState::Open {
1922 params: request,
1923 modify_state: ModifyState::NotModifying,
1924 reserved_state,
1925 }
1926 } else {
1927 ChannelState::Closed
1928 };
1929 }
1930 ChannelState::OpeningClientRelease => {
1931 tracing::info!(
1932 offer_id = offer_id.0,
1933 key = %channel.offer.key(),
1934 result,
1935 "opened channel (client released)"
1936 );
1937
1938 if result >= 0 {
1939 channel.state = ChannelState::ClosingClientRelease;
1940 self.notifier.notify(offer_id, Action::Close);
1941 } else {
1942 channel.state = ChannelState::ClientReleased;
1943 self.check_disconnected();
1944 }
1945 }
1946
1947 ChannelState::ClientReleased
1948 | ChannelState::Closed
1949 | ChannelState::Open { .. }
1950 | ChannelState::Closing { .. }
1951 | ChannelState::ClosingReopen { .. }
1952 | ChannelState::Revoked
1953 | ChannelState::Reoffered
1954 | ChannelState::ClosingClientRelease => {
1955 tracing::error!(?offer_id, key = %channel.offer.key(), state = ?channel.state, "invalid open complete")
1956 }
1957 }
1958 }
1959
1960 fn are_channels_reset(&self, include_reserved: bool) -> bool {
1963 self.inner.gpadls.keys().all(|(_, offer_id)| {
1964 !include_reserved && self.inner.channels[*offer_id].state.is_reserved()
1965 }) && self.inner.channels.iter().all(|(_, channel)| {
1966 matches!(channel.state, ChannelState::ClientReleased)
1967 || (!include_reserved && channel.state.is_reserved())
1968 })
1969 }
1970
1971 fn check_disconnected(&mut self) {
1975 match self.inner.state {
1976 ConnectionState::Disconnecting {
1977 next_action,
1978 modify_sent: false,
1979 } => {
1980 if self.are_channels_reset(matches!(next_action, ConnectionAction::Reset)) {
1981 self.notify_disconnect(next_action);
1982 }
1983 }
1984 ConnectionState::Disconnecting {
1985 modify_sent: true, ..
1986 }
1987 | ConnectionState::Disconnected
1988 | ConnectionState::Connected { .. }
1989 | ConnectionState::Connecting { .. } => (),
1990 }
1991 }
1992
1993 fn notify_disconnect(&mut self, next_action: ConnectionAction) {
1995 debug_assert!(self.are_channels_reset(matches!(next_action, ConnectionAction::Reset)));
1997 self.inner.state = ConnectionState::Disconnecting {
1998 next_action,
1999 modify_sent: true,
2000 };
2001
2002 self.notifier
2004 .modify_connection(ModifyConnectionRequest {
2005 monitor_page: Update::Reset,
2006 interrupt_page: Update::Reset,
2007 ..Default::default()
2008 })
2009 .expect("resetting state should not fail");
2010 }
2011
2012 fn is_resetting(&self) -> bool {
2015 matches!(
2016 &self.inner.state,
2017 ConnectionState::Connecting {
2018 next_action: ConnectionAction::Reset,
2019 ..
2020 } | ConnectionState::Disconnecting {
2021 next_action: ConnectionAction::Reset,
2022 ..
2023 }
2024 )
2025 }
2026
2027 pub fn close_complete(&mut self, offer_id: OfferId) {
2029 let channel = &mut self.inner.channels[offer_id];
2030 tracing::info!(offer_id = offer_id.0, key = %channel.offer.key(), "closed channel");
2031 match channel.state {
2032 ChannelState::Closing {
2033 reserved_state: Some(reserved_state),
2034 ..
2035 } => {
2036 channel.state = ChannelState::Closed;
2037 if matches!(self.inner.state, ConnectionState::Connected { .. }) {
2038 let channel_id = channel.info.expect("assigned").channel_id;
2039 self.send_close_reserved_channel_response(
2040 channel_id,
2041 offer_id,
2042 reserved_state.target,
2043 );
2044 } else {
2045 if Self::client_release_channel(
2048 self.inner
2049 .pending_messages
2050 .sender(self.notifier, self.inner.state.is_paused()),
2051 offer_id,
2052 channel,
2053 &mut self.inner.gpadls,
2054 &mut self.inner.incomplete_gpadls,
2055 &mut self.inner.assigned_channels,
2056 &mut self.inner.assigned_monitors,
2057 None,
2058 ) {
2059 self.inner.channels.remove(offer_id);
2060 }
2061 }
2062 }
2063 ChannelState::Closing { .. } => {
2064 channel.state = ChannelState::Closed;
2065 }
2066 ChannelState::ClosingClientRelease => {
2067 channel.state = ChannelState::ClientReleased;
2068 self.check_disconnected();
2069 }
2070 ChannelState::ClosingReopen { request, .. } => {
2071 channel.state = ChannelState::Closed;
2072 self.open_channel(offer_id, &request, None);
2073 }
2074
2075 ChannelState::Closed
2076 | ChannelState::ClientReleased
2077 | ChannelState::Opening { .. }
2078 | ChannelState::Open { .. }
2079 | ChannelState::Revoked
2080 | ChannelState::Reoffered
2081 | ChannelState::OpeningClientRelease => {
2082 tracing::error!(?offer_id, key = %channel.offer.key(), state = ?channel.state, "invalid close complete")
2083 }
2084 }
2085 }
2086
2087 fn send_close_reserved_channel_response(
2088 &mut self,
2089 channel_id: ChannelId,
2090 offer_id: OfferId,
2091 target: ConnectionTarget,
2092 ) {
2093 self.sender().send_message_with_target(
2094 &protocol::CloseReservedChannelResponse { channel_id },
2095 MessageTarget::ReservedChannel(offer_id, target),
2096 );
2097 }
2098
2099 fn handle_initiate_contact(
2102 &mut self,
2103 input: &protocol::InitiateContact2,
2104 message: &SynicMessage,
2105 includes_client_id: bool,
2106 ) -> Result<(), ChannelError> {
2107 let target_info =
2108 protocol::TargetInfo::from(input.initiate_contact.interrupt_page_or_target_info);
2109
2110 let target_sint = if message.multiclient
2111 && input.initiate_contact.version_requested >= Version::Win10Rs3_1 as u32
2112 {
2113 target_info.sint()
2114 } else {
2115 VMBUS_SINT
2116 };
2117
2118 let target_vtl = if message.multiclient
2119 && input.initiate_contact.version_requested >= Version::Win10Rs4 as u32
2120 {
2121 target_info.vtl()
2122 } else {
2123 0
2124 };
2125
2126 let feature_flags = if input.initiate_contact.version_requested >= Version::Copper as u32 {
2127 target_info.feature_flags()
2128 } else {
2129 0
2130 };
2131
2132 let target_message_vp =
2137 if input.initiate_contact.version_requested >= Version::Win8_1 as u32 {
2138 input.initiate_contact.target_message_vp
2139 } else {
2140 0
2141 };
2142
2143 let interrupt_page = (input.initiate_contact.version_requested < Version::Win8 as u32
2150 && input.initiate_contact.interrupt_page_or_target_info != 0)
2151 .then_some(input.initiate_contact.interrupt_page_or_target_info);
2152
2153 let monitor_page = if (input.initiate_contact.parent_to_child_monitor_page_gpa == 0)
2156 != (input.initiate_contact.child_to_parent_monitor_page_gpa == 0)
2157 {
2158 MonitorPageRequest::Invalid
2159 } else if input.initiate_contact.parent_to_child_monitor_page_gpa != 0 {
2160 MonitorPageRequest::Some(MonitorPageGpas {
2161 parent_to_child: input.initiate_contact.parent_to_child_monitor_page_gpa,
2162 child_to_parent: input.initiate_contact.child_to_parent_monitor_page_gpa,
2163 })
2164 } else {
2165 MonitorPageRequest::None
2166 };
2167
2168 let client_id = if FeatureFlags::from(feature_flags).client_id() {
2171 if includes_client_id {
2172 input.client_id
2173 } else {
2174 return Err(ChannelError::ParseError(
2175 protocol::ParseError::MessageTooSmall(Some(
2176 protocol::MessageType::INITIATE_CONTACT,
2177 )),
2178 ));
2179 }
2180 } else {
2181 Guid::ZERO
2182 };
2183
2184 let request = InitiateContactRequest {
2185 version_requested: input.initiate_contact.version_requested,
2186 target_message_vp,
2187 monitor_page,
2188 target_sint,
2189 target_vtl,
2190 feature_flags,
2191 interrupt_page,
2192 client_id,
2193 trusted: message.trusted,
2194 };
2195 self.initiate_contact(request);
2196 Ok(())
2197 }
2198
2199 pub fn initiate_contact(&mut self, request: InitiateContactRequest) {
2200 let vtl = self.inner.assigned_channels.vtl as u8;
2203 if request.target_vtl != vtl {
2204 self.notifier.forward_unhandled(request);
2206 return;
2207 }
2208
2209 if request.target_sint != VMBUS_SINT {
2210 tracelimit::warn_ratelimited!(
2211 target_vtl = request.target_vtl,
2212 target_sint = request.target_sint,
2213 version = request.version_requested,
2214 "unsupported multiclient request",
2215 );
2216
2217 self.send_version_response_with_target(
2219 None,
2220 MessageTarget::Custom(ConnectionTarget {
2221 vp: request.target_message_vp,
2222 sint: request.target_sint,
2223 }),
2224 );
2225
2226 return;
2227 }
2228
2229 if !self.request_disconnect(ConnectionAction::Reconnect {
2230 initiate_contact: request,
2231 }) {
2232 return;
2233 }
2234
2235 let Some(version) = self.check_version_supported(&request) else {
2236 tracelimit::warn_ratelimited!(
2237 vtl,
2238 version = request.version_requested,
2239 client_id = ?request.client_id,
2240 "Guest requested unsupported version"
2241 );
2242
2243 self.send_version_response(None);
2245 return;
2246 };
2247
2248 tracelimit::info_ratelimited!(
2249 vtl,
2250 ?version,
2251 client_id = ?request.client_id,
2252 trusted = request.trusted,
2253 "Guest negotiated version"
2254 );
2255
2256 let monitor_page = match request.monitor_page {
2259 MonitorPageRequest::Some(mp) => {
2260 if self.inner.require_server_allocated_mnf {
2261 if !version.feature_flags.server_specified_monitor_pages() {
2262 tracelimit::warn_ratelimited!(
2263 "guest-supplied monitor pages not supported; MNF will be disabled"
2264 );
2265 }
2266
2267 None
2268 } else {
2269 Some(mp)
2270 }
2271 }
2272 MonitorPageRequest::None => None,
2273 MonitorPageRequest::Invalid => {
2274 self.send_version_response(Some(VersionResponseData::new(
2276 version,
2277 protocol::ConnectionState::FAILED_UNKNOWN_FAILURE,
2278 )));
2279
2280 return;
2281 }
2282 };
2283
2284 self.inner.state = ConnectionState::Connecting {
2285 info: ConnectionInfo {
2286 version,
2287 trusted: request.trusted,
2288 interrupt_page: request.interrupt_page,
2289 monitor_page: monitor_page.map(MonitorPageGpaInfo::from_guest_gpas),
2290 target_message_vp: request.target_message_vp,
2291 modifying: false,
2292 offers_sent: false,
2293 client_id: request.client_id,
2294 paused: false,
2295 },
2296 next_action: ConnectionAction::None,
2297 };
2298
2299 if let Err(err) = self.notifier.modify_connection(ModifyConnectionRequest {
2302 version: Some(version),
2303 monitor_page: monitor_page.into(),
2304 interrupt_page: request.interrupt_page.into(),
2305 target_message_vp: Some(request.target_message_vp),
2306 notify_relay: true,
2307 }) {
2308 tracelimit::error_ratelimited!(?err, "server failed to change state");
2309 self.inner.state = ConnectionState::Disconnected;
2310 self.send_version_response(Some(VersionResponseData::new(
2311 version,
2312 protocol::ConnectionState::FAILED_UNKNOWN_FAILURE,
2313 )));
2314 }
2315 }
2316
2317 pub(crate) fn complete_initiate_contact(&mut self, response: ModifyConnectionResponse) {
2318 let ConnectionState::Connecting {
2319 mut info,
2320 next_action,
2321 } = self.inner.state
2322 else {
2323 panic!("Invalid state for completing InitiateContact.");
2324 };
2325
2326 const LOCAL_FEATURE_FLAGS: FeatureFlags = FeatureFlags::new()
2330 .with_client_id(true)
2331 .with_confidential_channels(true);
2332
2333 let (relay_feature_flags, server_specified_monitor_page) = match response {
2334 ModifyConnectionResponse::Supported(
2336 protocol::ConnectionState::SUCCESSFUL,
2337 feature_flags,
2338 server_specified_monitor_page,
2339 ) => (feature_flags, server_specified_monitor_page),
2340 ModifyConnectionResponse::Supported(
2343 connection_state,
2344 feature_flags,
2345 server_specified_monitor_page,
2346 ) => {
2347 tracelimit::error_ratelimited!(
2348 ?connection_state,
2349 "initiate contact failed because relay request failed"
2350 );
2351
2352 info.version.feature_flags &= (feature_flags | LOCAL_FEATURE_FLAGS)
2355 .with_server_specified_monitor_pages(server_specified_monitor_page.is_some());
2356
2357 self.send_version_response(Some(VersionResponseData::new(
2358 info.version,
2359 connection_state,
2360 )));
2361 self.inner.state = ConnectionState::Disconnected;
2362 return;
2363 }
2364 ModifyConnectionResponse::Unsupported => {
2367 self.send_version_response(None);
2368 self.inner.state = ConnectionState::Disconnected;
2369 return;
2370 }
2371 ModifyConnectionResponse::Modified(_) => {
2372 panic!("Invalid response for completing InitiateContact.");
2373 }
2374 };
2375
2376 assert!(
2378 info.version.feature_flags.server_specified_monitor_pages()
2379 || server_specified_monitor_page.is_none()
2380 );
2381
2382 info.version.feature_flags &= relay_feature_flags | LOCAL_FEATURE_FLAGS;
2385
2386 if let Some(gpas) = server_specified_monitor_page {
2390 info.monitor_page = Some(MonitorPageGpaInfo::from_server_gpas(gpas));
2391 info.version
2392 .feature_flags
2393 .set_server_specified_monitor_pages(true);
2394 } else {
2395 info.version
2396 .feature_flags
2397 .set_server_specified_monitor_pages(false);
2398 }
2399
2400 let version = info.version;
2401 self.inner.state = ConnectionState::Connected(info);
2402
2403 self.send_version_response(Some(
2404 VersionResponseData::new(version, protocol::ConnectionState::SUCCESSFUL)
2405 .with_monitor_pages(server_specified_monitor_page),
2406 ));
2407 if !matches!(next_action, ConnectionAction::None) && self.request_disconnect(next_action) {
2408 self.do_next_action(next_action);
2409 }
2410 }
2411
2412 fn check_version_supported(&self, request: &InitiateContactRequest) -> Option<VersionInfo> {
2414 let version = SUPPORTED_VERSIONS
2415 .iter()
2416 .find(|v| request.version_requested == **v as u32)
2417 .copied()?;
2418
2419 if let Some(max_version) = self.inner.max_version {
2421 if version as u32 > max_version.version {
2422 return None;
2423 }
2424 }
2425
2426 let supported_flags = if version >= Version::Copper {
2427 let max_supported_flags =
2429 SUPPORTED_FEATURE_FLAGS.with_confidential_channels(request.trusted);
2430
2431 if let Some(max_version) = self.inner.max_version {
2433 max_supported_flags & max_version.feature_flags
2434 } else {
2435 max_supported_flags
2436 }
2437 } else {
2438 FeatureFlags::new()
2439 };
2440
2441 let feature_flags = supported_flags & request.feature_flags.into();
2442
2443 assert!(version >= Version::Copper || feature_flags == FeatureFlags::new());
2444 if feature_flags.into_bits() != request.feature_flags {
2445 tracelimit::warn_ratelimited!(
2446 supported = feature_flags.into_bits(),
2447 requested = request.feature_flags,
2448 "Guest requested unsupported feature flags."
2449 );
2450 }
2451
2452 Some(VersionInfo {
2453 version,
2454 feature_flags,
2455 })
2456 }
2457
2458 fn send_version_response(&mut self, data: Option<VersionResponseData>) {
2459 self.send_version_response_with_target(data, MessageTarget::Default);
2460 }
2461
2462 fn send_version_response_with_target(
2463 &mut self,
2464 data: Option<VersionResponseData>,
2465 target: MessageTarget,
2466 ) {
2467 enum VersionResponseType {
2468 PreCopper,
2469 Copper,
2470 CopperWithServerMnf,
2471 }
2472
2473 let mut response_copper_with_mnf = protocol::VersionResponse3::new_zeroed();
2474 let response_copper = &mut response_copper_with_mnf.version_response2;
2475 let response = &mut response_copper.version_response;
2476 let mut response_type = VersionResponseType::PreCopper;
2477 if let Some(data) = data {
2478 if data.state == protocol::ConnectionState::SUCCESSFUL
2481 || data.version.version >= Version::Win8
2482 {
2483 response.version_supported = 1;
2484 response.connection_state = data.state;
2485 response.selected_version_or_connection_id =
2486 if data.version.version >= Version::Win10Rs3_1 {
2487 self.inner.child_connection_id
2488 } else {
2489 data.version.version as u32
2490 };
2491
2492 if data.version.version >= Version::Copper {
2493 response_copper.supported_features = data.version.feature_flags.into();
2494 response_type = VersionResponseType::Copper;
2495 if let Some(monitor_page) = data.monitor_pages {
2496 assert!(data.version.feature_flags.server_specified_monitor_pages());
2497 response_copper_with_mnf.child_to_parent_monitor_page_gpa =
2498 monitor_page.child_to_parent;
2499 response_copper_with_mnf.parent_to_child_monitor_page_gpa =
2500 monitor_page.parent_to_child;
2501 response_type = VersionResponseType::CopperWithServerMnf;
2502 }
2503 }
2504 }
2505 }
2506
2507 match response_type {
2509 VersionResponseType::PreCopper => {
2510 self.sender().send_message_with_target(response, target)
2511 }
2512 VersionResponseType::Copper => self
2513 .sender()
2514 .send_message_with_target(response_copper, target),
2515 VersionResponseType::CopperWithServerMnf => self
2516 .sender()
2517 .send_message_with_target(&response_copper_with_mnf, target),
2518 }
2519 }
2520
2521 fn request_disconnect(&mut self, new_action: ConnectionAction) -> bool {
2524 assert!(!self.is_resetting());
2525
2526 let gpadls = &mut self.inner.gpadls;
2528 let vm_reset = matches!(new_action, ConnectionAction::Reset);
2529 self.inner.channels.retain(|offer_id, channel| {
2530 (!vm_reset && channel.state.is_reserved())
2532 || !Self::client_release_channel(
2533 self.inner
2534 .pending_messages
2535 .sender(self.notifier, self.inner.state.is_paused()),
2536 offer_id,
2537 channel,
2538 gpadls,
2539 &mut self.inner.incomplete_gpadls,
2540 &mut self.inner.assigned_channels,
2541 &mut self.inner.assigned_monitors,
2542 None,
2543 )
2544 });
2545
2546 match &mut self.inner.state {
2550 ConnectionState::Disconnected => {
2551 if vm_reset {
2553 if !self.are_channels_reset(true) {
2554 self.inner.state = ConnectionState::Disconnecting {
2555 next_action: ConnectionAction::Reset,
2556 modify_sent: false,
2557 };
2558 }
2559 } else {
2560 assert!(self.are_channels_reset(false));
2561 }
2562 }
2563
2564 ConnectionState::Connected { .. } => {
2565 if self.are_channels_reset(vm_reset) {
2566 self.notify_disconnect(new_action);
2567 } else {
2568 self.inner.state = ConnectionState::Disconnecting {
2569 next_action: new_action,
2570 modify_sent: false,
2571 };
2572 }
2573 }
2574
2575 ConnectionState::Connecting { next_action, .. }
2576 | ConnectionState::Disconnecting { next_action, .. } => {
2577 *next_action = new_action;
2578 }
2579 }
2580
2581 matches!(self.inner.state, ConnectionState::Disconnected)
2582 }
2583
2584 pub(crate) fn complete_disconnect(&mut self) {
2585 if let ConnectionState::Disconnecting {
2586 next_action,
2587 modify_sent,
2588 } = std::mem::replace(&mut self.inner.state, ConnectionState::Disconnected)
2589 {
2590 assert!(self.are_channels_reset(matches!(next_action, ConnectionAction::Reset)));
2591 if !modify_sent {
2592 tracelimit::warn_ratelimited!("unexpected modify response");
2593 }
2594
2595 self.inner.state = ConnectionState::Disconnected;
2596 self.do_next_action(next_action);
2597 } else {
2598 unreachable!("not ready for disconnect");
2599 }
2600 }
2601
2602 fn do_next_action(&mut self, action: ConnectionAction) {
2603 match action {
2604 ConnectionAction::None => {}
2605 ConnectionAction::Reset => {
2606 self.complete_reset();
2607 }
2608 ConnectionAction::SendUnloadComplete => {
2609 self.complete_unload();
2610 }
2611 ConnectionAction::Reconnect { initiate_contact } => {
2612 self.initiate_contact(initiate_contact);
2613 }
2614 ConnectionAction::SendFailedVersionResponse => {
2615 self.send_version_response(None);
2618 }
2619 }
2620 }
2621
2622 fn handle_unload(&mut self) {
2624 tracing::debug!(
2625 vtl = self.inner.assigned_channels.vtl as u8,
2626 state = ?self.inner.state,
2627 "VmBus received unload request from guest",
2628 );
2629
2630 if self.request_disconnect(ConnectionAction::SendUnloadComplete) {
2631 self.complete_unload();
2632 }
2633 }
2634
2635 fn complete_unload(&mut self) {
2636 self.notifier.unload_complete();
2637 if let Some(version) = self.inner.delayed_max_version.take() {
2638 self.inner.set_compatibility_version(version, false);
2639 }
2640
2641 self.sender().send_message(&protocol::UnloadComplete {});
2642 tracelimit::info_ratelimited!("Vmbus disconnected");
2643 }
2644
2645 fn handle_request_offers(&mut self) -> Result<(), ChannelError> {
2647 let ConnectionState::Connected(info) = &mut self.inner.state else {
2648 unreachable!(
2649 "in unexpected state {:?}, should be prevented by Message::parse()",
2650 self.inner.state
2651 );
2652 };
2653
2654 if info.offers_sent {
2655 return Err(ChannelError::OffersAlreadySent);
2656 }
2657
2658 info.offers_sent = true;
2659
2660 let mut sorted_channels: Vec<_> = self
2663 .inner
2664 .channels
2665 .iter_mut()
2666 .filter(|(_, channel)| !channel.state.is_reserved())
2667 .collect();
2668
2669 if self.inner.use_absolute_channel_order {
2670 sorted_channels.sort_unstable_by_key(|(_, channel)| {
2671 (
2672 channel.offer.offer_order.unwrap_or(u64::MAX),
2673 channel.offer.interface_id,
2674 channel.offer.instance_id,
2675 )
2676 });
2677 } else {
2678 sorted_channels.sort_unstable_by_key(|(_, channel)| {
2679 (
2680 channel.offer.interface_id,
2681 channel.offer.offer_order.unwrap_or(u64::MAX),
2682 channel.offer.instance_id,
2683 )
2684 });
2685 }
2686
2687 for (offer_id, channel) in sorted_channels {
2688 assert!(matches!(channel.state, ChannelState::ClientReleased));
2689
2690 channel.prepare_channel(
2691 offer_id,
2692 &mut self.inner.assigned_channels,
2693 &mut self.inner.assigned_monitors,
2694 );
2695
2696 channel.state = ChannelState::Closed;
2697 self.inner
2698 .pending_messages
2699 .sender(self.notifier, info.paused)
2700 .send_offer(channel, info);
2701 }
2702 self.sender().send_message(&protocol::AllOffersDelivered {});
2703
2704 Ok(())
2705 }
2706
2707 #[must_use]
2710 fn gpadl_updated(
2711 mut sender: MessageSender<'_, N>,
2712 offer_id: OfferId,
2713 channel: &Channel,
2714 gpadl_id: GpadlId,
2715 gpadl: &Gpadl,
2716 ) -> bool {
2717 if channel.state.is_revoked() {
2718 let channel_id = channel.info.as_ref().expect("assigned").channel_id;
2719 sender.send_gpadl_created(channel_id, gpadl_id, protocol::STATUS_UNSUCCESSFUL);
2720 false
2721 } else {
2722 sender.notifier.notify(
2724 offer_id,
2725 Action::Gpadl(gpadl_id, gpadl.count, gpadl.buf.clone()),
2726 );
2727 true
2728 }
2729 }
2730
2731 fn handle_gpadl_header_core(
2733 &mut self,
2734 input: &protocol::GpadlHeader,
2735 range: &[u8],
2736 ) -> Result<(), ChannelError> {
2737 let (offer_id, channel) = self
2739 .inner
2740 .channels
2741 .get_by_channel_id_mut(&self.inner.assigned_channels, input.channel_id)?;
2742
2743 if channel.state.is_reserved() {
2746 return Err(ChannelError::ChannelReserved);
2747 }
2748
2749 let mut gpadl = Gpadl::new(input.count, input.len as usize / 8);
2751 let done = gpadl.append(range)?;
2752
2753 let gpadl = match self.inner.gpadls.entry((input.gpadl_id, offer_id)) {
2755 Entry::Vacant(entry) => entry.insert(gpadl),
2756 Entry::Occupied(_) => return Err(ChannelError::DuplicateGpadlId),
2757 };
2758
2759 if !done {
2765 match self.inner.incomplete_gpadls.entry(input.gpadl_id) {
2766 Entry::Vacant(entry) => {
2767 entry.insert(offer_id);
2768 }
2769 Entry::Occupied(_) => {
2770 self.inner.gpadls.remove(&(input.gpadl_id, offer_id));
2771 tracelimit::error_ratelimited!(
2772 channel_id = ?input.channel_id,
2773 key = %channel.offer.key(),
2774 gpadl_id = ?input.gpadl_id,
2775 "duplicate in-progress gpadl ID",
2776 );
2777 return Err(ChannelError::DuplicateGpadlId);
2778 }
2779 }
2780 }
2781
2782 if done
2783 && !Self::gpadl_updated(
2784 self.inner
2785 .pending_messages
2786 .sender(self.notifier, self.inner.state.is_paused()),
2787 offer_id,
2788 channel,
2789 input.gpadl_id,
2790 gpadl,
2791 )
2792 {
2793 self.inner.gpadls.remove(&(input.gpadl_id, offer_id));
2794 }
2795 Ok(())
2796 }
2797
2798 fn handle_gpadl_header(&mut self, input: &protocol::GpadlHeader, range: &[u8]) {
2800 if let Err(err) = self.handle_gpadl_header_core(input, range) {
2801 tracelimit::warn_ratelimited!(
2802 err = &err as &dyn std::error::Error,
2803 channel_id = ?input.channel_id,
2804 key = %self.inner.channels.get_by_channel_id(&self.inner.assigned_channels, input.channel_id).map(|(_, c)| c.offer.key()).unwrap_or_default(),
2805 gpadl_id = ?input.gpadl_id,
2806 "error handling gpadl header"
2807 );
2808
2809 self.sender().send_gpadl_created(
2811 input.channel_id,
2812 input.gpadl_id,
2813 protocol::STATUS_UNSUCCESSFUL,
2814 );
2815 }
2816 }
2817
2818 fn handle_gpadl_body(
2824 &mut self,
2825 input: &protocol::GpadlBody,
2826 range: &[u8],
2827 ) -> Result<(), ChannelError> {
2828 let &offer_id = self
2832 .inner
2833 .incomplete_gpadls
2834 .get(&input.gpadl_id)
2835 .ok_or(ChannelError::UnknownGpadlId)?;
2836 let gpadl = self
2837 .inner
2838 .gpadls
2839 .get_mut(&(input.gpadl_id, offer_id))
2840 .ok_or(ChannelError::UnknownGpadlId)?;
2841 let channel = &mut self.inner.channels[offer_id];
2842
2843 match gpadl.append(range) {
2844 Ok(done) => {
2845 if done {
2846 self.inner.incomplete_gpadls.remove(&input.gpadl_id);
2847 if !Self::gpadl_updated(
2848 self.inner
2849 .pending_messages
2850 .sender(self.notifier, self.inner.state.is_paused()),
2851 offer_id,
2852 channel,
2853 input.gpadl_id,
2854 gpadl,
2855 ) {
2856 self.inner.gpadls.remove(&(input.gpadl_id, offer_id));
2857 }
2858 }
2859 }
2860 Err(err) => {
2861 self.inner.incomplete_gpadls.remove(&input.gpadl_id);
2862 self.inner.gpadls.remove(&(input.gpadl_id, offer_id));
2863 let channel_id = channel.info.as_ref().expect("assigned").channel_id;
2864 tracelimit::warn_ratelimited!(
2865 err = &err as &dyn std::error::Error,
2866 channel_id = channel_id.0,
2867 key = %channel.offer.key(),
2868 gpadl_id = input.gpadl_id.0,
2869 "error handling gpadl body"
2870 );
2871 self.sender().send_gpadl_created(
2872 channel_id,
2873 input.gpadl_id,
2874 protocol::STATUS_UNSUCCESSFUL,
2875 );
2876 }
2877 }
2878
2879 Ok(())
2880 }
2881
2882 fn handle_gpadl_teardown(
2884 &mut self,
2885 input: &protocol::GpadlTeardown,
2886 ) -> Result<(), ChannelError> {
2887 let (offer_id, channel) = self
2888 .inner
2889 .channels
2890 .get_by_channel_id_mut(&self.inner.assigned_channels, input.channel_id)?;
2891
2892 tracing::debug!(
2893 channel_id = input.channel_id.0,
2894 key = %channel.offer.key(),
2895 gpadl_id = input.gpadl_id.0,
2896 "Received GPADL teardown request"
2897 );
2898
2899 let gpadl = self
2900 .inner
2901 .gpadls
2902 .get_mut(&(input.gpadl_id, offer_id))
2903 .ok_or(ChannelError::UnknownGpadlId)?;
2904
2905 match gpadl.state {
2906 GpadlState::InProgress
2907 | GpadlState::Offered
2908 | GpadlState::OfferedTearingDown
2909 | GpadlState::TearingDown => {
2910 return Err(ChannelError::InvalidGpadlState);
2911 }
2912 GpadlState::Accepted => {
2913 if channel.info.as_ref().map(|info| info.channel_id) != Some(input.channel_id) {
2914 return Err(ChannelError::WrongGpadlChannelId);
2915 }
2916
2917 if channel.state.is_reserved() {
2921 return Err(ChannelError::ChannelReserved);
2922 }
2923
2924 if channel.state.is_revoked() {
2925 tracing::trace!(
2926 channel_id = input.channel_id.0,
2927 key = %channel.offer.key(),
2928 gpadl_id = input.gpadl_id.0,
2929 "Gpadl teardown for revoked channel"
2930 );
2931
2932 self.inner.gpadls.remove(&(input.gpadl_id, offer_id));
2933 self.sender().send_gpadl_torndown(input.gpadl_id);
2934 } else {
2935 gpadl.state = GpadlState::TearingDown;
2936 self.notifier.notify(
2937 offer_id,
2938 Action::TeardownGpadl {
2939 gpadl_id: input.gpadl_id,
2940 post_restore: false,
2941 },
2942 );
2943 }
2944 }
2945 }
2946 Ok(())
2947 }
2948
2949 fn open_channel(
2952 &mut self,
2953 offer_id: OfferId,
2954 input: &OpenRequest,
2955 reserved_state: Option<ReservedState>,
2956 ) {
2957 let channel = &mut self.inner.channels[offer_id];
2958 assert!(matches!(channel.state, ChannelState::Closed));
2959
2960 channel.state = ChannelState::Opening {
2961 request: *input,
2962 reserved_state,
2963 };
2964
2965 let info = channel.info.as_ref().expect("assigned");
2968 self.notifier.notify(
2969 offer_id,
2970 Action::Open(
2971 OpenParams::from_request(
2972 info,
2973 input,
2974 channel.handled_monitor_info(),
2975 reserved_state.map(|state| state.target),
2976 ),
2977 self.inner.state.get_version().expect("must be connected"),
2978 ),
2979 );
2980 }
2981
2982 fn handle_open_channel(&mut self, input: &protocol::OpenChannel2) -> Result<(), ChannelError> {
2984 let (offer_id, channel) = self
2985 .inner
2986 .channels
2987 .get_by_channel_id_mut(&self.inner.assigned_channels, input.open_channel.channel_id)?;
2988
2989 let guest_specified_interrupt_info = self
2990 .inner
2991 .state
2992 .check_feature_flags(|ff| ff.guest_specified_signal_parameters())
2993 .then_some(SignalInfo {
2994 event_flag: input.event_flag,
2995 connection_id: input.connection_id,
2996 });
2997
2998 let flags = if self
2999 .inner
3000 .state
3001 .check_feature_flags(|ff| ff.channel_interrupt_redirection())
3002 {
3003 input.flags
3004 } else {
3005 Default::default()
3006 };
3007
3008 let request = OpenRequest {
3009 open_id: input.open_channel.open_id,
3010 ring_buffer_gpadl_id: input.open_channel.ring_buffer_gpadl_id,
3011 target_vp: input.open_channel.target_vp,
3012 downstream_ring_buffer_page_offset: input
3013 .open_channel
3014 .downstream_ring_buffer_page_offset,
3015 user_data: input.open_channel.user_data,
3016 guest_specified_interrupt_info,
3017 flags,
3018 };
3019
3020 match channel.state {
3021 ChannelState::Closed => self.open_channel(offer_id, &request, None),
3022 ChannelState::Closing { params, .. } => {
3023 channel.state = ChannelState::ClosingReopen { params, request }
3027 }
3028 ChannelState::Revoked | ChannelState::Reoffered => {}
3029
3030 ChannelState::Open { .. }
3031 | ChannelState::Opening { .. }
3032 | ChannelState::ClosingReopen { .. } => return Err(ChannelError::ChannelAlreadyOpen),
3033
3034 ChannelState::ClientReleased
3035 | ChannelState::ClosingClientRelease
3036 | ChannelState::OpeningClientRelease => unreachable!(),
3037 }
3038 Ok(())
3039 }
3040
3041 fn handle_close_channel(&mut self, input: &protocol::CloseChannel) -> Result<(), ChannelError> {
3043 let (offer_id, channel) = self
3044 .inner
3045 .channels
3046 .get_by_channel_id_mut(&self.inner.assigned_channels, input.channel_id)?;
3047
3048 match channel.state {
3049 ChannelState::Open {
3050 params,
3051 modify_state,
3052 reserved_state: None,
3053 } => {
3054 if modify_state.is_modifying() {
3055 tracelimit::warn_ratelimited!(
3056 key = %channel.offer.key(),
3057 ?modify_state,
3058 "Client is closing the channel with a modify in progress"
3059 )
3060 }
3061
3062 channel.state = ChannelState::Closing {
3063 params,
3064 reserved_state: None,
3065 };
3066 self.notifier.notify(offer_id, Action::Close);
3067 }
3068
3069 ChannelState::Open {
3070 reserved_state: Some(_),
3071 ..
3072 } => return Err(ChannelError::ChannelReserved),
3073
3074 ChannelState::Revoked | ChannelState::Reoffered => {}
3075
3076 ChannelState::Closed
3077 | ChannelState::Opening { .. }
3078 | ChannelState::Closing { .. }
3079 | ChannelState::ClosingReopen { .. } => return Err(ChannelError::ChannelNotOpen),
3080
3081 ChannelState::ClientReleased
3082 | ChannelState::ClosingClientRelease
3083 | ChannelState::OpeningClientRelease => unreachable!(),
3084 }
3085
3086 Ok(())
3087 }
3088
3089 fn handle_open_reserved_channel(
3092 &mut self,
3093 input: &protocol::OpenReservedChannel,
3094 version: VersionInfo,
3095 ) -> Result<(), ChannelError> {
3096 let (offer_id, channel) = self
3097 .inner
3098 .channels
3099 .get_by_channel_id_mut(&self.inner.assigned_channels, input.channel_id)?;
3100
3101 let target = ConnectionTarget {
3102 vp: input.target_vp,
3103 sint: input.target_sint as u8,
3104 };
3105
3106 let reserved_state = Some(ReservedState { version, target });
3107
3108 let request = OpenRequest {
3109 ring_buffer_gpadl_id: input.ring_buffer_gpadl,
3110 target_vp: protocol::VP_INDEX_DISABLE_INTERRUPT,
3112 downstream_ring_buffer_page_offset: input.downstream_page_offset,
3113 open_id: 0,
3114 user_data: UserDefinedData::new_zeroed(),
3115 guest_specified_interrupt_info: None,
3116 flags: Default::default(),
3117 };
3118
3119 match channel.state {
3120 ChannelState::Closed => self.open_channel(offer_id, &request, reserved_state),
3121 ChannelState::Revoked | ChannelState::Reoffered => {}
3122
3123 ChannelState::Open { .. } | ChannelState::Opening { .. } => {
3124 return Err(ChannelError::ChannelAlreadyOpen);
3125 }
3126
3127 ChannelState::Closing { .. } | ChannelState::ClosingReopen { .. } => {
3128 return Err(ChannelError::InvalidChannelState);
3129 }
3130
3131 ChannelState::ClientReleased
3132 | ChannelState::ClosingClientRelease
3133 | ChannelState::OpeningClientRelease => unreachable!(),
3134 }
3135 Ok(())
3136 }
3137
3138 fn handle_close_reserved_channel(
3141 &mut self,
3142 input: &protocol::CloseReservedChannel,
3143 ) -> Result<(), ChannelError> {
3144 let (offer_id, channel) = self
3145 .inner
3146 .channels
3147 .get_by_channel_id_mut(&self.inner.assigned_channels, input.channel_id)?;
3148
3149 match channel.state {
3150 ChannelState::Open {
3151 params,
3152 reserved_state: Some(mut resvd),
3153 ..
3154 } => {
3155 resvd.target.vp = input.target_vp;
3156 resvd.target.sint = input.target_sint as u8;
3157 channel.state = ChannelState::Closing {
3158 params,
3159 reserved_state: Some(resvd),
3160 };
3161 self.notifier.notify(offer_id, Action::Close);
3162 }
3163
3164 ChannelState::Open {
3165 reserved_state: None,
3166 ..
3167 } => return Err(ChannelError::ChannelNotReserved),
3168
3169 ChannelState::Revoked | ChannelState::Reoffered => {}
3170
3171 ChannelState::Closed
3172 | ChannelState::Opening { .. }
3173 | ChannelState::Closing { .. }
3174 | ChannelState::ClosingReopen { .. } => return Err(ChannelError::ChannelNotOpen),
3175
3176 ChannelState::ClientReleased
3177 | ChannelState::ClosingClientRelease
3178 | ChannelState::OpeningClientRelease => unreachable!(),
3179 }
3180
3181 Ok(())
3182 }
3183
3184 #[must_use]
3188 fn client_release_channel(
3189 mut sender: MessageSender<'_, N>,
3190 offer_id: OfferId,
3191 channel: &mut Channel,
3192 gpadls: &mut GpadlMap,
3193 incomplete_gpadls: &mut IncompleteGpadlMap,
3194 assigned_channels: &mut AssignedChannels,
3195 assigned_monitors: &mut AssignedMonitors,
3196 info: Option<&ConnectionInfo>,
3197 ) -> bool {
3198 tracelimit::info_ratelimited!(?offer_id, key = %channel.offer.key(), "client released channel");
3199 gpadls.retain(|&(gpadl_id, gpadl_offer_id), gpadl| {
3201 if gpadl_offer_id != offer_id {
3202 return true;
3203 }
3204 match gpadl.state {
3205 GpadlState::InProgress => {
3206 incomplete_gpadls.remove(&gpadl_id);
3207 false
3208 }
3209 GpadlState::Offered => {
3210 gpadl.state = GpadlState::OfferedTearingDown;
3211 true
3212 }
3213 GpadlState::Accepted => {
3214 if channel.state.is_revoked() {
3215 false
3217 } else {
3218 gpadl.state = GpadlState::TearingDown;
3219 sender.notifier.notify(
3220 offer_id,
3221 Action::TeardownGpadl {
3222 gpadl_id,
3223 post_restore: false,
3224 },
3225 );
3226 true
3227 }
3228 }
3229 GpadlState::OfferedTearingDown | GpadlState::TearingDown => true,
3230 }
3231 });
3232
3233 let remove = match &mut channel.state {
3234 ChannelState::Closed => {
3235 channel.state = ChannelState::ClientReleased;
3236 false
3237 }
3238 ChannelState::Reoffered => {
3239 if let Some(info) = info {
3240 channel.state = ChannelState::Closed;
3241 channel.restore_state = RestoreState::New;
3242 sender.send_offer(channel, info);
3243 return false;
3245 }
3246 channel.state = ChannelState::ClientReleased;
3247 false
3248 }
3249 ChannelState::Revoked => {
3250 channel.state = ChannelState::ClientReleased;
3251 true
3252 }
3253 ChannelState::Opening { .. } => {
3254 channel.state = ChannelState::OpeningClientRelease;
3255 false
3256 }
3257 ChannelState::Open { .. } => {
3258 channel.state = ChannelState::ClosingClientRelease;
3259 sender.notifier.notify(offer_id, Action::Close);
3260 false
3261 }
3262 ChannelState::Closing { .. } | ChannelState::ClosingReopen { .. } => {
3263 channel.state = ChannelState::ClosingClientRelease;
3264 false
3265 }
3266
3267 ChannelState::ClosingClientRelease
3268 | ChannelState::OpeningClientRelease
3269 | ChannelState::ClientReleased => false,
3270 };
3271
3272 assert!(channel.state.is_released());
3273
3274 channel.release_channel(offer_id, assigned_channels, assigned_monitors);
3275 remove
3276 }
3277
3278 fn handle_rel_id_released(
3280 &mut self,
3281 input: &protocol::RelIdReleased,
3282 ) -> Result<(), ChannelError> {
3283 let channel_id = input.channel_id;
3284 let (offer_id, channel) = self
3285 .inner
3286 .channels
3287 .get_by_channel_id_mut(&self.inner.assigned_channels, channel_id)?;
3288
3289 match channel.state {
3290 ChannelState::Closed
3291 | ChannelState::Revoked
3292 | ChannelState::Closing { .. }
3293 | ChannelState::Reoffered => {
3294 if Self::client_release_channel(
3295 self.inner
3296 .pending_messages
3297 .sender(self.notifier, self.inner.state.is_paused()),
3298 offer_id,
3299 channel,
3300 &mut self.inner.gpadls,
3301 &mut self.inner.incomplete_gpadls,
3302 &mut self.inner.assigned_channels,
3303 &mut self.inner.assigned_monitors,
3304 self.inner.state.get_connected_info(),
3305 ) {
3306 self.inner.channels.remove(offer_id);
3307 }
3308
3309 self.check_disconnected();
3310 }
3311
3312 ChannelState::Opening { .. }
3313 | ChannelState::Open { .. }
3314 | ChannelState::ClosingReopen { .. } => return Err(ChannelError::InvalidChannelState),
3315
3316 ChannelState::ClientReleased
3317 | ChannelState::OpeningClientRelease
3318 | ChannelState::ClosingClientRelease => unreachable!(),
3319 }
3320 Ok(())
3321 }
3322
3323 fn handle_tl_connect_request(&mut self, request: protocol::TlConnectRequest2) {
3326 let version = self
3327 .inner
3328 .state
3329 .get_version()
3330 .expect("must be connected")
3331 .version;
3332
3333 let hosted_silo_unaware = version < Version::Win10Rs5;
3334 self.notifier
3335 .notify_hvsock(&HvsockConnectRequest::from_message(
3336 request,
3337 hosted_silo_unaware,
3338 ));
3339 }
3340
3341 pub fn send_tl_connect_result(&mut self, result: HvsockConnectResult) {
3343 if !result.success && self.inner.state.check_version(Version::Win10Rs3_0) {
3347 self.sender().send_message(&protocol::TlConnectResult {
3351 service_id: result.service_id,
3352 endpoint_id: result.endpoint_id,
3353 status: protocol::STATUS_CONNECTION_REFUSED,
3354 })
3355 }
3356 }
3357
3358 fn handle_modify_channel(
3361 &mut self,
3362 request: &protocol::ModifyChannel,
3363 ) -> Result<(), ChannelError> {
3364 let result = self.modify_channel(request);
3365 if result.is_err() {
3366 self.send_modify_channel_response(request.channel_id, protocol::STATUS_UNSUCCESSFUL);
3367 }
3368
3369 result
3370 }
3371
3372 fn modify_channel(&mut self, request: &protocol::ModifyChannel) -> Result<(), ChannelError> {
3374 let (offer_id, channel) = self
3375 .inner
3376 .channels
3377 .get_by_channel_id_mut(&self.inner.assigned_channels, request.channel_id)?;
3378
3379 let (open_request, modify_state) = match &mut channel.state {
3380 ChannelState::Open {
3381 params,
3382 modify_state,
3383 reserved_state: None,
3384 } => (params, modify_state),
3385 _ => return Err(ChannelError::InvalidChannelState),
3386 };
3387
3388 if let ModifyState::Modifying { pending_target_vp } = modify_state {
3389 if self.inner.state.check_version(Version::Iron) {
3390 tracelimit::warn_ratelimited!(
3393 key = %channel.offer.key(),
3394 "Client sent new ModifyChannel before receiving ModifyChannelResponse."
3395 );
3396 } else {
3397 *pending_target_vp = Some(request.target_vp);
3400 }
3401 } else {
3402 self.notifier.notify(
3403 offer_id,
3404 Action::Modify {
3405 target_vp: request.target_vp,
3406 },
3407 );
3408
3409 open_request.target_vp = request.target_vp;
3411 *modify_state = ModifyState::Modifying {
3412 pending_target_vp: None,
3413 };
3414 }
3415
3416 Ok(())
3417 }
3418
3419 pub fn modify_channel_complete(&mut self, offer_id: OfferId, status: i32) {
3426 let channel = &mut self.inner.channels[offer_id];
3427
3428 if let ChannelState::Open {
3429 params,
3430 modify_state: ModifyState::Modifying { pending_target_vp },
3431 reserved_state: None,
3432 } = channel.state
3433 {
3434 channel.state = ChannelState::Open {
3435 params,
3436 modify_state: ModifyState::NotModifying,
3437 reserved_state: None,
3438 };
3439
3440 let channel_id = channel.info.as_ref().expect("assigned").channel_id;
3442 let key = channel.offer.key();
3443 self.send_modify_channel_response(channel_id, status);
3444
3445 if let Some(target_vp) = pending_target_vp {
3447 let request = protocol::ModifyChannel {
3448 channel_id,
3449 target_vp,
3450 };
3451
3452 if let Err(error) = self.handle_modify_channel(&request) {
3453 tracelimit::warn_ratelimited!(?error, %key, "Pending ModifyChannel request failed.")
3454 }
3455 }
3456 }
3457 }
3458
3459 fn send_modify_channel_response(&mut self, channel_id: ChannelId, status: i32) {
3460 if self.inner.state.check_version(Version::Iron) {
3461 self.sender()
3462 .send_message(&protocol::ModifyChannelResponse { channel_id, status });
3463 }
3464 }
3465
3466 fn handle_modify_connection(&mut self, request: protocol::ModifyConnection) {
3467 if let Err(err) = self.modify_connection(request) {
3468 tracelimit::error_ratelimited!(?err, "modifying connection failed");
3469 self.complete_modify_connection(ModifyConnectionResponse::Modified(
3470 protocol::ConnectionState::FAILED_UNKNOWN_FAILURE,
3471 ));
3472 }
3473 }
3474
3475 fn modify_connection(&mut self, request: protocol::ModifyConnection) -> anyhow::Result<()> {
3476 let ConnectionState::Connected(info) = &mut self.inner.state else {
3477 anyhow::bail!(
3478 "Invalid state for ModifyConnection request: {:?}",
3479 self.inner.state
3480 );
3481 };
3482
3483 if info.modifying {
3484 anyhow::bail!(
3485 "Duplicate ModifyConnection request, state: {:?}",
3486 self.inner.state
3487 );
3488 }
3489
3490 if matches!(
3491 info.monitor_page,
3492 Some(MonitorPageGpaInfo {
3493 server_allocated: true,
3494 ..
3495 })
3496 ) {
3497 anyhow::bail!("Cannot modify server-allocated monitor pages");
3498 }
3499
3500 if (request.child_to_parent_monitor_page_gpa == 0)
3501 != (request.parent_to_child_monitor_page_gpa == 0)
3502 {
3503 anyhow::bail!("Guest must specify either both or no monitor pages, {request:?}");
3504 }
3505
3506 let monitor_page = (request.child_to_parent_monitor_page_gpa != 0).then_some(
3507 MonitorPageGpaInfo::from_guest_gpas(MonitorPageGpas {
3508 child_to_parent: request.child_to_parent_monitor_page_gpa,
3509 parent_to_child: request.parent_to_child_monitor_page_gpa,
3510 }),
3511 );
3512
3513 info.modifying = true;
3514 info.monitor_page = monitor_page;
3515 tracing::debug!("modifying connection parameters.");
3516 self.notifier.modify_connection(request.into())?;
3517
3518 Ok(())
3519 }
3520
3521 pub fn complete_modify_connection(&mut self, response: ModifyConnectionResponse) {
3522 tracing::debug!(?response, "modifying connection parameters complete");
3523
3524 match &mut self.inner.state {
3528 ConnectionState::Connecting { .. } => self.complete_initiate_contact(response),
3529 ConnectionState::Disconnecting { .. } => self.complete_disconnect(),
3530 ConnectionState::Connected(info) => {
3531 let ModifyConnectionResponse::Modified(connection_state) = response else {
3532 panic!(
3533 "Relay should not return {:?} for a modify request with no version.",
3534 response
3535 );
3536 };
3537
3538 if !info.modifying {
3539 panic!(
3540 "ModifyConnection response while not modifying, state: {:?}",
3541 self.inner.state
3542 );
3543 }
3544
3545 info.modifying = false;
3546 self.sender()
3547 .send_message(&protocol::ModifyConnectionResponse { connection_state });
3548 }
3549 _ => panic!(
3550 "Invalid state for ModifyConnection response: {:?}",
3551 self.inner.state
3552 ),
3553 }
3554 }
3555
3556 fn handle_pause(&mut self) {
3557 tracelimit::info_ratelimited!("pausing sending messages");
3558 self.sender().send_message(&protocol::PauseResponse {});
3559 let ConnectionState::Connected(info) = &mut self.inner.state else {
3560 unreachable!(
3561 "in unexpected state {:?}, should be prevented by Message::parse()",
3562 self.inner.state
3563 );
3564 };
3565 info.paused = true;
3566 }
3567
3568 pub fn handle_synic_message(&mut self, message: SynicMessage) -> Result<(), ChannelError> {
3570 assert!(!self.is_resetting());
3571
3572 let version = self.inner.state.get_version();
3573 let msg = Message::parse(&message.data, version)?;
3574 tracing::trace!(?msg, message.trusted, "received vmbus message");
3575 if self.inner.state.is_trusted() && !message.trusted {
3580 tracelimit::warn_ratelimited!(?msg, "Received untrusted message");
3581 return Err(ChannelError::UntrustedMessage);
3582 }
3583
3584 match &mut self.inner.state {
3586 ConnectionState::Connected(info) if info.paused => {
3587 if !matches!(
3588 msg,
3589 Message::Resume(..)
3590 | Message::Unload(..)
3591 | Message::InitiateContact { .. }
3592 | Message::InitiateContact2 { .. }
3593 ) {
3594 tracelimit::warn_ratelimited!(?msg, "Received message while paused");
3595 return Err(ChannelError::Paused);
3596 }
3597 tracelimit::info_ratelimited!("resuming sending messages");
3598 info.paused = false;
3599 }
3600 _ => {}
3601 }
3602
3603 match msg {
3604 Message::InitiateContact2(input, ..) => {
3605 self.handle_initiate_contact(&input, &message, true)?
3606 }
3607 Message::InitiateContact(input, ..) => {
3608 self.handle_initiate_contact(&input.into(), &message, false)?
3609 }
3610 Message::Unload(..) => self.handle_unload(),
3611 Message::RequestOffers(..) => self.handle_request_offers()?,
3612 Message::GpadlHeader(input, range) => self.handle_gpadl_header(&input, range),
3613 Message::GpadlBody(input, range) => self.handle_gpadl_body(&input, range)?,
3614 Message::GpadlTeardown(input, ..) => self.handle_gpadl_teardown(&input)?,
3615 Message::OpenChannel(input, ..) => self.handle_open_channel(&input.into())?,
3616 Message::OpenChannel2(input, ..) => self.handle_open_channel(&input)?,
3617 Message::CloseChannel(input, ..) => self.handle_close_channel(&input)?,
3618 Message::RelIdReleased(input, ..) => self.handle_rel_id_released(&input)?,
3619 Message::TlConnectRequest(input, ..) => self.handle_tl_connect_request(input.into()),
3620 Message::TlConnectRequest2(input, ..) => self.handle_tl_connect_request(input),
3621 Message::ModifyChannel(input, ..) => self.handle_modify_channel(&input)?,
3622 Message::ModifyConnection(input, ..) => self.handle_modify_connection(input),
3623 Message::OpenReservedChannel(input, ..) => self.handle_open_reserved_channel(
3624 &input,
3625 version.expect("version validated by Message::parse"),
3626 )?,
3627 Message::CloseReservedChannel(input, ..) => {
3628 self.handle_close_reserved_channel(&input)?
3629 }
3630 Message::Pause(protocol::Pause, ..) => self.handle_pause(),
3631 Message::Resume(protocol::Resume, ..) => {}
3632 Message::OfferChannel(..)
3634 | Message::RescindChannelOffer(..)
3635 | Message::AllOffersDelivered(..)
3636 | Message::OpenResult(..)
3637 | Message::GpadlCreated(..)
3638 | Message::GpadlTorndown(..)
3639 | Message::VersionResponse(..)
3640 | Message::VersionResponse2(..)
3641 | Message::VersionResponse3(..)
3642 | Message::UnloadComplete(..)
3643 | Message::CloseReservedChannelResponse(..)
3644 | Message::TlConnectResult(..)
3645 | Message::ModifyChannelResponse(..)
3646 | Message::ModifyConnectionResponse(..)
3647 | Message::PauseResponse(..) => {
3648 unreachable!("Server received client message {:?}", msg);
3649 }
3650 }
3651 Ok(())
3652 }
3653
3654 pub fn gpadl_create_complete(&mut self, offer_id: OfferId, gpadl_id: GpadlId, status: i32) {
3656 let Some(gpadl) = self.inner.gpadls.get_mut(&(gpadl_id, offer_id)) else {
3657 tracelimit::error_ratelimited!(
3658 ?offer_id,
3659 key = %self.inner.channels[offer_id].offer.key(),
3660 ?gpadl_id,
3661 "invalid gpadl ID for channel"
3662 );
3663 return;
3664 };
3665 let retain = match gpadl.state {
3666 GpadlState::InProgress | GpadlState::TearingDown | GpadlState::Accepted => {
3667 tracelimit::error_ratelimited!(?offer_id, ?gpadl_id, ?gpadl, "invalid gpadl state");
3668 return;
3669 }
3670 GpadlState::Offered => {
3671 let channel_id = self.inner.channels[offer_id]
3672 .info
3673 .as_ref()
3674 .expect("assigned")
3675 .channel_id;
3676 self.inner
3677 .pending_messages
3678 .sender(self.notifier, self.inner.state.is_paused())
3679 .send_gpadl_created(channel_id, gpadl_id, status);
3680 if status >= 0 {
3681 gpadl.state = GpadlState::Accepted;
3682 true
3683 } else {
3684 false
3685 }
3686 }
3687 GpadlState::OfferedTearingDown => {
3688 if status >= 0 {
3689 self.notifier.notify(
3691 offer_id,
3692 Action::TeardownGpadl {
3693 gpadl_id,
3694 post_restore: false,
3695 },
3696 );
3697 gpadl.state = GpadlState::TearingDown;
3698 true
3699 } else {
3700 false
3701 }
3702 }
3703 };
3704 if !retain {
3705 self.inner
3706 .gpadls
3707 .remove(&(gpadl_id, offer_id))
3708 .expect("gpadl validated above");
3709
3710 self.check_disconnected();
3711 }
3712 }
3713
3714 pub fn gpadl_teardown_complete(&mut self, offer_id: OfferId, gpadl_id: GpadlId) {
3716 let channel = &mut self.inner.channels[offer_id];
3717 let Some(gpadl) = self.inner.gpadls.get_mut(&(gpadl_id, offer_id)) else {
3718 tracelimit::error_ratelimited!(
3719 ?offer_id,
3720 key = %channel.offer.key(),
3721 ?gpadl_id,
3722 "invalid gpadl ID for channel"
3723 );
3724 return;
3725 };
3726 tracing::debug!(
3727 offer_id = offer_id.0,
3728 key = %channel.offer.key(),
3729 gpadl_id = gpadl_id.0,
3730 "Gpadl teardown complete"
3731 );
3732 match gpadl.state {
3733 GpadlState::InProgress
3734 | GpadlState::Offered
3735 | GpadlState::OfferedTearingDown
3736 | GpadlState::Accepted => {
3737 tracelimit::error_ratelimited!(?offer_id, key = %channel.offer.key(), ?gpadl_id, ?gpadl, "invalid gpadl state");
3738 }
3739 GpadlState::TearingDown => {
3740 if !channel.state.is_released() {
3741 self.sender().send_gpadl_torndown(gpadl_id);
3742 }
3743 self.inner
3744 .gpadls
3745 .remove(&(gpadl_id, offer_id))
3746 .expect("gpadl validated above");
3747
3748 self.check_disconnected();
3749 }
3750 }
3751 }
3752
3753 fn sender(&mut self) -> MessageSender<'_, N> {
3758 self.inner
3759 .pending_messages
3760 .sender(self.notifier, self.inner.state.is_paused())
3761 }
3762}
3763
3764fn revoke<N: Notifier>(
3765 mut sender: MessageSender<'_, N>,
3766 offer_id: OfferId,
3767 channel: &mut Channel,
3768 gpadls: &mut GpadlMap,
3769) -> bool {
3770 let info = match channel.state {
3771 ChannelState::Closed
3772 | ChannelState::Open { .. }
3773 | ChannelState::Opening { .. }
3774 | ChannelState::Closing { .. }
3775 | ChannelState::ClosingReopen { .. } => {
3776 channel.state = ChannelState::Revoked;
3777 Some(channel.info.as_ref().expect("assigned"))
3778 }
3779 ChannelState::Reoffered => {
3780 channel.state = ChannelState::Revoked;
3781 None
3782 }
3783 ChannelState::ClientReleased
3784 | ChannelState::OpeningClientRelease
3785 | ChannelState::ClosingClientRelease => None,
3786 ChannelState::Revoked => return true,
3788 };
3789 let retain = !channel.state.is_released();
3790
3791 gpadls.retain(|&(gpadl_id, gpadl_offer_id), gpadl| {
3793 if gpadl_offer_id != offer_id {
3794 return true;
3795 }
3796
3797 match gpadl.state {
3798 GpadlState::InProgress => true,
3799 GpadlState::Offered => {
3800 if let Some(info) = info {
3801 sender.send_gpadl_created(
3802 info.channel_id,
3803 gpadl_id,
3804 protocol::STATUS_UNSUCCESSFUL,
3805 );
3806 }
3807 false
3808 }
3809 GpadlState::OfferedTearingDown => false,
3810 GpadlState::Accepted => true,
3811 GpadlState::TearingDown => {
3812 if info.is_some() {
3813 sender.send_gpadl_torndown(gpadl_id);
3814 }
3815 false
3816 }
3817 }
3818 });
3819 if let Some(info) = info {
3820 sender.send_rescind(info);
3821 }
3822 if channel.restore_state != RestoreState::New {
3824 channel.restore_state = RestoreState::Restored;
3825 }
3826 retain
3827}
3828
3829struct PendingMessages(VecDeque<OutgoingMessage>);
3830
3831impl PendingMessages {
3832 fn sender<'a, N: Notifier>(
3834 &'a mut self,
3835 notifier: &'a mut N,
3836 is_paused: bool,
3837 ) -> MessageSender<'a, N> {
3838 MessageSender {
3839 notifier,
3840 pending_messages: self,
3841 is_paused,
3842 }
3843 }
3844}
3845
3846struct MessageSender<'a, N> {
3849 notifier: &'a mut N,
3850 pending_messages: &'a mut PendingMessages,
3851 is_paused: bool,
3852}
3853
3854impl<N: Notifier> MessageSender<'_, N> {
3855 fn send_message<
3857 T: IntoBytes + protocol::VmbusMessage + std::fmt::Debug + Immutable + KnownLayout,
3858 >(
3859 &mut self,
3860 msg: &T,
3861 ) {
3862 let message = OutgoingMessage::new(msg);
3863
3864 tracing::trace!(typ = ?T::MESSAGE_TYPE, ?msg, "sending message");
3865 if !self.pending_messages.0.is_empty()
3867 || self.is_paused
3868 || !self.notifier.send_message(&message, MessageTarget::Default)
3869 {
3870 tracing::trace!("message queued");
3871 self.pending_messages.0.push_back(message);
3873 }
3874 }
3875
3876 fn send_message_with_target<
3878 T: IntoBytes + protocol::VmbusMessage + std::fmt::Debug + Immutable + KnownLayout,
3879 >(
3880 &mut self,
3881 msg: &T,
3882 target: MessageTarget,
3883 ) {
3884 if target == MessageTarget::Default {
3885 self.send_message(msg);
3886 } else {
3887 tracing::trace!(typ = ?T::MESSAGE_TYPE, ?msg, "sending message");
3888 let message = OutgoingMessage::new(msg);
3891 if !self.notifier.send_message(&message, target) {
3892 tracelimit::warn_ratelimited!(?target, "failed to send message");
3893 }
3894 }
3895 }
3896
3897 fn send_offer(&mut self, channel: &mut Channel, connection_info: &ConnectionInfo) {
3899 let info = channel.info.as_ref().expect("assigned");
3900 let mut flags = channel.offer.flags;
3901 if !connection_info
3902 .version
3903 .feature_flags
3904 .confidential_channels()
3905 {
3906 flags.set_confidential_ring_buffer(false);
3907 flags.set_confidential_external_memory(false);
3908 }
3909
3910 let monitor_id = connection_info.monitor_page.and(info.monitor_id);
3917 let msg = protocol::OfferChannel {
3918 interface_id: channel.offer.interface_id,
3919 instance_id: channel.offer.instance_id,
3920 rsvd: [0; 4],
3921 flags,
3922 mmio_megabytes: channel.offer.mmio_megabytes,
3923 user_defined: channel.offer.user_defined,
3924 subchannel_index: channel.offer.subchannel_index,
3925 mmio_megabytes_optional: channel.offer.mmio_megabytes_optional,
3926 channel_id: info.channel_id,
3927 monitor_id: monitor_id.unwrap_or(MonitorId::INVALID).0,
3928 monitor_allocated: monitor_id.is_some().into(),
3929 is_dedicated: 1,
3932 connection_id: info.connection_id,
3933 };
3934 tracing::info!(
3935 channel_id = msg.channel_id.0,
3936 connection_id = msg.connection_id,
3937 key = %channel.offer.key(),
3938 "sending offer to guest"
3939 );
3940
3941 self.send_message(&msg);
3942 }
3943
3944 fn send_open_result(
3945 &mut self,
3946 channel_id: ChannelId,
3947 open_request: &OpenRequest,
3948 result: i32,
3949 target: MessageTarget,
3950 ) {
3951 self.send_message_with_target(
3952 &protocol::OpenResult {
3953 channel_id,
3954 open_id: open_request.open_id,
3955 status: result as u32,
3956 },
3957 target,
3958 );
3959 }
3960
3961 fn send_gpadl_created(&mut self, channel_id: ChannelId, gpadl_id: GpadlId, status: i32) {
3962 self.send_message(&protocol::GpadlCreated {
3963 channel_id,
3964 gpadl_id,
3965 status,
3966 });
3967 }
3968
3969 fn send_gpadl_torndown(&mut self, gpadl_id: GpadlId) {
3970 self.send_message(&protocol::GpadlTorndown { gpadl_id });
3971 }
3972
3973 fn send_rescind(&mut self, info: &OfferedInfo) {
3974 tracing::info!(
3975 channel_id = info.channel_id.0,
3976 "rescinding channel from guest"
3977 );
3978
3979 self.send_message(&protocol::RescindChannelOffer {
3980 channel_id: info.channel_id,
3981 });
3982 }
3983}
3984
3985struct VersionResponseData {
3987 version: VersionInfo,
3988 state: protocol::ConnectionState,
3989 monitor_pages: Option<MonitorPageGpas>,
3990}
3991
3992impl VersionResponseData {
3993 fn new(version: VersionInfo, state: protocol::ConnectionState) -> Self {
3995 VersionResponseData {
3996 version,
3997 state,
3998 monitor_pages: None,
3999 }
4000 }
4001
4002 fn with_monitor_pages(mut self, monitor_pages: Option<MonitorPageGpas>) -> Self {
4004 self.monitor_pages = monitor_pages;
4005 self
4006 }
4007}