1pub mod saved_state;
5#[cfg(test)]
6mod tests;
7
8use crate::Guid;
9use crate::SynicMessage;
10use crate::monitor::AssignedMonitors;
11use crate::protocol::Version;
12use hvdef::Vtl;
13use inspect::Inspect;
14pub use saved_state::RestoreError;
15pub use saved_state::SavedState;
16pub use saved_state::SavedStateData;
17use slab::Slab;
18use std::cmp::min;
19use std::collections::VecDeque;
20use std::collections::hash_map::Entry;
21use std::collections::hash_map::HashMap;
22use std::fmt::Display;
23use std::ops::Index;
24use std::ops::IndexMut;
25use std::task::Poll;
26use std::task::ready;
27use std::time::Duration;
28use thiserror::Error;
29use vmbus_channel::bus::ChannelType;
30use vmbus_channel::bus::GpadlRequest;
31use vmbus_channel::bus::OfferKey;
32use vmbus_channel::bus::OfferParams;
33use vmbus_channel::bus::OpenData;
34use vmbus_channel::bus::RestoredGpadl;
35use vmbus_core::HvsockConnectRequest;
36use vmbus_core::HvsockConnectResult;
37use vmbus_core::MaxVersionInfo;
38use vmbus_core::OutgoingMessage;
39use vmbus_core::VMBUS_SINT;
40use vmbus_core::VersionInfo;
41use vmbus_core::protocol;
42use vmbus_core::protocol::ChannelId;
43use vmbus_core::protocol::ConnectionId;
44use vmbus_core::protocol::FeatureFlags;
45use vmbus_core::protocol::GpadlId;
46use vmbus_core::protocol::Message;
47use vmbus_core::protocol::OfferFlags;
48use vmbus_core::protocol::UserDefinedData;
49use vmbus_ring::gparange;
50use vmcore::monitor::MonitorId;
51use vmcore::synic::MonitorInfo;
52use vmcore::synic::MonitorPageGpas;
53use zerocopy::FromZeros;
54use zerocopy::Immutable;
55use zerocopy::IntoBytes;
56use zerocopy::KnownLayout;
57
58#[derive(Debug, Error)]
60pub enum ChannelError {
61 #[error("unknown channel ID")]
62 UnknownChannelId,
63 #[error("unknown GPADL ID")]
64 UnknownGpadlId,
65 #[error("parse error")]
66 ParseError(#[from] protocol::ParseError),
67 #[error("invalid gpa range")]
68 InvalidGpaRange(#[source] gparange::Error),
69 #[error("duplicate GPADL ID")]
70 DuplicateGpadlId,
71 #[error("GPADL is already complete")]
72 GpadlAlreadyComplete,
73 #[error("GPADL channel ID mismatch")]
74 WrongGpadlChannelId,
75 #[error("trying to open an open channel")]
76 ChannelAlreadyOpen,
77 #[error("trying to close a closed channel")]
78 ChannelNotOpen,
79 #[error("invalid GPADL state for operation")]
80 InvalidGpadlState,
81 #[error("invalid channel state for operation")]
82 InvalidChannelState,
83 #[error("channel ID has already been released")]
84 ChannelReleased,
85 #[error("channel offers have already been sent")]
86 OffersAlreadySent,
87 #[error("invalid operation on reserved channel")]
88 ChannelReserved,
89 #[error("invalid operation on non-reserved channel")]
90 ChannelNotReserved,
91 #[error("received untrusted message for trusted connection")]
92 UntrustedMessage,
93 #[error("received a non-resuming message while paused")]
94 Paused,
95 #[error("invalid target VP")]
96 InvalidTargetVp,
97 #[error("interrupts are disabled for this channel")]
98 InterruptsDisabled,
99}
100
101#[derive(Debug, Error)]
102pub enum OfferError {
103 #[error("the channel ID {} is not valid for this operation", (.0).0)]
104 InvalidChannelId(ChannelId),
105 #[error("the channel ID {} is already in use", (.0).0)]
106 ChannelIdInUse(ChannelId),
107 #[error("offer {0} already exists")]
108 AlreadyExists(OfferKey),
109 #[error("specified resources do not match those of the existing saved or revoked offer")]
110 IncompatibleResources,
111 #[error("too many channels have been offered")]
112 TooManyChannels,
113 #[error("mismatched monitor ID from saved state; expected {0:?}, actual {1:?}")]
114 MismatchedMonitorId(Option<MonitorId>, MonitorId),
115}
116
117#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
119pub struct OfferId(usize);
120
121type IncompleteGpadlMap = HashMap<GpadlId, OfferId>;
122
123type GpadlMap = HashMap<(GpadlId, OfferId), Gpadl>;
124
125pub struct Server {
127 state: ConnectionState,
128 channels: ChannelList,
129 assigned_channels: AssignedChannels,
130 assigned_monitors: AssignedMonitors,
131 gpadls: GpadlMap,
132 incomplete_gpadls: IncompleteGpadlMap,
133 child_connection_id: u32,
134 max_version: Option<MaxVersionInfo>,
136 delayed_max_version: Option<MaxVersionInfo>,
140 max_restore_version: Option<MaxVersionInfo>,
143 pending_messages: PendingMessages,
146 require_server_allocated_mnf: bool,
150 use_absolute_channel_order: bool,
151}
152
153pub struct ServerWithNotifier<'a, T> {
154 inner: &'a mut Server,
155 notifier: &'a mut T,
156}
157
158impl<T> Drop for ServerWithNotifier<'_, T> {
159 fn drop(&mut self) {
160 self.inner.validate();
161 }
162}
163
164impl<T: Notifier> Inspect for ServerWithNotifier<'_, T> {
165 fn inspect(&self, req: inspect::Request<'_>) {
166 let mut resp = req.respond();
167 let (state, info, next_action) = match &self.inner.state {
168 ConnectionState::Disconnected => ("disconnected", None, None),
169 ConnectionState::Connecting { info, .. } => ("connecting", Some(info), None),
170 ConnectionState::Connected(info) => (
171 if info.offers_sent {
172 "connected"
173 } else {
174 "negotiated"
175 },
176 Some(info),
177 None,
178 ),
179 ConnectionState::Disconnecting { next_action, .. } => {
180 ("disconnecting", None, Some(next_action))
181 }
182 };
183
184 resp.field("connection_info", info);
185 let next_action = next_action.map(|a| match a {
186 ConnectionAction::None => "disconnect",
187 ConnectionAction::Reset => "reset",
188 ConnectionAction::SendUnloadComplete => "unload",
189 ConnectionAction::Reconnect { .. } => "reconnect",
190 ConnectionAction::SendFailedVersionResponse => "send_version_response",
191 });
192 resp.field("state", state)
193 .field("next_action", next_action)
194 .field(
195 "assigned_monitors_bitmap",
196 format_args!("{:x}", self.inner.assigned_monitors.bitmap()),
197 )
198 .child("channels", |req| {
199 let mut resp = req.respond();
200 self.inner
201 .channels
202 .inspect(self.notifier, self.inner.get_version(), &mut resp);
203 for ((gpadl_id, offer_id), gpadl) in &self.inner.gpadls {
204 let channel = &self.inner.channels[*offer_id];
205 resp.field(
206 &channel_inspect_path(
207 &channel.offer,
208 format_args!("/gpadls/{}", gpadl_id.0),
209 ),
210 gpadl,
211 );
212 }
213 });
214 }
215}
216
217#[derive(Debug, Copy, Clone, Inspect)]
219struct MonitorPageGpaInfo {
220 gpas: MonitorPageGpas,
221 server_allocated: bool,
222}
223
224impl MonitorPageGpaInfo {
225 fn from_guest_gpas(gpas: MonitorPageGpas) -> Self {
227 Self {
228 gpas,
229 server_allocated: false,
230 }
231 }
232
233 fn from_server_gpas(gpas: MonitorPageGpas) -> Self {
235 Self {
236 gpas,
237 server_allocated: true,
238 }
239 }
240}
241
242#[derive(Debug, Copy, Clone, Inspect)]
243struct ConnectionInfo {
244 version: VersionInfo,
245 trusted: bool,
248 offers_sent: bool,
249 interrupt_page: Option<u64>,
250 monitor_page: Option<MonitorPageGpaInfo>,
251 target_message_vp: u32,
252 modifying: bool,
253 client_id: Guid,
254 paused: bool,
255}
256
257#[derive(Debug)]
259enum ConnectionState {
260 Disconnected,
261 Disconnecting {
262 next_action: ConnectionAction,
263 modify_sent: bool,
264 },
265 Connecting {
266 info: ConnectionInfo,
267 next_action: ConnectionAction,
268 },
269 Connected(ConnectionInfo),
270}
271
272impl ConnectionState {
273 fn check_version(&self, min_version: Version) -> bool {
275 matches!(self, ConnectionState::Connected(info) if info.version.version >= min_version)
276 }
277
278 fn check_feature_flags(&self, flags: impl Fn(FeatureFlags) -> bool) -> bool {
281 matches!(self, ConnectionState::Connected(info) if flags(info.version.feature_flags))
282 }
283
284 fn get_version(&self) -> Option<VersionInfo> {
285 if let ConnectionState::Connected(info) = self {
286 Some(info.version)
287 } else {
288 None
289 }
290 }
291
292 fn get_connected_info(&self) -> Option<&ConnectionInfo> {
294 if let ConnectionState::Connected(info) = self {
295 Some(info)
296 } else {
297 None
298 }
299 }
300
301 fn is_trusted(&self) -> bool {
302 match self {
303 ConnectionState::Connected(info) => info.trusted,
304 ConnectionState::Connecting { info, .. } => info.trusted,
305 _ => false,
306 }
307 }
308
309 fn is_paused(&self) -> bool {
310 if let ConnectionState::Connected(info) = self {
311 info.paused
312 } else {
313 false
314 }
315 }
316}
317
318#[derive(Debug, Copy, Clone)]
319enum ConnectionAction {
320 None,
321 Reset,
322 SendUnloadComplete,
323 Reconnect {
324 initiate_contact: InitiateContactRequest,
325 },
326 SendFailedVersionResponse,
327}
328
329#[derive(PartialEq, Eq, Debug, Copy, Clone)]
330pub enum MonitorPageRequest {
331 None,
332 Some(MonitorPageGpas),
333 Invalid,
334}
335
336#[derive(PartialEq, Eq, Debug, Copy, Clone)]
337pub struct InitiateContactRequest {
338 pub version_requested: u32,
339 pub target_message_vp: u32,
340 pub monitor_page: MonitorPageRequest,
341 pub target_sint: u8,
342 pub target_vtl: u8,
343 pub feature_flags: u32,
344 pub interrupt_page: Option<u64>,
345 pub client_id: Guid,
346 pub trusted: bool,
347}
348
349#[derive(Debug, Copy, Clone)]
350pub struct OpenRequest {
351 pub open_id: u32,
352 pub ring_buffer_gpadl_id: GpadlId,
353 pub target_vp: Option<u32>,
354 pub downstream_ring_buffer_page_offset: u32,
355 pub user_data: UserDefinedData,
356 pub guest_specified_interrupt_info: Option<SignalInfo>,
357 pub flags: protocol::OpenChannelFlags,
358}
359
360#[derive(Debug, Copy, Clone, Eq, PartialEq)]
361pub enum Update<T: std::fmt::Debug + Copy + Clone> {
362 Unchanged,
363 Reset,
364 Set(T),
365}
366
367impl<T: std::fmt::Debug + Copy + Clone> From<Option<T>> for Update<T> {
368 fn from(value: Option<T>) -> Self {
369 match value {
370 None => Self::Reset,
371 Some(value) => Self::Set(value),
372 }
373 }
374}
375
376#[derive(Debug, Copy, Clone, Eq, PartialEq)]
377pub struct ModifyConnectionRequest {
378 pub version: Option<VersionInfo>,
379 pub monitor_page: Update<MonitorPageGpas>,
380 pub interrupt_page: Update<u64>,
381 pub target_message_vp: Option<u32>,
382 pub notify_relay: bool,
383}
384
385impl Default for ModifyConnectionRequest {
387 fn default() -> Self {
388 Self {
389 version: None,
390 monitor_page: Update::Unchanged,
391 interrupt_page: Update::Unchanged,
392 target_message_vp: None,
393 notify_relay: true,
394 }
395 }
396}
397
398impl From<protocol::ModifyConnection> for ModifyConnectionRequest {
399 fn from(value: protocol::ModifyConnection) -> Self {
400 let monitor_page = if value.parent_to_child_monitor_page_gpa != 0 {
401 Update::Set(MonitorPageGpas {
402 parent_to_child: value.parent_to_child_monitor_page_gpa,
403 child_to_parent: value.child_to_parent_monitor_page_gpa,
404 })
405 } else {
406 Update::Reset
407 };
408
409 Self {
410 monitor_page,
411 ..Default::default()
412 }
413 }
414}
415
416#[derive(Debug, Copy, Clone)]
418pub enum ModifyConnectionResponse {
419 Supported(
425 protocol::ConnectionState,
426 FeatureFlags,
427 Option<MonitorPageGpas>,
428 ),
429 Unsupported,
431 Modified(protocol::ConnectionState),
434}
435
436#[derive(Debug, Copy, Clone)]
437pub enum ModifyState {
438 NotModifying,
439 Modifying { pending_target_vp: Option<u32> },
440}
441
442impl ModifyState {
443 pub fn is_modifying(&self) -> bool {
444 matches!(self, ModifyState::Modifying { .. })
445 }
446}
447
448#[derive(Debug, Copy, Clone)]
449pub struct SignalInfo {
450 pub event_flag: u16,
451 pub connection_id: u32,
452}
453
454#[derive(Debug, Copy, Clone, PartialEq, Eq)]
455enum RestoreState {
456 New,
458 Restoring,
462 Unmatched,
465 Restored,
467}
468
469#[derive(Debug, Clone)]
471enum ChannelState {
472 ClientReleased,
476
477 Closed,
479
480 Opening {
483 request: OpenRequest,
484 reserved_state: Option<ReservedState>,
485 },
486
487 Open {
489 params: OpenRequest,
490 modify_state: ModifyState,
491 reserved_state: Option<ReservedState>,
492 },
493
494 Closing {
496 params: OpenRequest,
497 reserved_state: Option<ReservedState>,
498 },
499
500 ClosingReopen {
503 params: OpenRequest,
504 request: OpenRequest,
505 },
506
507 Revoked,
509
510 Reoffered,
513
514 ClosingClientRelease,
517
518 OpeningClientRelease,
521}
522
523impl ChannelState {
524 fn is_released(&self) -> bool {
527 match self {
528 ChannelState::Closed
529 | ChannelState::Opening { .. }
530 | ChannelState::Open { .. }
531 | ChannelState::Closing { .. }
532 | ChannelState::ClosingReopen { .. }
533 | ChannelState::Revoked
534 | ChannelState::Reoffered => false,
535
536 ChannelState::ClientReleased
537 | ChannelState::ClosingClientRelease
538 | ChannelState::OpeningClientRelease => true,
539 }
540 }
541
542 fn is_revoked(&self) -> bool {
544 match self {
545 ChannelState::Revoked | ChannelState::Reoffered => true,
546
547 ChannelState::ClientReleased
548 | ChannelState::Closed
549 | ChannelState::Opening { .. }
550 | ChannelState::Open { .. }
551 | ChannelState::Closing { .. }
552 | ChannelState::ClosingReopen { .. }
553 | ChannelState::ClosingClientRelease
554 | ChannelState::OpeningClientRelease => false,
555 }
556 }
557
558 fn is_reserved(&self) -> bool {
559 match self {
560 ChannelState::Open {
562 reserved_state: Some(_),
563 ..
564 }
565 | ChannelState::Opening {
566 reserved_state: Some(_),
567 ..
568 }
569 | ChannelState::Closing {
570 reserved_state: Some(_),
571 ..
572 } => true,
573
574 ChannelState::Opening { .. }
575 | ChannelState::Open { .. }
576 | ChannelState::Closing { .. }
577 | ChannelState::ClientReleased
578 | ChannelState::Closed
579 | ChannelState::ClosingReopen { .. }
580 | ChannelState::Revoked
581 | ChannelState::Reoffered
582 | ChannelState::ClosingClientRelease
583 | ChannelState::OpeningClientRelease => false,
584 }
585 }
586}
587
588impl Display for ChannelState {
589 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
590 let state = match self {
591 Self::ClientReleased => "ClientReleased",
592 Self::Closed => "Closed",
593 Self::Opening { .. } => "Opening",
594 Self::Open { .. } => "Open",
595 Self::Closing { .. } => "Closing",
596 Self::ClosingReopen { .. } => "ClosingReopen",
597 Self::Revoked => "Revoked",
598 Self::Reoffered => "Reoffered",
599 Self::ClosingClientRelease => "ClosingClientRelease",
600 Self::OpeningClientRelease => "OpeningClientRelease",
601 };
602 write!(f, "{}", state)
603 }
604}
605
606#[derive(Debug, Clone, Default, mesh::MeshPayload)]
608pub enum MnfUsage {
609 #[default]
611 Disabled,
612 Enabled { latency: Duration },
614 Relayed { monitor_id: u8 },
617}
618
619impl MnfUsage {
620 pub fn is_enabled(&self) -> bool {
621 matches!(self, Self::Enabled { .. })
622 }
623
624 pub fn is_relayed(&self) -> bool {
625 matches!(self, Self::Relayed { .. })
626 }
627
628 pub fn enabled_and_then<T>(&self, f: impl FnOnce(Duration) -> Option<T>) -> Option<T> {
629 if let Self::Enabled { latency } = self {
630 f(*latency)
631 } else {
632 None
633 }
634 }
635}
636
637impl From<Option<Duration>> for MnfUsage {
638 fn from(value: Option<Duration>) -> Self {
639 match value {
640 None => Self::Disabled,
641 Some(latency) => Self::Enabled { latency },
642 }
643 }
644}
645
646#[derive(Debug, Clone, Default, mesh::MeshPayload)]
647pub struct OfferParamsInternal {
648 pub interface_name: String,
650 pub instance_id: Guid,
651 pub interface_id: Guid,
652 pub mmio_megabytes: u16,
653 pub mmio_megabytes_optional: u16,
654 pub subchannel_index: u16,
655 pub use_mnf: MnfUsage,
656 pub offer_order: Option<u64>,
657 pub flags: OfferFlags,
658 pub user_defined: UserDefinedData,
659}
660
661impl OfferParamsInternal {
662 pub fn key(&self) -> OfferKey {
664 OfferKey {
665 interface_id: self.interface_id,
666 instance_id: self.instance_id,
667 subchannel_index: self.subchannel_index,
668 }
669 }
670}
671
672impl From<OfferParams> for OfferParamsInternal {
673 fn from(value: OfferParams) -> Self {
674 let mut user_defined = UserDefinedData::new_zeroed();
675
676 let mut flags = OfferFlags::new()
679 .with_confidential_ring_buffer(true)
680 .with_confidential_external_memory(value.allow_confidential_external_memory);
681
682 match value.channel_type {
683 ChannelType::Device { pipe_packets } => {
684 if pipe_packets {
685 flags.set_named_pipe_mode(true);
686 user_defined.as_pipe_params_mut().pipe_type = protocol::PipeType::MESSAGE;
687 }
688 }
689 ChannelType::Interface {
690 user_defined: interface_user_defined,
691 } => {
692 flags.set_enumerate_device_interface(true);
693 user_defined = interface_user_defined;
694 }
695 ChannelType::Pipe { message_mode } => {
696 flags.set_enumerate_device_interface(true);
697 flags.set_named_pipe_mode(true);
698 user_defined.as_pipe_params_mut().pipe_type = if message_mode {
699 protocol::PipeType::MESSAGE
700 } else {
701 protocol::PipeType::BYTE
702 };
703 }
704 ChannelType::HvSocket {
705 is_connect,
706 is_for_container,
707 silo_id,
708 } => {
709 flags.set_enumerate_device_interface(true);
710 flags.set_tlnpi_provider(true);
711 flags.set_named_pipe_mode(true);
712 *user_defined.as_hvsock_params_mut() = protocol::HvsockUserDefinedParameters::new(
713 is_connect,
714 is_for_container,
715 silo_id,
716 );
717 }
718 };
719
720 Self {
721 interface_name: value.interface_name,
722 instance_id: value.instance_id,
723 interface_id: value.interface_id,
724 mmio_megabytes: value.mmio_megabytes,
725 mmio_megabytes_optional: value.mmio_megabytes_optional,
726 subchannel_index: value.subchannel_index,
727 use_mnf: value.mnf_interrupt_latency.into(),
728 offer_order: value.offer_order,
729 user_defined,
730 flags,
731 }
732 }
733}
734
735#[derive(Debug, Copy, Clone, Inspect, PartialEq, Eq)]
736pub struct ConnectionTarget {
737 pub vp: u32,
738 pub sint: u8,
739}
740
741#[derive(Debug, Copy, Clone, PartialEq, Eq)]
742pub enum MessageTarget {
743 Default,
744 ReservedChannel(OfferId, ConnectionTarget),
745 Custom(ConnectionTarget),
746}
747
748impl MessageTarget {
749 pub fn for_offer(offer_id: OfferId, reserved_state: &Option<ReservedState>) -> Self {
750 if let Some(state) = reserved_state {
751 Self::ReservedChannel(offer_id, state.target)
752 } else {
753 Self::Default
754 }
755 }
756}
757
758#[derive(Debug, Copy, Clone)]
759pub struct ReservedState {
760 version: VersionInfo,
761 target: ConnectionTarget,
762}
763
764#[derive(Debug)]
766struct Channel {
767 info: Option<OfferedInfo>,
768 offer: OfferParamsInternal,
769 state: ChannelState,
770 restore_state: RestoreState,
771}
772
773#[derive(Debug, Copy, Clone)]
774struct OfferedInfo {
775 channel_id: ChannelId,
776 connection_id: u32,
777 monitor_id: Option<MonitorId>,
778}
779
780impl Channel {
781 fn inspect_state(&self, resp: &mut inspect::Response<'_>) {
782 let mut target_vp = None;
783 let mut event_flag = None;
784 let mut connection_id = None;
785 let mut reserved_target = None;
786 let state = match &self.state {
787 ChannelState::ClientReleased => "client_released",
788 ChannelState::Closed => "closed",
789 ChannelState::Opening { reserved_state, .. } => {
790 reserved_target = reserved_state.map(|state| state.target);
791 "opening"
792 }
793 ChannelState::Open {
794 params,
795 reserved_state,
796 ..
797 } => {
798 target_vp = Some(params.target_vp);
799 if let Some(id) = params.guest_specified_interrupt_info {
800 event_flag = Some(id.event_flag);
801 connection_id = Some(id.connection_id);
802 }
803 reserved_target = reserved_state.map(|state| state.target);
804 "open"
805 }
806 ChannelState::Closing { reserved_state, .. } => {
807 reserved_target = reserved_state.map(|state| state.target);
808 "closing"
809 }
810 ChannelState::ClosingReopen { .. } => "closing_reopen",
811 ChannelState::Revoked => "revoked",
812 ChannelState::Reoffered => "reoffered",
813 ChannelState::ClosingClientRelease => "closing_client_release",
814 ChannelState::OpeningClientRelease => "opening_client_release",
815 };
816 let restore_state = match self.restore_state {
817 RestoreState::New => "new",
818 RestoreState::Restoring => "restoring",
819 RestoreState::Restored => "restored",
820 RestoreState::Unmatched => "unmatched",
821 };
822 if let Some(info) = &self.info {
823 resp.field("channel_id", info.channel_id.0)
824 .field("offered_connection_id", info.connection_id)
825 .field("monitor_id", info.monitor_id.map(|id| id.0));
826 }
827 resp.field("state", state)
828 .field("restore_state", restore_state)
829 .field("interface_name", self.offer.interface_name.clone())
830 .display("instance_id", &self.offer.instance_id)
831 .display("interface_id", &self.offer.interface_id)
832 .field("mmio_megabytes", self.offer.mmio_megabytes)
833 .field("target_vp", target_vp)
834 .field("guest_specified_event_flag", event_flag)
835 .field("guest_specified_connection_id", connection_id)
836 .field("reserved_connection_target", reserved_target)
837 .binary("offer_flags", self.offer.flags.into_bits());
838 }
839
840 fn handled_monitor_info(&self) -> Option<MonitorInfo> {
850 self.offer.use_mnf.enabled_and_then(|latency| {
851 if self.state.is_reserved() {
852 None
853 } else {
854 self.info.and_then(|info| {
855 info.monitor_id.map(|monitor_id| MonitorInfo {
856 monitor_id,
857 latency,
858 })
859 })
860 }
861 })
862 }
863
864 fn prepare_channel(
867 &mut self,
868 offer_id: OfferId,
869 assigned_channels: &mut AssignedChannels,
870 assigned_monitors: &mut AssignedMonitors,
871 ) {
872 assert!(self.info.is_none());
873
874 let entry = assigned_channels
876 .allocate()
877 .expect("there are enough channel IDs for everything in ChannelList");
878
879 let channel_id = entry.id();
880 entry.insert(offer_id);
881 let connection_id = ConnectionId::new(channel_id.0, assigned_channels.vtl, VMBUS_SINT);
882
883 let monitor_id = match self.offer.use_mnf {
888 MnfUsage::Enabled { .. } => {
889 let monitor_id = assigned_monitors.assign_monitor();
890 if monitor_id.is_none() {
891 tracelimit::warn_ratelimited!("Out of monitor IDs.");
892 }
893
894 monitor_id
895 }
896 MnfUsage::Relayed { monitor_id } => Some(MonitorId(monitor_id)),
897 MnfUsage::Disabled => None,
898 };
899
900 self.info = Some(OfferedInfo {
901 channel_id,
902 connection_id: connection_id.0,
903 monitor_id,
904 });
905 }
906
907 fn release_channel(
909 &mut self,
910 offer_id: OfferId,
911 assigned_channels: &mut AssignedChannels,
912 assigned_monitors: &mut AssignedMonitors,
913 ) {
914 if let Some(info) = self.info.take() {
915 assigned_channels.free(info.channel_id, offer_id);
916
917 if let Some(monitor_id) = info.monitor_id {
919 if self.offer.use_mnf.is_enabled() {
920 assigned_monitors.release_monitor(monitor_id);
921 }
922 }
923 }
924 }
925}
926
927#[derive(Debug)]
928struct AssignedChannels {
929 assignments: Vec<Option<OfferId>>,
930 vtl: Vtl,
931 reserved_offset: usize,
932 count_in_reserved_range: usize,
934}
935
936impl AssignedChannels {
937 fn new(vtl: Vtl, channel_id_offset: u16) -> Self {
938 Self {
939 assignments: vec![None; MAX_CHANNELS],
940 vtl,
941 reserved_offset: channel_id_offset as usize,
942 count_in_reserved_range: 0,
943 }
944 }
945
946 fn allowable_channel_count(&self) -> usize {
947 MAX_CHANNELS - self.reserved_offset + self.count_in_reserved_range
948 }
949
950 fn get(&self, channel_id: ChannelId) -> Option<OfferId> {
951 self.assignments
952 .get(Self::index(channel_id))
953 .copied()
954 .flatten()
955 }
956
957 fn set(&mut self, channel_id: ChannelId) -> Result<AssignmentEntry<'_>, OfferError> {
958 let index = Self::index(channel_id);
959 if self
960 .assignments
961 .get(index)
962 .ok_or(OfferError::InvalidChannelId(channel_id))?
963 .is_some()
964 {
965 return Err(OfferError::ChannelIdInUse(channel_id));
966 }
967 Ok(AssignmentEntry { list: self, index })
968 }
969
970 fn allocate(&mut self) -> Option<AssignmentEntry<'_>> {
971 let index = self.reserved_offset
972 + self.assignments[self.reserved_offset..]
973 .iter()
974 .position(|x| x.is_none())?;
975 Some(AssignmentEntry { list: self, index })
976 }
977
978 fn free(&mut self, channel_id: ChannelId, offer_id: OfferId) {
979 let index = Self::index(channel_id);
980 let slot = &mut self.assignments[index];
981 assert_eq!(slot.take(), Some(offer_id));
982 if index < self.reserved_offset {
983 self.count_in_reserved_range -= 1;
984 }
985 }
986
987 fn index(channel_id: ChannelId) -> usize {
988 channel_id.0.wrapping_sub(1) as usize
989 }
990}
991
992struct AssignmentEntry<'a> {
993 list: &'a mut AssignedChannels,
994 index: usize,
995}
996
997impl AssignmentEntry<'_> {
998 pub fn id(&self) -> ChannelId {
999 ChannelId(self.index as u32 + 1)
1000 }
1001
1002 pub fn insert(self, offer_id: OfferId) {
1003 assert!(
1004 self.list.assignments[self.index]
1005 .replace(offer_id)
1006 .is_none()
1007 );
1008
1009 if self.index < self.list.reserved_offset {
1010 self.list.count_in_reserved_range += 1;
1011 }
1012 }
1013}
1014
1015struct ChannelList {
1016 channels: Slab<Channel>,
1017}
1018
1019fn channel_inspect_path(offer: &OfferParamsInternal, suffix: std::fmt::Arguments<'_>) -> String {
1020 if offer.subchannel_index == 0 {
1021 format!("{}{}", offer.instance_id, suffix)
1022 } else {
1023 format!(
1024 "{}/subchannels/{}{}",
1025 offer.instance_id, offer.subchannel_index, suffix
1026 )
1027 }
1028}
1029
1030impl ChannelList {
1031 fn inspect(
1032 &self,
1033 notifier: &impl Notifier,
1034 version: Option<VersionInfo>,
1035 resp: &mut inspect::Response<'_>,
1036 ) {
1037 for (offer_id, channel) in self.iter() {
1038 resp.child(
1039 &channel_inspect_path(&channel.offer, format_args!("")),
1040 |req| {
1041 let mut resp = req.respond();
1042 channel.inspect_state(&mut resp);
1043
1044 resp.merge(inspect::adhoc(|req| {
1048 if !matches!(channel.state, ChannelState::Revoked) {
1049 notifier.inspect(version, offer_id, req);
1050 }
1051 }));
1052 },
1053 );
1054 }
1055 }
1056}
1057
1058pub const MAX_CHANNELS: usize = 2047;
1061
1062impl ChannelList {
1063 fn new() -> Self {
1064 Self {
1065 channels: Slab::new(),
1066 }
1067 }
1068
1069 fn len(&self) -> usize {
1071 self.channels.len()
1072 }
1073
1074 fn offer(&mut self, new_channel: Channel) -> OfferId {
1076 OfferId(self.channels.insert(new_channel))
1077 }
1078
1079 fn remove(&mut self, offer_id: OfferId) {
1081 let channel = self.channels.remove(offer_id.0);
1082 assert!(channel.info.is_none());
1083 }
1084
1085 fn get_by_channel_id_mut(
1087 &mut self,
1088 assigned_channels: &AssignedChannels,
1089 channel_id: ChannelId,
1090 ) -> Result<(OfferId, &mut Channel), ChannelError> {
1091 let offer_id = assigned_channels
1092 .get(channel_id)
1093 .ok_or(ChannelError::UnknownChannelId)?;
1094 let channel = &mut self[offer_id];
1095 if channel.state.is_released() {
1096 return Err(ChannelError::ChannelReleased);
1097 }
1098 assert_eq!(
1099 channel.info.as_ref().map(|info| info.channel_id),
1100 Some(channel_id)
1101 );
1102 Ok((offer_id, channel))
1103 }
1104
1105 fn get_by_channel_id(
1107 &self,
1108 assigned_channels: &AssignedChannels,
1109 channel_id: ChannelId,
1110 ) -> Result<(OfferId, &Channel), ChannelError> {
1111 let offer_id = assigned_channels
1112 .get(channel_id)
1113 .ok_or(ChannelError::UnknownChannelId)?;
1114 let channel = &self[offer_id];
1115 if channel.state.is_released() {
1116 return Err(ChannelError::ChannelReleased);
1117 }
1118 assert_eq!(
1119 channel.info.as_ref().map(|info| info.channel_id),
1120 Some(channel_id)
1121 );
1122 Ok((offer_id, channel))
1123 }
1124
1125 fn get_by_key_mut(&mut self, key: &OfferKey) -> Option<(OfferId, &mut Channel)> {
1128 for (offer_id, channel) in self.iter_mut() {
1129 if channel.offer.instance_id == key.instance_id
1130 && channel.offer.interface_id == key.interface_id
1131 && channel.offer.subchannel_index == key.subchannel_index
1132 {
1133 return Some((offer_id, channel));
1134 }
1135 }
1136 None
1137 }
1138
1139 fn iter(&self) -> impl Iterator<Item = (OfferId, &Channel)> {
1141 self.channels
1142 .iter()
1143 .map(|(id, channel)| (OfferId(id), channel))
1144 }
1145
1146 fn iter_mut(&mut self) -> impl Iterator<Item = (OfferId, &mut Channel)> {
1148 self.channels
1149 .iter_mut()
1150 .map(|(id, channel)| (OfferId(id), channel))
1151 }
1152
1153 fn retain<F>(&mut self, mut f: F)
1155 where
1156 F: FnMut(OfferId, &mut Channel) -> bool,
1157 {
1158 self.channels.retain(|id, channel| {
1159 let retain = f(OfferId(id), channel);
1160 if !retain {
1161 assert!(channel.info.is_none());
1162 }
1163 retain
1164 })
1165 }
1166}
1167
1168impl Index<OfferId> for ChannelList {
1169 type Output = Channel;
1170
1171 fn index(&self, offer_id: OfferId) -> &Self::Output {
1172 &self.channels[offer_id.0]
1173 }
1174}
1175
1176impl IndexMut<OfferId> for ChannelList {
1177 fn index_mut(&mut self, offer_id: OfferId) -> &mut Self::Output {
1178 &mut self.channels[offer_id.0]
1179 }
1180}
1181
1182#[derive(Debug, Inspect)]
1184struct Gpadl {
1185 count: u16,
1186 #[inspect(skip)]
1187 buf: Vec<u64>,
1188 state: GpadlState,
1189}
1190
1191#[derive(Debug, Copy, Clone, PartialEq, Eq, Inspect)]
1192enum GpadlState {
1193 InProgress,
1195 Offered,
1197 OfferedTearingDown,
1200 Accepted,
1202 TearingDown,
1204}
1205
1206impl Gpadl {
1207 fn new(count: u16, len: usize) -> Self {
1210 Self {
1211 state: GpadlState::InProgress,
1212 count,
1213 buf: Vec::with_capacity(len),
1214 }
1215 }
1216
1217 fn append(&mut self, data: &[u8]) -> Result<bool, ChannelError> {
1219 if self.state == GpadlState::InProgress {
1220 let buf = &mut self.buf;
1221 let len = min(data.len() & !7, (buf.capacity() - buf.len()) * 8);
1226 let data = &data[..len];
1227 let start = buf.len();
1228 buf.resize(buf.len() + data.len() / 8, 0);
1229 buf[start..].as_mut_bytes().copy_from_slice(data);
1230 Ok(if buf.len() == buf.capacity() {
1231 gparange::validate_gpa_ranges(self.count as usize, buf)
1232 .map_err(ChannelError::InvalidGpaRange)?;
1233 self.state = GpadlState::Offered;
1234 true
1235 } else {
1236 false
1237 })
1238 } else {
1239 Err(ChannelError::GpadlAlreadyComplete)
1240 }
1241 }
1242}
1243
1244#[derive(Debug, Copy, Clone)]
1246pub struct OpenParams {
1247 pub open_data: OpenData,
1248 pub connection_id: u32,
1249 pub event_flag: u16,
1250 pub monitor_info: Option<MonitorInfo>,
1251 pub flags: protocol::OpenChannelFlags,
1252 pub reserved_target: Option<ConnectionTarget>,
1253 pub channel_id: ChannelId,
1254}
1255
1256impl OpenParams {
1257 fn from_request(
1258 info: &OfferedInfo,
1259 request: &OpenRequest,
1260 monitor_info: Option<MonitorInfo>,
1261 reserved_target: Option<ConnectionTarget>,
1262 ) -> Self {
1263 let (event_flag, connection_id) = if let Some(id) = request.guest_specified_interrupt_info {
1266 (id.event_flag, id.connection_id)
1267 } else {
1268 (info.channel_id.0 as u16, info.connection_id)
1269 };
1270
1271 Self {
1272 open_data: OpenData {
1273 target_vp: request.target_vp,
1274 ring_offset: request.downstream_ring_buffer_page_offset,
1275 ring_gpadl_id: request.ring_buffer_gpadl_id,
1276 user_data: request.user_data,
1277 event_flag,
1278 connection_id,
1279 },
1280 connection_id,
1281 event_flag,
1282 monitor_info: request.target_vp.and(monitor_info),
1284 flags: request.flags.with_unused(0),
1285 reserved_target,
1286 channel_id: info.channel_id,
1287 }
1288 }
1289}
1290
1291#[derive(Debug)]
1293pub enum Action {
1294 Open(OpenParams, VersionInfo),
1295 Close,
1296 Gpadl(GpadlId, u16, Vec<u64>),
1297 TeardownGpadl {
1298 gpadl_id: GpadlId,
1299 post_restore: bool,
1300 },
1301 Modify {
1302 target_vp: u32,
1303 },
1304}
1305
1306static SUPPORTED_VERSIONS: &[Version] = &[
1308 Version::V1,
1309 Version::Win7,
1310 Version::Win8,
1311 Version::Win8_1,
1312 Version::Win10,
1313 Version::Win10Rs3_0,
1314 Version::Win10Rs3_1,
1315 Version::Win10Rs4,
1316 Version::Win10Rs5,
1317 Version::Iron,
1318 Version::Copper,
1319];
1320
1321const SUPPORTED_FEATURE_FLAGS: FeatureFlags = FeatureFlags::new()
1324 .with_guest_specified_signal_parameters(true)
1325 .with_channel_interrupt_redirection(true)
1326 .with_modify_connection(true)
1327 .with_client_id(true)
1328 .with_pause_resume(true)
1329 .with_server_specified_monitor_pages(true);
1330
1331pub trait Notifier: Send {
1333 fn notify(&mut self, offer_id: OfferId, action: Action);
1335
1336 fn forward_unhandled(&mut self, request: InitiateContactRequest);
1338
1339 fn modify_connection(&mut self, request: ModifyConnectionRequest) -> anyhow::Result<()>;
1345
1346 fn inspect(&self, version: Option<VersionInfo>, offer_id: OfferId, req: inspect::Request<'_>) {
1348 let _ = (version, offer_id, req);
1349 }
1350
1351 #[must_use]
1354 fn send_message(&mut self, message: &OutgoingMessage, target: MessageTarget) -> bool;
1355
1356 fn notify_hvsock(&mut self, request: &HvsockConnectRequest);
1358
1359 fn reset_complete(&mut self);
1361
1362 fn unload_complete(&mut self);
1364}
1365
1366impl Server {
1367 pub fn new(
1369 vtl: Vtl,
1370 child_connection_id: u32,
1371 channel_id_offset: u16,
1372 use_absolute_channel_order: bool,
1373 ) -> Self {
1374 Server {
1375 state: ConnectionState::Disconnected,
1376 channels: ChannelList::new(),
1377 assigned_channels: AssignedChannels::new(vtl, channel_id_offset),
1378 assigned_monitors: AssignedMonitors::new(),
1379 gpadls: Default::default(),
1380 incomplete_gpadls: Default::default(),
1381 child_connection_id,
1382 max_version: None,
1383 delayed_max_version: None,
1384 max_restore_version: None,
1385 pending_messages: PendingMessages(VecDeque::new()),
1386 require_server_allocated_mnf: false,
1387 use_absolute_channel_order,
1388 }
1389 }
1390
1391 pub fn with_notifier<'a, T: Notifier>(
1393 &'a mut self,
1394 notifier: &'a mut T,
1395 ) -> ServerWithNotifier<'a, T> {
1396 self.validate();
1397 ServerWithNotifier {
1398 inner: self,
1399 notifier,
1400 }
1401 }
1402
1403 pub fn set_require_server_allocated_mnf(&mut self, require: bool) {
1406 self.require_server_allocated_mnf = require;
1407 }
1408
1409 fn validate(&self) {
1410 #[cfg(debug_assertions)]
1411 for (_, channel) in self.channels.iter() {
1412 let should_have_info = !channel.state.is_released();
1413 if channel.info.is_some() != should_have_info {
1414 panic!("channel invariant violation: {channel:?}");
1415 }
1416 }
1417 }
1418
1419 pub fn set_compatibility_version(&mut self, version: MaxVersionInfo, delay: bool) {
1424 if delay {
1425 self.delayed_max_version = Some(version)
1426 } else {
1427 tracing::info!(?version, "Limiting VmBus connections to version");
1428 self.max_version = Some(version);
1429 }
1430 }
1431
1432 pub fn set_restore_compatibility_version(&mut self, version: MaxVersionInfo) {
1440 tracing::info!(?version, "Limiting VmBus restore to version");
1441 self.max_restore_version = Some(version);
1442 }
1443
1444 pub fn channel_gpadls(&self, offer_id: OfferId) -> Vec<RestoredGpadl> {
1445 self.gpadls
1446 .iter()
1447 .filter_map(|(&(gpadl_id, gpadl_offer_id), gpadl)| {
1448 if offer_id != gpadl_offer_id {
1449 return None;
1450 }
1451 let accepted = match gpadl.state {
1452 GpadlState::Offered | GpadlState::OfferedTearingDown => false,
1453 GpadlState::Accepted => true,
1454 GpadlState::InProgress | GpadlState::TearingDown => return None,
1455 };
1456 Some(RestoredGpadl {
1457 request: GpadlRequest {
1458 id: gpadl_id,
1459 count: gpadl.count,
1460 buf: gpadl.buf.clone(),
1461 },
1462 accepted,
1463 })
1464 })
1465 .collect()
1466 }
1467
1468 pub fn get_version(&self) -> Option<VersionInfo> {
1469 self.state.get_version()
1470 }
1471
1472 pub fn get_restore_open_params(&self, offer_id: OfferId) -> Result<OpenParams, RestoreError> {
1473 let channel = &self.channels[offer_id];
1474
1475 match channel.restore_state {
1477 RestoreState::New => {
1478 return Err(RestoreError::MissingChannel(channel.offer.key()));
1482 }
1483 RestoreState::Restoring => {}
1484 RestoreState::Unmatched => unreachable!(),
1485 RestoreState::Restored => {
1486 return Err(RestoreError::AlreadyRestored(channel.offer.key()));
1487 }
1488 }
1489
1490 let info = channel
1491 .info
1492 .ok_or_else(|| RestoreError::MissingChannel(channel.offer.key()))?;
1493
1494 let (request, reserved_state) = match channel.state {
1495 ChannelState::Closed => {
1496 return Err(RestoreError::MismatchedOpenState(channel.offer.key()));
1497 }
1498 ChannelState::Closing { params, .. } | ChannelState::ClosingReopen { params, .. } => {
1499 (params, None)
1500 }
1501 ChannelState::Opening {
1502 request,
1503 reserved_state,
1504 } => (request, reserved_state),
1505 ChannelState::Open {
1506 params,
1507 reserved_state,
1508 ..
1509 } => (params, reserved_state),
1510 ChannelState::ClientReleased | ChannelState::Reoffered => {
1511 return Err(RestoreError::MissingChannel(channel.offer.key()));
1512 }
1513 ChannelState::Revoked
1514 | ChannelState::ClosingClientRelease
1515 | ChannelState::OpeningClientRelease => unreachable!(),
1516 };
1517
1518 Ok(OpenParams::from_request(
1519 &info,
1520 &request,
1521 channel.handled_monitor_info(),
1522 reserved_state.map(|state| state.target),
1523 ))
1524 }
1525
1526 pub fn has_pending_messages(&self) -> bool {
1528 !self.pending_messages.0.is_empty() && !self.state.is_paused()
1529 }
1530
1531 pub fn poll_flush_pending_messages(
1533 &mut self,
1534 mut send: impl FnMut(&OutgoingMessage) -> Poll<()>,
1535 ) -> Poll<()> {
1536 if !self.state.is_paused() {
1537 while let Some(message) = self.pending_messages.0.front() {
1538 ready!(send(message));
1539 self.pending_messages.0.pop_front();
1540 }
1541 }
1542
1543 Poll::Ready(())
1544 }
1545}
1546
1547impl<'a, N: 'a + Notifier> ServerWithNotifier<'a, N> {
1548 pub fn restore_channel(&mut self, offer_id: OfferId, open: bool) -> Result<(), RestoreError> {
1554 let channel = &mut self.inner.channels[offer_id];
1555
1556 match channel.restore_state {
1559 RestoreState::New => {
1560 if open {
1564 return Err(RestoreError::MissingChannel(channel.offer.key()));
1565 } else {
1566 return Ok(());
1567 }
1568 }
1569 RestoreState::Restoring => {}
1570 RestoreState::Unmatched => unreachable!(),
1571 RestoreState::Restored => {
1572 return Err(RestoreError::AlreadyRestored(channel.offer.key()));
1573 }
1574 }
1575
1576 let info = channel
1577 .info
1578 .ok_or_else(|| RestoreError::MissingChannel(channel.offer.key()))?;
1579
1580 if let Some(monitor_info) = channel.handled_monitor_info() {
1581 if !self
1582 .inner
1583 .assigned_monitors
1584 .claim_monitor(monitor_info.monitor_id)
1585 {
1586 return Err(RestoreError::DuplicateMonitorId(monitor_info.monitor_id.0));
1587 }
1588 }
1589
1590 if open {
1591 match channel.state {
1592 ChannelState::Closed => {
1593 return Err(RestoreError::MismatchedOpenState(channel.offer.key()));
1594 }
1595 ChannelState::Closing { .. } | ChannelState::ClosingReopen { .. } => {
1596 self.notifier.notify(offer_id, Action::Close);
1597 }
1598 ChannelState::Opening {
1599 request,
1600 reserved_state,
1601 } => {
1602 self.inner
1603 .pending_messages
1604 .sender(self.notifier, self.inner.state.is_paused())
1605 .send_open_result(
1606 info.channel_id,
1607 &request,
1608 protocol::STATUS_SUCCESS,
1609 MessageTarget::for_offer(offer_id, &reserved_state),
1610 );
1611 channel.state = ChannelState::Open {
1612 params: request,
1613 modify_state: ModifyState::NotModifying,
1614 reserved_state,
1615 };
1616 }
1617 ChannelState::Open { .. } => {}
1618 ChannelState::ClientReleased | ChannelState::Reoffered => {
1619 return Err(RestoreError::MissingChannel(channel.offer.key()));
1620 }
1621 ChannelState::Revoked
1622 | ChannelState::ClosingClientRelease
1623 | ChannelState::OpeningClientRelease => unreachable!(),
1624 };
1625 } else {
1626 match channel.state {
1627 ChannelState::Closed => {}
1628 ChannelState::Reoffered => {}
1633 ChannelState::Closing { .. } => {
1634 channel.state = ChannelState::Closed;
1635 }
1636 ChannelState::ClosingReopen { request, .. } => {
1637 self.notifier.notify(
1638 offer_id,
1639 Action::Open(
1640 OpenParams::from_request(
1641 &info,
1642 &request,
1643 channel.handled_monitor_info(),
1644 None,
1645 ),
1646 self.inner.state.get_version().expect("must be connected"),
1647 ),
1648 );
1649 channel.state = ChannelState::Opening {
1650 request,
1651 reserved_state: None,
1652 };
1653 }
1654 ChannelState::Opening {
1655 request,
1656 reserved_state,
1657 } => {
1658 self.notifier.notify(
1659 offer_id,
1660 Action::Open(
1661 OpenParams::from_request(
1662 &info,
1663 &request,
1664 channel.handled_monitor_info(),
1665 reserved_state.map(|state| state.target),
1666 ),
1667 self.inner.state.get_version().expect("must be connected"),
1668 ),
1669 );
1670 }
1671 ChannelState::Open { .. } => {
1672 return Err(RestoreError::MismatchedOpenState(channel.offer.key()));
1673 }
1674 ChannelState::ClientReleased => {
1675 return Err(RestoreError::MissingChannel(channel.offer.key()));
1676 }
1677 ChannelState::Revoked
1678 | ChannelState::ClosingClientRelease
1679 | ChannelState::OpeningClientRelease => unreachable!(),
1680 }
1681 }
1682
1683 channel.restore_state = RestoreState::Restored;
1684 Ok(())
1685 }
1686
1687 pub fn revoke_unclaimed_channels(&mut self) {
1690 for (offer_id, channel) in self.inner.channels.iter_mut() {
1691 match channel.restore_state {
1692 RestoreState::Restored => {
1693 }
1695 RestoreState::New => {
1696 if let ConnectionState::Connected(info) = &self.inner.state {
1701 if info.offers_sent && matches!(channel.state, ChannelState::ClientReleased)
1702 {
1703 channel.prepare_channel(
1704 offer_id,
1705 &mut self.inner.assigned_channels,
1706 &mut self.inner.assigned_monitors,
1707 );
1708 channel.state = ChannelState::Closed;
1709 self.inner
1710 .pending_messages
1711 .sender(self.notifier, self.inner.state.is_paused())
1712 .send_offer(channel, info);
1713 }
1714 }
1715 }
1716 RestoreState::Restoring => {
1717 if matches!(channel.state, ChannelState::Closed) {
1725 channel.restore_state = RestoreState::Restored;
1726 } else {
1727 let retain = revoke(
1728 self.inner
1729 .pending_messages
1730 .sender(self.notifier, self.inner.state.is_paused()),
1731 offer_id,
1732 channel,
1733 &mut self.inner.gpadls,
1734 );
1735 assert!(retain, "channel has not been released");
1736 channel.state = ChannelState::Reoffered;
1737 }
1738 }
1739 RestoreState::Unmatched => {
1740 let retain = revoke(
1743 self.inner
1744 .pending_messages
1745 .sender(self.notifier, self.inner.state.is_paused()),
1746 offer_id,
1747 channel,
1748 &mut self.inner.gpadls,
1749 );
1750 assert!(retain, "channel has not been released");
1751 }
1752 }
1753 }
1754
1755 for (&(gpadl_id, offer_id), gpadl) in self.inner.gpadls.iter_mut() {
1757 match gpadl.state {
1758 GpadlState::InProgress | GpadlState::Accepted => {}
1759 GpadlState::Offered => {
1760 self.notifier.notify(
1761 offer_id,
1762 Action::Gpadl(gpadl_id, gpadl.count, gpadl.buf.clone()),
1763 );
1764 }
1765 GpadlState::TearingDown => {
1766 self.notifier.notify(
1767 offer_id,
1768 Action::TeardownGpadl {
1769 gpadl_id,
1770 post_restore: true,
1771 },
1772 );
1773 }
1774 GpadlState::OfferedTearingDown => unreachable!(),
1775 }
1776 }
1777
1778 self.check_disconnected();
1779 }
1780
1781 pub fn reset(&mut self) {
1786 assert!(!self.is_resetting());
1787 if self.request_disconnect(ConnectionAction::Reset) {
1788 self.complete_reset();
1789 }
1790 }
1791
1792 fn complete_reset(&mut self) {
1793 for (_, channel) in self.inner.channels.iter_mut() {
1795 channel.restore_state = RestoreState::New;
1796 }
1797 self.inner.pending_messages.0.clear();
1798 self.notifier.reset_complete();
1799 }
1800
1801 pub fn offer_channel(&mut self, offer: OfferParamsInternal) -> Result<OfferId, OfferError> {
1803 if let Some((offer_id, channel)) = self.inner.channels.get_by_key_mut(&offer.key()) {
1805 if channel.restore_state != RestoreState::Unmatched
1809 && !matches!(channel.state, ChannelState::Revoked)
1810 {
1811 return Err(OfferError::AlreadyExists(offer.key()));
1812 }
1813
1814 let info = channel.info.expect("assigned");
1815 if channel.restore_state == RestoreState::Unmatched {
1816 tracing::debug!(
1817 offer_id = offer_id.0,
1818 key = %channel.offer.key(),
1819 "matched channel"
1820 );
1821
1822 assert!(!matches!(channel.state, ChannelState::Revoked));
1823 channel.restore_state = RestoreState::Restoring;
1827
1828 if let MnfUsage::Relayed { monitor_id } = offer.use_mnf {
1831 if info.monitor_id != Some(MonitorId(monitor_id)) {
1832 return Err(OfferError::MismatchedMonitorId(
1833 info.monitor_id,
1834 MonitorId(monitor_id),
1835 ));
1836 }
1837 }
1838 } else {
1839 channel.state = ChannelState::Reoffered;
1843 tracing::info!(?offer_id, key = %channel.offer.key(), "channel marked for reoffer");
1844 }
1845
1846 channel.offer = offer;
1847 return Ok(offer_id);
1848 }
1849
1850 let mut connected_info = None;
1851 let state = match &self.inner.state {
1852 ConnectionState::Connected(info) => {
1853 if info.offers_sent {
1854 connected_info = Some(info);
1855 ChannelState::Closed
1856 } else {
1857 ChannelState::ClientReleased
1858 }
1859 }
1860 ConnectionState::Connecting { .. }
1861 | ConnectionState::Disconnecting { .. }
1862 | ConnectionState::Disconnected => ChannelState::ClientReleased,
1863 };
1864
1865 if self.inner.channels.len() >= self.inner.assigned_channels.allowable_channel_count() {
1867 return Err(OfferError::TooManyChannels);
1868 }
1869
1870 let key = offer.key();
1871 let confidential_ring_buffer = offer.flags.confidential_ring_buffer();
1872 let confidential_external_memory = offer.flags.confidential_external_memory();
1873 let channel = Channel {
1874 info: None,
1875 offer,
1876 state,
1877 restore_state: RestoreState::New,
1878 };
1879
1880 let offer_id = self.inner.channels.offer(channel);
1881 if let Some(info) = connected_info {
1882 let channel = &mut self.inner.channels[offer_id];
1883 channel.prepare_channel(
1884 offer_id,
1885 &mut self.inner.assigned_channels,
1886 &mut self.inner.assigned_monitors,
1887 );
1888
1889 self.inner
1890 .pending_messages
1891 .sender(self.notifier, self.inner.state.is_paused())
1892 .send_offer(channel, info);
1893 }
1894
1895 tracing::info!(?offer_id, %key, confidential_ring_buffer, confidential_external_memory, "new channel");
1896 Ok(offer_id)
1897 }
1898
1899 pub fn revoke_channel(&mut self, offer_id: OfferId) {
1901 let channel = &mut self.inner.channels[offer_id];
1902 let retain = revoke(
1903 self.inner
1904 .pending_messages
1905 .sender(self.notifier, self.inner.state.is_paused()),
1906 offer_id,
1907 channel,
1908 &mut self.inner.gpadls,
1909 );
1910 if !retain {
1911 self.inner.channels.remove(offer_id);
1912 }
1913
1914 self.check_disconnected();
1915 }
1916
1917 pub fn open_complete(&mut self, offer_id: OfferId, result: i32) {
1919 let channel = &mut self.inner.channels[offer_id];
1920 tracing::debug!(offer_id = offer_id.0, key = %channel.offer.key(), result, "open complete");
1921
1922 match channel.state {
1923 ChannelState::Opening {
1924 request,
1925 reserved_state,
1926 } => {
1927 let channel_id = channel.info.expect("assigned").channel_id;
1928 if result >= 0 {
1929 tracelimit::info_ratelimited!(
1930 offer_id = offer_id.0,
1931 channel_id = channel_id.0,
1932 key = %channel.offer.key(),
1933 result,
1934 "opened channel"
1935 );
1936 } else {
1937 tracelimit::error_ratelimited!(
1939 offer_id = offer_id.0,
1940 channel_id = channel_id.0,
1941 key = %channel.offer.key(),
1942 result,
1943 "failed to open channel"
1944 );
1945 }
1946
1947 self.inner
1948 .pending_messages
1949 .sender(self.notifier, self.inner.state.is_paused())
1950 .send_open_result(
1951 channel_id,
1952 &request,
1953 result,
1954 MessageTarget::for_offer(offer_id, &reserved_state),
1955 );
1956 channel.state = if result >= 0 {
1957 ChannelState::Open {
1958 params: request,
1959 modify_state: ModifyState::NotModifying,
1960 reserved_state,
1961 }
1962 } else {
1963 ChannelState::Closed
1964 };
1965 }
1966 ChannelState::OpeningClientRelease => {
1967 tracing::info!(
1968 offer_id = offer_id.0,
1969 key = %channel.offer.key(),
1970 result,
1971 "opened channel (client released)"
1972 );
1973
1974 if result >= 0 {
1975 channel.state = ChannelState::ClosingClientRelease;
1976 self.notifier.notify(offer_id, Action::Close);
1977 } else {
1978 channel.state = ChannelState::ClientReleased;
1979 self.check_disconnected();
1980 }
1981 }
1982
1983 ChannelState::ClientReleased
1984 | ChannelState::Closed
1985 | ChannelState::Open { .. }
1986 | ChannelState::Closing { .. }
1987 | ChannelState::ClosingReopen { .. }
1988 | ChannelState::Revoked
1989 | ChannelState::Reoffered
1990 | ChannelState::ClosingClientRelease => {
1991 tracing::error!(?offer_id, key = %channel.offer.key(), state = ?channel.state, "invalid open complete")
1992 }
1993 }
1994 }
1995
1996 fn are_channels_reset(&self, include_reserved: bool) -> bool {
1999 self.inner.gpadls.keys().all(|(_, offer_id)| {
2000 !include_reserved && self.inner.channels[*offer_id].state.is_reserved()
2001 }) && self.inner.channels.iter().all(|(_, channel)| {
2002 matches!(channel.state, ChannelState::ClientReleased)
2003 || (!include_reserved && channel.state.is_reserved())
2004 })
2005 }
2006
2007 fn check_disconnected(&mut self) {
2011 match self.inner.state {
2012 ConnectionState::Disconnecting {
2013 next_action,
2014 modify_sent: false,
2015 } => {
2016 if self.are_channels_reset(matches!(next_action, ConnectionAction::Reset)) {
2017 self.notify_disconnect(next_action);
2018 }
2019 }
2020 ConnectionState::Disconnecting {
2021 modify_sent: true, ..
2022 }
2023 | ConnectionState::Disconnected
2024 | ConnectionState::Connected { .. }
2025 | ConnectionState::Connecting { .. } => (),
2026 }
2027 }
2028
2029 fn notify_disconnect(&mut self, next_action: ConnectionAction) {
2031 debug_assert!(self.are_channels_reset(matches!(next_action, ConnectionAction::Reset)));
2033 self.inner.state = ConnectionState::Disconnecting {
2034 next_action,
2035 modify_sent: true,
2036 };
2037
2038 self.notifier
2040 .modify_connection(ModifyConnectionRequest {
2041 monitor_page: Update::Reset,
2042 interrupt_page: Update::Reset,
2043 ..Default::default()
2044 })
2045 .expect("resetting state should not fail");
2046 }
2047
2048 fn is_resetting(&self) -> bool {
2051 matches!(
2052 &self.inner.state,
2053 ConnectionState::Connecting {
2054 next_action: ConnectionAction::Reset,
2055 ..
2056 } | ConnectionState::Disconnecting {
2057 next_action: ConnectionAction::Reset,
2058 ..
2059 }
2060 )
2061 }
2062
2063 pub fn close_complete(&mut self, offer_id: OfferId) {
2065 let channel = &mut self.inner.channels[offer_id];
2066 tracing::info!(offer_id = offer_id.0, key = %channel.offer.key(), "closed channel");
2067 match channel.state {
2068 ChannelState::Closing {
2069 reserved_state: Some(reserved_state),
2070 ..
2071 } => {
2072 channel.state = ChannelState::Closed;
2073 let channel_id = channel.info.expect("assigned").channel_id;
2074 self.send_close_reserved_channel_response(
2078 channel_id,
2079 offer_id,
2080 reserved_state.target,
2081 );
2082
2083 if !matches!(self.inner.state, ConnectionState::Connected { .. }) {
2084 let channel = &mut self.inner.channels[offer_id];
2086 if Self::client_release_channel(
2089 self.inner
2090 .pending_messages
2091 .sender(self.notifier, self.inner.state.is_paused()),
2092 offer_id,
2093 channel,
2094 &mut self.inner.gpadls,
2095 &mut self.inner.incomplete_gpadls,
2096 &mut self.inner.assigned_channels,
2097 &mut self.inner.assigned_monitors,
2098 None,
2099 false,
2100 ) {
2101 self.inner.channels.remove(offer_id);
2102 }
2103 }
2104 }
2105 ChannelState::Closing { .. } => {
2106 channel.state = ChannelState::Closed;
2107 }
2108 ChannelState::ClosingClientRelease => {
2109 channel.state = ChannelState::ClientReleased;
2110 self.check_disconnected();
2111 }
2112 ChannelState::ClosingReopen { request, .. } => {
2113 channel.state = ChannelState::Closed;
2114 self.open_channel(offer_id, &request, None);
2115 }
2116
2117 ChannelState::Closed
2118 | ChannelState::ClientReleased
2119 | ChannelState::Opening { .. }
2120 | ChannelState::Open { .. }
2121 | ChannelState::Revoked
2122 | ChannelState::Reoffered
2123 | ChannelState::OpeningClientRelease => {
2124 tracing::error!(?offer_id, key = %channel.offer.key(), state = ?channel.state, "invalid close complete")
2125 }
2126 }
2127 }
2128
2129 fn send_close_reserved_channel_response(
2130 &mut self,
2131 channel_id: ChannelId,
2132 offer_id: OfferId,
2133 target: ConnectionTarget,
2134 ) {
2135 self.sender().send_message_with_target(
2136 &protocol::CloseReservedChannelResponse { channel_id },
2137 MessageTarget::ReservedChannel(offer_id, target),
2138 );
2139 }
2140
2141 fn handle_initiate_contact(
2144 &mut self,
2145 input: &protocol::InitiateContact2,
2146 message: &SynicMessage,
2147 includes_client_id: bool,
2148 ) -> Result<(), ChannelError> {
2149 let target_info =
2150 protocol::TargetInfo::from(input.initiate_contact.interrupt_page_or_target_info);
2151
2152 let target_sint = if message.multiclient
2153 && input.initiate_contact.version_requested >= Version::Win10Rs3_1 as u32
2154 {
2155 target_info.sint()
2156 } else {
2157 VMBUS_SINT
2158 };
2159
2160 let target_vtl = if message.multiclient
2161 && input.initiate_contact.version_requested >= Version::Win10Rs4 as u32
2162 {
2163 target_info.vtl()
2164 } else {
2165 0
2166 };
2167
2168 let feature_flags = if input.initiate_contact.version_requested >= Version::Copper as u32 {
2169 target_info.feature_flags()
2170 } else {
2171 0
2172 };
2173
2174 let target_message_vp =
2179 if input.initiate_contact.version_requested >= Version::Win8_1 as u32 {
2180 input.initiate_contact.target_message_vp
2181 } else {
2182 0
2183 };
2184
2185 let interrupt_page = (input.initiate_contact.version_requested < Version::Win8 as u32
2192 && input.initiate_contact.interrupt_page_or_target_info != 0)
2193 .then_some(input.initiate_contact.interrupt_page_or_target_info);
2194
2195 let monitor_page = if (input.initiate_contact.parent_to_child_monitor_page_gpa == 0)
2198 != (input.initiate_contact.child_to_parent_monitor_page_gpa == 0)
2199 {
2200 MonitorPageRequest::Invalid
2201 } else if input.initiate_contact.parent_to_child_monitor_page_gpa != 0 {
2202 MonitorPageRequest::Some(MonitorPageGpas {
2203 parent_to_child: input.initiate_contact.parent_to_child_monitor_page_gpa,
2204 child_to_parent: input.initiate_contact.child_to_parent_monitor_page_gpa,
2205 })
2206 } else {
2207 MonitorPageRequest::None
2208 };
2209
2210 let client_id = if FeatureFlags::from(feature_flags).client_id() {
2213 if includes_client_id {
2214 input.client_id
2215 } else {
2216 return Err(ChannelError::ParseError(
2217 protocol::ParseError::MessageTooSmall(Some(
2218 protocol::MessageType::INITIATE_CONTACT,
2219 )),
2220 ));
2221 }
2222 } else {
2223 Guid::ZERO
2224 };
2225
2226 let request = InitiateContactRequest {
2227 version_requested: input.initiate_contact.version_requested,
2228 target_message_vp,
2229 monitor_page,
2230 target_sint,
2231 target_vtl,
2232 feature_flags,
2233 interrupt_page,
2234 client_id,
2235 trusted: message.trusted,
2236 };
2237 self.initiate_contact(request);
2238 Ok(())
2239 }
2240
2241 pub fn initiate_contact(&mut self, request: InitiateContactRequest) {
2242 let vtl = self.inner.assigned_channels.vtl as u8;
2245 if request.target_vtl != vtl {
2246 self.notifier.forward_unhandled(request);
2248 return;
2249 }
2250
2251 if request.target_sint != VMBUS_SINT {
2252 tracelimit::warn_ratelimited!(
2253 target_vtl = request.target_vtl,
2254 target_sint = request.target_sint,
2255 version = request.version_requested,
2256 "unsupported multiclient request",
2257 );
2258
2259 self.send_version_response_with_target(
2261 None,
2262 MessageTarget::Custom(ConnectionTarget {
2263 vp: request.target_message_vp,
2264 sint: request.target_sint,
2265 }),
2266 );
2267
2268 return;
2269 }
2270
2271 if !self.request_disconnect(ConnectionAction::Reconnect {
2272 initiate_contact: request,
2273 }) {
2274 return;
2275 }
2276
2277 let Some(version) = self.check_version_supported(&request) else {
2278 tracelimit::warn_ratelimited!(
2279 vtl,
2280 version = request.version_requested,
2281 client_id = ?request.client_id,
2282 "Guest requested unsupported version"
2283 );
2284
2285 self.send_version_response(None);
2287 return;
2288 };
2289
2290 tracelimit::info_ratelimited!(
2291 vtl,
2292 ?version,
2293 client_id = ?request.client_id,
2294 trusted = request.trusted,
2295 "Guest negotiated version"
2296 );
2297
2298 let monitor_page = match request.monitor_page {
2301 MonitorPageRequest::Some(mp) => {
2302 if self.inner.require_server_allocated_mnf {
2303 if !version.feature_flags.server_specified_monitor_pages() {
2304 tracelimit::warn_ratelimited!(
2305 "guest-supplied monitor pages not supported; MNF will be disabled"
2306 );
2307 }
2308
2309 None
2310 } else {
2311 Some(mp)
2312 }
2313 }
2314 MonitorPageRequest::None => None,
2315 MonitorPageRequest::Invalid => {
2316 self.send_version_response(Some(VersionResponseData::new(
2318 version,
2319 protocol::ConnectionState::FAILED_UNKNOWN_FAILURE,
2320 )));
2321
2322 return;
2323 }
2324 };
2325
2326 self.inner.state = ConnectionState::Connecting {
2327 info: ConnectionInfo {
2328 version,
2329 trusted: request.trusted,
2330 interrupt_page: request.interrupt_page,
2331 monitor_page: monitor_page.map(MonitorPageGpaInfo::from_guest_gpas),
2332 target_message_vp: request.target_message_vp,
2333 modifying: false,
2334 offers_sent: false,
2335 client_id: request.client_id,
2336 paused: false,
2337 },
2338 next_action: ConnectionAction::None,
2339 };
2340
2341 if let Err(err) = self.notifier.modify_connection(ModifyConnectionRequest {
2344 version: Some(version),
2345 monitor_page: monitor_page.into(),
2346 interrupt_page: request.interrupt_page.into(),
2347 target_message_vp: Some(request.target_message_vp),
2348 notify_relay: true,
2349 }) {
2350 tracelimit::error_ratelimited!(?err, "server failed to change state");
2351 self.inner.state = ConnectionState::Disconnected;
2352 self.send_version_response(Some(VersionResponseData::new(
2353 version,
2354 protocol::ConnectionState::FAILED_UNKNOWN_FAILURE,
2355 )));
2356 }
2357 }
2358
2359 pub(crate) fn complete_initiate_contact(&mut self, response: ModifyConnectionResponse) {
2360 let ConnectionState::Connecting {
2361 mut info,
2362 next_action,
2363 } = self.inner.state
2364 else {
2365 panic!("Invalid state for completing InitiateContact.");
2366 };
2367
2368 const LOCAL_FEATURE_FLAGS: FeatureFlags = FeatureFlags::new()
2372 .with_client_id(true)
2373 .with_confidential_channels(true);
2374
2375 let (relay_feature_flags, server_specified_monitor_page) = match response {
2376 ModifyConnectionResponse::Supported(
2378 protocol::ConnectionState::SUCCESSFUL,
2379 feature_flags,
2380 server_specified_monitor_page,
2381 ) => (feature_flags, server_specified_monitor_page),
2382 ModifyConnectionResponse::Supported(
2385 connection_state,
2386 feature_flags,
2387 server_specified_monitor_page,
2388 ) => {
2389 tracelimit::error_ratelimited!(
2390 ?connection_state,
2391 "initiate contact failed because relay request failed"
2392 );
2393
2394 info.version.feature_flags &= (feature_flags | LOCAL_FEATURE_FLAGS)
2397 .with_server_specified_monitor_pages(server_specified_monitor_page.is_some());
2398
2399 self.send_version_response(Some(VersionResponseData::new(
2400 info.version,
2401 connection_state,
2402 )));
2403 self.inner.state = ConnectionState::Disconnected;
2404 return;
2405 }
2406 ModifyConnectionResponse::Unsupported => {
2409 self.send_version_response(None);
2410 self.inner.state = ConnectionState::Disconnected;
2411 return;
2412 }
2413 ModifyConnectionResponse::Modified(_) => {
2414 panic!("Invalid response for completing InitiateContact.");
2415 }
2416 };
2417
2418 assert!(
2420 info.version.feature_flags.server_specified_monitor_pages()
2421 || server_specified_monitor_page.is_none()
2422 );
2423
2424 info.version.feature_flags &= relay_feature_flags | LOCAL_FEATURE_FLAGS;
2427
2428 if let Some(gpas) = server_specified_monitor_page {
2432 info.monitor_page = Some(MonitorPageGpaInfo::from_server_gpas(gpas));
2433 info.version
2434 .feature_flags
2435 .set_server_specified_monitor_pages(true);
2436 } else {
2437 info.version
2438 .feature_flags
2439 .set_server_specified_monitor_pages(false);
2440 }
2441
2442 let version = info.version;
2443 self.inner.state = ConnectionState::Connected(info);
2444
2445 self.send_version_response(Some(
2446 VersionResponseData::new(version, protocol::ConnectionState::SUCCESSFUL)
2447 .with_monitor_pages(server_specified_monitor_page),
2448 ));
2449 if !matches!(next_action, ConnectionAction::None) && self.request_disconnect(next_action) {
2450 self.do_next_action(next_action);
2451 }
2452 }
2453
2454 fn check_version_supported(&self, request: &InitiateContactRequest) -> Option<VersionInfo> {
2456 let version = SUPPORTED_VERSIONS
2457 .iter()
2458 .find(|v| request.version_requested == **v as u32)
2459 .copied()?;
2460
2461 if let Some(max_version) = self.inner.max_version {
2463 if version as u32 > max_version.version {
2464 return None;
2465 }
2466 }
2467
2468 let supported_flags = if version >= Version::Copper {
2469 let max_supported_flags =
2471 SUPPORTED_FEATURE_FLAGS.with_confidential_channels(request.trusted);
2472
2473 if let Some(max_version) = self.inner.max_version {
2475 max_supported_flags & max_version.feature_flags
2476 } else {
2477 max_supported_flags
2478 }
2479 } else {
2480 FeatureFlags::new()
2481 };
2482
2483 let feature_flags = supported_flags & request.feature_flags.into();
2484
2485 assert!(version >= Version::Copper || feature_flags == FeatureFlags::new());
2486 if feature_flags.into_bits() != request.feature_flags {
2487 tracelimit::info_ratelimited!(
2490 supported = feature_flags.into_bits(),
2491 requested = request.feature_flags,
2492 "guest requested unsupported feature flags."
2493 );
2494 }
2495
2496 Some(VersionInfo {
2497 version,
2498 feature_flags,
2499 })
2500 }
2501
2502 fn send_version_response(&mut self, data: Option<VersionResponseData>) {
2503 self.send_version_response_with_target(data, MessageTarget::Default);
2504 }
2505
2506 fn send_version_response_with_target(
2507 &mut self,
2508 data: Option<VersionResponseData>,
2509 target: MessageTarget,
2510 ) {
2511 enum VersionResponseType {
2512 PreCopper,
2513 Copper,
2514 CopperWithServerMnf,
2515 }
2516
2517 let mut response_copper_with_mnf = protocol::VersionResponse3::new_zeroed();
2518 let response_copper = &mut response_copper_with_mnf.version_response2;
2519 let response = &mut response_copper.version_response;
2520 let mut response_type = VersionResponseType::PreCopper;
2521 if let Some(data) = data {
2522 if data.state == protocol::ConnectionState::SUCCESSFUL
2525 || data.version.version >= Version::Win8
2526 {
2527 response.version_supported = 1;
2528 response.connection_state = data.state;
2529 response.selected_version_or_connection_id =
2530 if data.version.version >= Version::Win10Rs3_1 {
2531 self.inner.child_connection_id
2532 } else {
2533 data.version.version as u32
2534 };
2535
2536 if data.version.version >= Version::Copper {
2537 response_copper.supported_features = data.version.feature_flags.into();
2538 response_type = VersionResponseType::Copper;
2539 if let Some(monitor_page) = data.monitor_pages {
2540 assert!(data.version.feature_flags.server_specified_monitor_pages());
2541 response_copper_with_mnf.child_to_parent_monitor_page_gpa =
2542 monitor_page.child_to_parent;
2543 response_copper_with_mnf.parent_to_child_monitor_page_gpa =
2544 monitor_page.parent_to_child;
2545 response_type = VersionResponseType::CopperWithServerMnf;
2546 }
2547 }
2548 }
2549 }
2550
2551 match response_type {
2553 VersionResponseType::PreCopper => {
2554 self.sender().send_message_with_target(response, target)
2555 }
2556 VersionResponseType::Copper => self
2557 .sender()
2558 .send_message_with_target(response_copper, target),
2559 VersionResponseType::CopperWithServerMnf => self
2560 .sender()
2561 .send_message_with_target(&response_copper_with_mnf, target),
2562 }
2563 }
2564
2565 fn request_disconnect(&mut self, new_action: ConnectionAction) -> bool {
2568 assert!(!self.is_resetting());
2569
2570 let gpadls = &mut self.inner.gpadls;
2572 let vm_reset = matches!(new_action, ConnectionAction::Reset);
2573 self.inner.channels.retain(|offer_id, channel| {
2574 (!vm_reset && channel.state.is_reserved())
2576 || !Self::client_release_channel(
2577 self.inner
2578 .pending_messages
2579 .sender(self.notifier, self.inner.state.is_paused()),
2580 offer_id,
2581 channel,
2582 gpadls,
2583 &mut self.inner.incomplete_gpadls,
2584 &mut self.inner.assigned_channels,
2585 &mut self.inner.assigned_monitors,
2586 None,
2587 vm_reset,
2588 )
2589 });
2590
2591 match &mut self.inner.state {
2595 ConnectionState::Disconnected => {
2596 if vm_reset {
2598 if !self.are_channels_reset(true) {
2599 self.inner.state = ConnectionState::Disconnecting {
2600 next_action: ConnectionAction::Reset,
2601 modify_sent: false,
2602 };
2603 }
2604 } else {
2605 assert!(self.are_channels_reset(false));
2606 }
2607 }
2608
2609 ConnectionState::Connected { .. } => {
2610 if self.are_channels_reset(vm_reset) {
2611 self.notify_disconnect(new_action);
2612 } else {
2613 self.inner.state = ConnectionState::Disconnecting {
2614 next_action: new_action,
2615 modify_sent: false,
2616 };
2617 }
2618 }
2619
2620 ConnectionState::Connecting { next_action, .. }
2621 | ConnectionState::Disconnecting { next_action, .. } => {
2622 *next_action = new_action;
2623 }
2624 }
2625
2626 matches!(self.inner.state, ConnectionState::Disconnected)
2627 }
2628
2629 pub(crate) fn complete_disconnect(&mut self) {
2630 if let ConnectionState::Disconnecting {
2631 next_action,
2632 modify_sent,
2633 } = std::mem::replace(&mut self.inner.state, ConnectionState::Disconnected)
2634 {
2635 assert!(self.are_channels_reset(matches!(next_action, ConnectionAction::Reset)));
2636 if !modify_sent {
2637 tracelimit::warn_ratelimited!("unexpected modify response");
2638 }
2639
2640 self.inner.state = ConnectionState::Disconnected;
2641 self.do_next_action(next_action);
2642 } else {
2643 unreachable!("not ready for disconnect");
2644 }
2645 }
2646
2647 fn do_next_action(&mut self, action: ConnectionAction) {
2648 match action {
2649 ConnectionAction::None => {}
2650 ConnectionAction::Reset => {
2651 self.complete_reset();
2652 }
2653 ConnectionAction::SendUnloadComplete => {
2654 self.complete_unload();
2655 }
2656 ConnectionAction::Reconnect { initiate_contact } => {
2657 self.initiate_contact(initiate_contact);
2658 }
2659 ConnectionAction::SendFailedVersionResponse => {
2660 self.send_version_response(None);
2663 }
2664 }
2665 }
2666
2667 fn handle_unload(&mut self) {
2669 tracing::debug!(
2670 vtl = self.inner.assigned_channels.vtl as u8,
2671 state = ?self.inner.state,
2672 "VmBus received unload request from guest",
2673 );
2674
2675 if self.request_disconnect(ConnectionAction::SendUnloadComplete) {
2676 self.complete_unload();
2677 }
2678 }
2679
2680 fn complete_unload(&mut self) {
2681 self.notifier.unload_complete();
2682 if let Some(version) = self.inner.delayed_max_version.take() {
2683 self.inner.set_compatibility_version(version, false);
2684 }
2685
2686 self.sender().send_message(&protocol::UnloadComplete {});
2687 tracelimit::info_ratelimited!("Vmbus disconnected");
2688 }
2689
2690 fn handle_request_offers(&mut self) -> Result<(), ChannelError> {
2692 let ConnectionState::Connected(info) = &mut self.inner.state else {
2693 unreachable!(
2694 "in unexpected state {:?}, should be prevented by Message::parse()",
2695 self.inner.state
2696 );
2697 };
2698
2699 if info.offers_sent {
2700 return Err(ChannelError::OffersAlreadySent);
2701 }
2702
2703 info.offers_sent = true;
2704
2705 let mut sorted_channels: Vec<_> = self
2708 .inner
2709 .channels
2710 .iter_mut()
2711 .filter(|(_, channel)| !channel.state.is_reserved())
2712 .collect();
2713
2714 if self.inner.use_absolute_channel_order {
2715 sorted_channels.sort_unstable_by_key(|(_, channel)| {
2716 (
2717 channel.offer.offer_order.unwrap_or(u64::MAX),
2718 channel.offer.interface_id,
2719 channel.offer.instance_id,
2720 )
2721 });
2722 } else {
2723 sorted_channels.sort_unstable_by_key(|(_, channel)| {
2724 (
2725 channel.offer.interface_id,
2726 channel.offer.offer_order.unwrap_or(u64::MAX),
2727 channel.offer.instance_id,
2728 )
2729 });
2730 }
2731
2732 for (offer_id, channel) in sorted_channels {
2733 assert!(matches!(channel.state, ChannelState::ClientReleased));
2734
2735 channel.prepare_channel(
2736 offer_id,
2737 &mut self.inner.assigned_channels,
2738 &mut self.inner.assigned_monitors,
2739 );
2740
2741 channel.state = ChannelState::Closed;
2742 self.inner
2743 .pending_messages
2744 .sender(self.notifier, info.paused)
2745 .send_offer(channel, info);
2746 }
2747 self.sender().send_message(&protocol::AllOffersDelivered {});
2748
2749 Ok(())
2750 }
2751
2752 #[must_use]
2755 fn gpadl_updated(
2756 mut sender: MessageSender<'_, N>,
2757 offer_id: OfferId,
2758 channel: &Channel,
2759 gpadl_id: GpadlId,
2760 gpadl: &Gpadl,
2761 ) -> bool {
2762 if channel.state.is_revoked() {
2763 let channel_id = channel.info.as_ref().expect("assigned").channel_id;
2764 sender.send_gpadl_created(channel_id, gpadl_id, protocol::STATUS_UNSUCCESSFUL);
2765 false
2766 } else {
2767 sender.notifier.notify(
2769 offer_id,
2770 Action::Gpadl(gpadl_id, gpadl.count, gpadl.buf.clone()),
2771 );
2772 true
2773 }
2774 }
2775
2776 fn handle_gpadl_header_core(
2778 &mut self,
2779 input: &protocol::GpadlHeader,
2780 range: &[u8],
2781 ) -> Result<(), ChannelError> {
2782 let (offer_id, channel) = self
2784 .inner
2785 .channels
2786 .get_by_channel_id_mut(&self.inner.assigned_channels, input.channel_id)?;
2787
2788 if channel.state.is_reserved() {
2791 return Err(ChannelError::ChannelReserved);
2792 }
2793
2794 let mut gpadl = Gpadl::new(input.count, input.len as usize / 8);
2796 let done = gpadl.append(range)?;
2797
2798 let gpadl = match self.inner.gpadls.entry((input.gpadl_id, offer_id)) {
2800 Entry::Vacant(entry) => entry.insert(gpadl),
2801 Entry::Occupied(_) => return Err(ChannelError::DuplicateGpadlId),
2802 };
2803
2804 if !done {
2810 match self.inner.incomplete_gpadls.entry(input.gpadl_id) {
2811 Entry::Vacant(entry) => {
2812 entry.insert(offer_id);
2813 }
2814 Entry::Occupied(_) => {
2815 self.inner.gpadls.remove(&(input.gpadl_id, offer_id));
2816 tracelimit::error_ratelimited!(
2817 channel_id = ?input.channel_id,
2818 key = %channel.offer.key(),
2819 gpadl_id = ?input.gpadl_id,
2820 "duplicate in-progress gpadl ID",
2821 );
2822 return Err(ChannelError::DuplicateGpadlId);
2823 }
2824 }
2825 }
2826
2827 if done
2828 && !Self::gpadl_updated(
2829 self.inner
2830 .pending_messages
2831 .sender(self.notifier, self.inner.state.is_paused()),
2832 offer_id,
2833 channel,
2834 input.gpadl_id,
2835 gpadl,
2836 )
2837 {
2838 self.inner.gpadls.remove(&(input.gpadl_id, offer_id));
2839 }
2840 Ok(())
2841 }
2842
2843 fn handle_gpadl_header(&mut self, input: &protocol::GpadlHeader, range: &[u8]) {
2845 if let Err(err) = self.handle_gpadl_header_core(input, range) {
2846 tracelimit::warn_ratelimited!(
2847 err = &err as &dyn std::error::Error,
2848 channel_id = ?input.channel_id,
2849 key = %self.inner.channels.get_by_channel_id(&self.inner.assigned_channels, input.channel_id).map(|(_, c)| c.offer.key()).unwrap_or_default(),
2850 gpadl_id = ?input.gpadl_id,
2851 "error handling gpadl header"
2852 );
2853
2854 self.sender().send_gpadl_created(
2856 input.channel_id,
2857 input.gpadl_id,
2858 protocol::STATUS_UNSUCCESSFUL,
2859 );
2860 }
2861 }
2862
2863 fn handle_gpadl_body(
2869 &mut self,
2870 input: &protocol::GpadlBody,
2871 range: &[u8],
2872 ) -> Result<(), ChannelError> {
2873 let &offer_id = self
2877 .inner
2878 .incomplete_gpadls
2879 .get(&input.gpadl_id)
2880 .ok_or(ChannelError::UnknownGpadlId)?;
2881 let gpadl = self
2882 .inner
2883 .gpadls
2884 .get_mut(&(input.gpadl_id, offer_id))
2885 .ok_or(ChannelError::UnknownGpadlId)?;
2886 let channel = &mut self.inner.channels[offer_id];
2887
2888 match gpadl.append(range) {
2889 Ok(done) => {
2890 if done {
2891 self.inner.incomplete_gpadls.remove(&input.gpadl_id);
2892 if !Self::gpadl_updated(
2893 self.inner
2894 .pending_messages
2895 .sender(self.notifier, self.inner.state.is_paused()),
2896 offer_id,
2897 channel,
2898 input.gpadl_id,
2899 gpadl,
2900 ) {
2901 self.inner.gpadls.remove(&(input.gpadl_id, offer_id));
2902 }
2903 }
2904 }
2905 Err(err) => {
2906 self.inner.incomplete_gpadls.remove(&input.gpadl_id);
2907 self.inner.gpadls.remove(&(input.gpadl_id, offer_id));
2908 let channel_id = channel.info.as_ref().expect("assigned").channel_id;
2909 tracelimit::warn_ratelimited!(
2910 err = &err as &dyn std::error::Error,
2911 channel_id = channel_id.0,
2912 key = %channel.offer.key(),
2913 gpadl_id = input.gpadl_id.0,
2914 "error handling gpadl body"
2915 );
2916 self.sender().send_gpadl_created(
2917 channel_id,
2918 input.gpadl_id,
2919 protocol::STATUS_UNSUCCESSFUL,
2920 );
2921 }
2922 }
2923
2924 Ok(())
2925 }
2926
2927 fn handle_gpadl_teardown(
2929 &mut self,
2930 input: &protocol::GpadlTeardown,
2931 ) -> Result<(), ChannelError> {
2932 let (offer_id, channel) = self
2933 .inner
2934 .channels
2935 .get_by_channel_id_mut(&self.inner.assigned_channels, input.channel_id)?;
2936
2937 tracing::debug!(
2938 channel_id = input.channel_id.0,
2939 key = %channel.offer.key(),
2940 gpadl_id = input.gpadl_id.0,
2941 "Received GPADL teardown request"
2942 );
2943
2944 let gpadl = self
2945 .inner
2946 .gpadls
2947 .get_mut(&(input.gpadl_id, offer_id))
2948 .ok_or(ChannelError::UnknownGpadlId)?;
2949
2950 match gpadl.state {
2951 GpadlState::InProgress
2952 | GpadlState::Offered
2953 | GpadlState::OfferedTearingDown
2954 | GpadlState::TearingDown => {
2955 return Err(ChannelError::InvalidGpadlState);
2956 }
2957 GpadlState::Accepted => {
2958 if channel.info.as_ref().map(|info| info.channel_id) != Some(input.channel_id) {
2959 return Err(ChannelError::WrongGpadlChannelId);
2960 }
2961
2962 if channel.state.is_reserved() {
2966 return Err(ChannelError::ChannelReserved);
2967 }
2968
2969 if channel.state.is_revoked() {
2970 tracing::trace!(
2971 channel_id = input.channel_id.0,
2972 key = %channel.offer.key(),
2973 gpadl_id = input.gpadl_id.0,
2974 "Gpadl teardown for revoked channel"
2975 );
2976
2977 self.inner.gpadls.remove(&(input.gpadl_id, offer_id));
2978 self.sender().send_gpadl_torndown(input.gpadl_id);
2979 } else {
2980 gpadl.state = GpadlState::TearingDown;
2981 self.notifier.notify(
2982 offer_id,
2983 Action::TeardownGpadl {
2984 gpadl_id: input.gpadl_id,
2985 post_restore: false,
2986 },
2987 );
2988 }
2989 }
2990 }
2991 Ok(())
2992 }
2993
2994 fn open_channel(
2997 &mut self,
2998 offer_id: OfferId,
2999 input: &OpenRequest,
3000 reserved_state: Option<ReservedState>,
3001 ) {
3002 let channel = &mut self.inner.channels[offer_id];
3003 assert!(matches!(channel.state, ChannelState::Closed));
3004
3005 channel.state = ChannelState::Opening {
3006 request: *input,
3007 reserved_state,
3008 };
3009
3010 let info = channel.info.as_ref().expect("assigned");
3013 self.notifier.notify(
3014 offer_id,
3015 Action::Open(
3016 OpenParams::from_request(
3017 info,
3018 input,
3019 channel.handled_monitor_info(),
3020 reserved_state.map(|state| state.target),
3021 ),
3022 self.inner.state.get_version().expect("must be connected"),
3023 ),
3024 );
3025 }
3026
3027 fn handle_open_channel(&mut self, input: &protocol::OpenChannel2) -> Result<(), ChannelError> {
3029 let (offer_id, channel) = self
3030 .inner
3031 .channels
3032 .get_by_channel_id_mut(&self.inner.assigned_channels, input.open_channel.channel_id)?;
3033
3034 let guest_specified_interrupt_info = self
3035 .inner
3036 .state
3037 .check_feature_flags(|ff| ff.guest_specified_signal_parameters())
3038 .then_some(SignalInfo {
3039 event_flag: input.event_flag,
3040 connection_id: input.connection_id,
3041 });
3042
3043 let flags = if self
3044 .inner
3045 .state
3046 .check_feature_flags(|ff| ff.channel_interrupt_redirection())
3047 {
3048 input.flags
3049 } else {
3050 Default::default()
3051 };
3052
3053 let request = OpenRequest {
3054 open_id: input.open_channel.open_id,
3055 ring_buffer_gpadl_id: input.open_channel.ring_buffer_gpadl_id,
3056 target_vp: protocol::vp_index_if_enabled(input.open_channel.target_vp),
3057 downstream_ring_buffer_page_offset: input
3058 .open_channel
3059 .downstream_ring_buffer_page_offset,
3060 user_data: input.open_channel.user_data,
3061 guest_specified_interrupt_info,
3062 flags,
3063 };
3064
3065 match channel.state {
3066 ChannelState::Closed => self.open_channel(offer_id, &request, None),
3067 ChannelState::Closing { params, .. } => {
3068 channel.state = ChannelState::ClosingReopen { params, request }
3072 }
3073 ChannelState::Revoked | ChannelState::Reoffered => {}
3074
3075 ChannelState::Open { .. }
3076 | ChannelState::Opening { .. }
3077 | ChannelState::ClosingReopen { .. } => return Err(ChannelError::ChannelAlreadyOpen),
3078
3079 ChannelState::ClientReleased
3080 | ChannelState::ClosingClientRelease
3081 | ChannelState::OpeningClientRelease => unreachable!(),
3082 }
3083 Ok(())
3084 }
3085
3086 fn handle_close_channel(&mut self, input: &protocol::CloseChannel) -> Result<(), ChannelError> {
3088 let (offer_id, channel) = self
3089 .inner
3090 .channels
3091 .get_by_channel_id_mut(&self.inner.assigned_channels, input.channel_id)?;
3092
3093 match channel.state {
3094 ChannelState::Open {
3095 params,
3096 modify_state,
3097 reserved_state: None,
3098 } => {
3099 if modify_state.is_modifying() {
3100 tracelimit::warn_ratelimited!(
3101 key = %channel.offer.key(),
3102 ?modify_state,
3103 "Client is closing the channel with a modify in progress"
3104 )
3105 }
3106
3107 channel.state = ChannelState::Closing {
3108 params,
3109 reserved_state: None,
3110 };
3111 self.notifier.notify(offer_id, Action::Close);
3112 }
3113
3114 ChannelState::Open {
3115 reserved_state: Some(_),
3116 ..
3117 } => return Err(ChannelError::ChannelReserved),
3118
3119 ChannelState::Revoked | ChannelState::Reoffered => {}
3120
3121 ChannelState::Closed
3122 | ChannelState::Opening { .. }
3123 | ChannelState::Closing { .. }
3124 | ChannelState::ClosingReopen { .. } => return Err(ChannelError::ChannelNotOpen),
3125
3126 ChannelState::ClientReleased
3127 | ChannelState::ClosingClientRelease
3128 | ChannelState::OpeningClientRelease => unreachable!(),
3129 }
3130
3131 Ok(())
3132 }
3133
3134 fn handle_open_reserved_channel(
3137 &mut self,
3138 input: &protocol::OpenReservedChannel,
3139 version: VersionInfo,
3140 ) -> Result<(), ChannelError> {
3141 let (offer_id, channel) = self
3142 .inner
3143 .channels
3144 .get_by_channel_id_mut(&self.inner.assigned_channels, input.channel_id)?;
3145
3146 let target = ConnectionTarget {
3147 vp: input.target_vp,
3148 sint: input.target_sint as u8,
3149 };
3150
3151 let reserved_state = Some(ReservedState { version, target });
3152
3153 let request = OpenRequest {
3154 ring_buffer_gpadl_id: input.ring_buffer_gpadl,
3155 target_vp: None,
3157 downstream_ring_buffer_page_offset: input.downstream_page_offset,
3158 open_id: 0,
3159 user_data: UserDefinedData::new_zeroed(),
3160 guest_specified_interrupt_info: None,
3161 flags: Default::default(),
3162 };
3163
3164 match channel.state {
3165 ChannelState::Closed => self.open_channel(offer_id, &request, reserved_state),
3166 ChannelState::Revoked | ChannelState::Reoffered => {}
3167
3168 ChannelState::Open { .. } | ChannelState::Opening { .. } => {
3169 return Err(ChannelError::ChannelAlreadyOpen);
3170 }
3171
3172 ChannelState::Closing { .. } | ChannelState::ClosingReopen { .. } => {
3173 return Err(ChannelError::InvalidChannelState);
3174 }
3175
3176 ChannelState::ClientReleased
3177 | ChannelState::ClosingClientRelease
3178 | ChannelState::OpeningClientRelease => unreachable!(),
3179 }
3180 Ok(())
3181 }
3182
3183 fn handle_close_reserved_channel(
3186 &mut self,
3187 input: &protocol::CloseReservedChannel,
3188 ) -> Result<(), ChannelError> {
3189 let (offer_id, channel) = self
3190 .inner
3191 .channels
3192 .get_by_channel_id_mut(&self.inner.assigned_channels, input.channel_id)?;
3193
3194 match channel.state {
3195 ChannelState::Open {
3196 params,
3197 reserved_state: Some(mut resvd),
3198 ..
3199 } => {
3200 resvd.target.vp = input.target_vp;
3201 resvd.target.sint = input.target_sint as u8;
3202 channel.state = ChannelState::Closing {
3203 params,
3204 reserved_state: Some(resvd),
3205 };
3206 self.notifier.notify(offer_id, Action::Close);
3207 }
3208
3209 ChannelState::Open {
3210 reserved_state: None,
3211 ..
3212 } => return Err(ChannelError::ChannelNotReserved),
3213
3214 ChannelState::Revoked | ChannelState::Reoffered => {}
3215
3216 ChannelState::Closed
3217 | ChannelState::Opening { .. }
3218 | ChannelState::Closing { .. }
3219 | ChannelState::ClosingReopen { .. } => return Err(ChannelError::ChannelNotOpen),
3220
3221 ChannelState::ClientReleased
3222 | ChannelState::ClosingClientRelease
3223 | ChannelState::OpeningClientRelease => unreachable!(),
3224 }
3225
3226 Ok(())
3227 }
3228
3229 #[must_use]
3233 fn client_release_channel(
3234 mut sender: MessageSender<'_, N>,
3235 offer_id: OfferId,
3236 channel: &mut Channel,
3237 gpadls: &mut GpadlMap,
3238 incomplete_gpadls: &mut IncompleteGpadlMap,
3239 assigned_channels: &mut AssignedChannels,
3240 assigned_monitors: &mut AssignedMonitors,
3241 info: Option<&ConnectionInfo>,
3242 vm_reset: bool,
3243 ) -> bool {
3244 tracelimit::info_ratelimited!(?offer_id, key = %channel.offer.key(), "client released channel");
3245 gpadls.retain(|&(gpadl_id, gpadl_offer_id), gpadl| {
3247 if gpadl_offer_id != offer_id {
3248 return true;
3249 }
3250 match gpadl.state {
3251 GpadlState::InProgress => {
3252 incomplete_gpadls.remove(&gpadl_id);
3253 false
3254 }
3255 GpadlState::Offered => {
3256 gpadl.state = GpadlState::OfferedTearingDown;
3257 true
3258 }
3259 GpadlState::Accepted => {
3260 if channel.state.is_revoked() {
3261 false
3263 } else {
3264 gpadl.state = GpadlState::TearingDown;
3265 sender.notifier.notify(
3266 offer_id,
3267 Action::TeardownGpadl {
3268 gpadl_id,
3269 post_restore: false,
3270 },
3271 );
3272 true
3273 }
3274 }
3275 GpadlState::OfferedTearingDown | GpadlState::TearingDown => true,
3276 }
3277 });
3278
3279 let remove = match &mut channel.state {
3280 ChannelState::Closed => {
3281 channel.state = ChannelState::ClientReleased;
3282 false
3283 }
3284 ChannelState::Reoffered => {
3285 if let Some(info) = info {
3286 channel.state = ChannelState::Closed;
3287 channel.restore_state = RestoreState::New;
3288 sender.send_offer(channel, info);
3289 return false;
3291 }
3292 channel.state = ChannelState::ClientReleased;
3293 false
3294 }
3295 ChannelState::Revoked => {
3296 channel.state = ChannelState::ClientReleased;
3297 true
3298 }
3299 ChannelState::Opening { .. } => {
3300 if vm_reset {
3318 channel.state = ChannelState::ClientReleased;
3319 } else {
3320 channel.state = ChannelState::OpeningClientRelease;
3321 }
3322 false
3323 }
3324 ChannelState::Open { .. } => {
3325 channel.state = ChannelState::ClosingClientRelease;
3326 sender.notifier.notify(offer_id, Action::Close);
3327 false
3328 }
3329 ChannelState::Closing { .. } | ChannelState::ClosingReopen { .. } => {
3330 channel.state = ChannelState::ClosingClientRelease;
3331 false
3332 }
3333
3334 ChannelState::ClosingClientRelease
3335 | ChannelState::OpeningClientRelease
3336 | ChannelState::ClientReleased => false,
3337 };
3338
3339 assert!(channel.state.is_released());
3340
3341 channel.release_channel(offer_id, assigned_channels, assigned_monitors);
3342 remove
3343 }
3344
3345 fn handle_rel_id_released(
3347 &mut self,
3348 input: &protocol::RelIdReleased,
3349 ) -> Result<(), ChannelError> {
3350 let channel_id = input.channel_id;
3351 let (offer_id, channel) = self
3352 .inner
3353 .channels
3354 .get_by_channel_id_mut(&self.inner.assigned_channels, channel_id)?;
3355
3356 match channel.state {
3357 ChannelState::Closed
3358 | ChannelState::Revoked
3359 | ChannelState::Closing { .. }
3360 | ChannelState::Reoffered => {
3361 if Self::client_release_channel(
3362 self.inner
3363 .pending_messages
3364 .sender(self.notifier, self.inner.state.is_paused()),
3365 offer_id,
3366 channel,
3367 &mut self.inner.gpadls,
3368 &mut self.inner.incomplete_gpadls,
3369 &mut self.inner.assigned_channels,
3370 &mut self.inner.assigned_monitors,
3371 self.inner.state.get_connected_info(),
3372 false,
3373 ) {
3374 self.inner.channels.remove(offer_id);
3375 }
3376
3377 self.check_disconnected();
3378 }
3379
3380 ChannelState::Opening { .. }
3381 | ChannelState::Open { .. }
3382 | ChannelState::ClosingReopen { .. } => return Err(ChannelError::InvalidChannelState),
3383
3384 ChannelState::ClientReleased
3385 | ChannelState::OpeningClientRelease
3386 | ChannelState::ClosingClientRelease => unreachable!(),
3387 }
3388 Ok(())
3389 }
3390
3391 fn handle_tl_connect_request(&mut self, request: protocol::TlConnectRequest2) {
3394 let version = self
3395 .inner
3396 .state
3397 .get_version()
3398 .expect("must be connected")
3399 .version;
3400
3401 let hosted_silo_unaware = version < Version::Win10Rs5;
3402 self.notifier
3403 .notify_hvsock(&HvsockConnectRequest::from_message(
3404 request,
3405 hosted_silo_unaware,
3406 ));
3407 }
3408
3409 pub fn send_tl_connect_result(&mut self, result: HvsockConnectResult) {
3411 if !result.success && self.inner.state.check_version(Version::Win10Rs3_0) {
3415 self.sender().send_message(&protocol::TlConnectResult {
3419 service_id: result.service_id,
3420 endpoint_id: result.endpoint_id,
3421 status: protocol::STATUS_CONNECTION_REFUSED,
3422 })
3423 }
3424 }
3425
3426 fn handle_modify_channel(
3429 &mut self,
3430 request: &protocol::ModifyChannel,
3431 ) -> Result<(), ChannelError> {
3432 let result = self.modify_channel(request);
3433 if result.is_err() {
3434 self.send_modify_channel_response(request.channel_id, protocol::STATUS_UNSUCCESSFUL);
3435 }
3436
3437 result
3438 }
3439
3440 fn modify_channel(&mut self, request: &protocol::ModifyChannel) -> Result<(), ChannelError> {
3442 if request.target_vp == protocol::VP_INDEX_DISABLE_INTERRUPT {
3444 return Err(ChannelError::InvalidTargetVp);
3445 }
3446
3447 let (offer_id, channel) = self
3448 .inner
3449 .channels
3450 .get_by_channel_id_mut(&self.inner.assigned_channels, request.channel_id)?;
3451
3452 let (open_request, modify_state) = match &mut channel.state {
3453 ChannelState::Open {
3454 params,
3455 modify_state,
3456 reserved_state: None,
3457 } => (params, modify_state),
3458 _ => return Err(ChannelError::InvalidChannelState),
3459 };
3460
3461 if open_request.target_vp.is_none() {
3462 return Err(ChannelError::InterruptsDisabled);
3463 }
3464
3465 if let ModifyState::Modifying { pending_target_vp } = modify_state {
3466 if self.inner.state.check_version(Version::Iron) {
3467 tracelimit::warn_ratelimited!(
3470 key = %channel.offer.key(),
3471 "Client sent new ModifyChannel before receiving ModifyChannelResponse."
3472 );
3473 } else {
3474 *pending_target_vp = Some(request.target_vp);
3477 }
3478 } else {
3479 self.notifier.notify(
3480 offer_id,
3481 Action::Modify {
3482 target_vp: request.target_vp,
3483 },
3484 );
3485
3486 open_request.target_vp = Some(request.target_vp);
3488 *modify_state = ModifyState::Modifying {
3489 pending_target_vp: None,
3490 };
3491 }
3492
3493 Ok(())
3494 }
3495
3496 pub fn modify_channel_complete(&mut self, offer_id: OfferId, status: i32) {
3503 let channel = &mut self.inner.channels[offer_id];
3504
3505 if let ChannelState::Open {
3506 params,
3507 modify_state: ModifyState::Modifying { pending_target_vp },
3508 reserved_state: None,
3509 } = channel.state
3510 {
3511 channel.state = ChannelState::Open {
3512 params,
3513 modify_state: ModifyState::NotModifying,
3514 reserved_state: None,
3515 };
3516
3517 let channel_id = channel.info.as_ref().expect("assigned").channel_id;
3519 let key = channel.offer.key();
3520 self.send_modify_channel_response(channel_id, status);
3521
3522 if let Some(target_vp) = pending_target_vp {
3524 let request = protocol::ModifyChannel {
3525 channel_id,
3526 target_vp,
3527 };
3528
3529 if let Err(error) = self.handle_modify_channel(&request) {
3530 tracelimit::warn_ratelimited!(?error, %key, "Pending ModifyChannel request failed.")
3531 }
3532 }
3533 }
3534 }
3535
3536 fn send_modify_channel_response(&mut self, channel_id: ChannelId, status: i32) {
3537 if self.inner.state.check_version(Version::Iron) {
3538 self.sender()
3539 .send_message(&protocol::ModifyChannelResponse { channel_id, status });
3540 }
3541 }
3542
3543 fn handle_modify_connection(&mut self, request: protocol::ModifyConnection) {
3544 if let Err(err) = self.modify_connection(request) {
3545 tracelimit::error_ratelimited!(?err, "modifying connection failed");
3546 self.complete_modify_connection(ModifyConnectionResponse::Modified(
3547 protocol::ConnectionState::FAILED_UNKNOWN_FAILURE,
3548 ));
3549 }
3550 }
3551
3552 fn modify_connection(&mut self, request: protocol::ModifyConnection) -> anyhow::Result<()> {
3553 let ConnectionState::Connected(info) = &mut self.inner.state else {
3554 anyhow::bail!(
3555 "Invalid state for ModifyConnection request: {:?}",
3556 self.inner.state
3557 );
3558 };
3559
3560 if info.modifying {
3561 anyhow::bail!(
3562 "Duplicate ModifyConnection request, state: {:?}",
3563 self.inner.state
3564 );
3565 }
3566
3567 if matches!(
3568 info.monitor_page,
3569 Some(MonitorPageGpaInfo {
3570 server_allocated: true,
3571 ..
3572 })
3573 ) {
3574 anyhow::bail!("Cannot modify server-allocated monitor pages");
3575 }
3576
3577 if (request.child_to_parent_monitor_page_gpa == 0)
3578 != (request.parent_to_child_monitor_page_gpa == 0)
3579 {
3580 anyhow::bail!("Guest must specify either both or no monitor pages, {request:?}");
3581 }
3582
3583 let monitor_page = (request.child_to_parent_monitor_page_gpa != 0).then_some(
3584 MonitorPageGpaInfo::from_guest_gpas(MonitorPageGpas {
3585 child_to_parent: request.child_to_parent_monitor_page_gpa,
3586 parent_to_child: request.parent_to_child_monitor_page_gpa,
3587 }),
3588 );
3589
3590 info.modifying = true;
3591 info.monitor_page = monitor_page;
3592 tracing::debug!("modifying connection parameters.");
3593 self.notifier.modify_connection(request.into())?;
3594
3595 Ok(())
3596 }
3597
3598 pub fn complete_modify_connection(&mut self, response: ModifyConnectionResponse) {
3599 tracing::debug!(?response, "modifying connection parameters complete");
3600
3601 match &mut self.inner.state {
3605 ConnectionState::Connecting { .. } => self.complete_initiate_contact(response),
3606 ConnectionState::Disconnecting { .. } => self.complete_disconnect(),
3607 ConnectionState::Connected(info) => {
3608 let ModifyConnectionResponse::Modified(connection_state) = response else {
3609 panic!(
3610 "Relay should not return {:?} for a modify request with no version.",
3611 response
3612 );
3613 };
3614
3615 if !info.modifying {
3616 panic!(
3617 "ModifyConnection response while not modifying, state: {:?}",
3618 self.inner.state
3619 );
3620 }
3621
3622 info.modifying = false;
3623 self.sender()
3624 .send_message(&protocol::ModifyConnectionResponse { connection_state });
3625 }
3626 _ => panic!(
3627 "Invalid state for ModifyConnection response: {:?}",
3628 self.inner.state
3629 ),
3630 }
3631 }
3632
3633 fn handle_pause(&mut self) {
3634 tracelimit::info_ratelimited!("pausing sending messages");
3635 self.sender().send_message(&protocol::PauseResponse {});
3636 let ConnectionState::Connected(info) = &mut self.inner.state else {
3637 unreachable!(
3638 "in unexpected state {:?}, should be prevented by Message::parse()",
3639 self.inner.state
3640 );
3641 };
3642 info.paused = true;
3643 }
3644
3645 pub fn handle_synic_message(&mut self, message: SynicMessage) -> Result<(), ChannelError> {
3647 assert!(!self.is_resetting());
3648
3649 let version = self.inner.state.get_version();
3650 let msg = Message::parse(&message.data, version)?;
3651 tracing::trace!(?msg, message.trusted, "received vmbus message");
3652 if self.inner.state.is_trusted() && !message.trusted {
3657 tracelimit::warn_ratelimited!(?msg, "Received untrusted message");
3658 return Err(ChannelError::UntrustedMessage);
3659 }
3660
3661 match &mut self.inner.state {
3663 ConnectionState::Connected(info) if info.paused => {
3664 if !matches!(
3665 msg,
3666 Message::Resume(..)
3667 | Message::Unload(..)
3668 | Message::InitiateContact { .. }
3669 | Message::InitiateContact2 { .. }
3670 ) {
3671 tracelimit::warn_ratelimited!(?msg, "Received message while paused");
3672 return Err(ChannelError::Paused);
3673 }
3674 tracelimit::info_ratelimited!("resuming sending messages");
3675 info.paused = false;
3676 }
3677 _ => {}
3678 }
3679
3680 match msg {
3681 Message::InitiateContact2(input, ..) => {
3682 self.handle_initiate_contact(&input, &message, true)?
3683 }
3684 Message::InitiateContact(input, ..) => {
3685 self.handle_initiate_contact(&input.into(), &message, false)?
3686 }
3687 Message::Unload(..) => self.handle_unload(),
3688 Message::RequestOffers(..) => self.handle_request_offers()?,
3689 Message::GpadlHeader(input, range) => self.handle_gpadl_header(&input, range),
3690 Message::GpadlBody(input, range) => self.handle_gpadl_body(&input, range)?,
3691 Message::GpadlTeardown(input, ..) => self.handle_gpadl_teardown(&input)?,
3692 Message::OpenChannel(input, ..) => self.handle_open_channel(&input.into())?,
3693 Message::OpenChannel2(input, ..) => self.handle_open_channel(&input)?,
3694 Message::CloseChannel(input, ..) => self.handle_close_channel(&input)?,
3695 Message::RelIdReleased(input, ..) => self.handle_rel_id_released(&input)?,
3696 Message::TlConnectRequest(input, ..) => self.handle_tl_connect_request(input.into()),
3697 Message::TlConnectRequest2(input, ..) => self.handle_tl_connect_request(input),
3698 Message::ModifyChannel(input, ..) => self.handle_modify_channel(&input)?,
3699 Message::ModifyConnection(input, ..) => self.handle_modify_connection(input),
3700 Message::OpenReservedChannel(input, ..) => self.handle_open_reserved_channel(
3701 &input,
3702 version.expect("version validated by Message::parse"),
3703 )?,
3704 Message::CloseReservedChannel(input, ..) => {
3705 self.handle_close_reserved_channel(&input)?
3706 }
3707 Message::Pause(protocol::Pause, ..) => self.handle_pause(),
3708 Message::Resume(protocol::Resume, ..) => {}
3709 Message::OfferChannel(..)
3711 | Message::RescindChannelOffer(..)
3712 | Message::AllOffersDelivered(..)
3713 | Message::OpenResult(..)
3714 | Message::GpadlCreated(..)
3715 | Message::GpadlTorndown(..)
3716 | Message::VersionResponse(..)
3717 | Message::VersionResponse2(..)
3718 | Message::VersionResponse3(..)
3719 | Message::UnloadComplete(..)
3720 | Message::CloseReservedChannelResponse(..)
3721 | Message::TlConnectResult(..)
3722 | Message::ModifyChannelResponse(..)
3723 | Message::ModifyConnectionResponse(..)
3724 | Message::PauseResponse(..) => {
3725 unreachable!("Server received client message {:?}", msg);
3726 }
3727 }
3728 Ok(())
3729 }
3730
3731 pub fn gpadl_create_complete(&mut self, offer_id: OfferId, gpadl_id: GpadlId, status: i32) {
3733 let Some(gpadl) = self.inner.gpadls.get_mut(&(gpadl_id, offer_id)) else {
3734 tracelimit::error_ratelimited!(
3735 ?offer_id,
3736 key = %self.inner.channels[offer_id].offer.key(),
3737 ?gpadl_id,
3738 "invalid gpadl ID for channel"
3739 );
3740 return;
3741 };
3742 let retain = match gpadl.state {
3743 GpadlState::InProgress | GpadlState::TearingDown | GpadlState::Accepted => {
3744 tracelimit::error_ratelimited!(?offer_id, ?gpadl_id, ?gpadl, "invalid gpadl state");
3745 return;
3746 }
3747 GpadlState::Offered => {
3748 let channel_id = self.inner.channels[offer_id]
3749 .info
3750 .as_ref()
3751 .expect("assigned")
3752 .channel_id;
3753 self.inner
3754 .pending_messages
3755 .sender(self.notifier, self.inner.state.is_paused())
3756 .send_gpadl_created(channel_id, gpadl_id, status);
3757 if status >= 0 {
3758 gpadl.state = GpadlState::Accepted;
3759 true
3760 } else {
3761 false
3762 }
3763 }
3764 GpadlState::OfferedTearingDown => {
3765 if status >= 0 {
3766 self.notifier.notify(
3768 offer_id,
3769 Action::TeardownGpadl {
3770 gpadl_id,
3771 post_restore: false,
3772 },
3773 );
3774 gpadl.state = GpadlState::TearingDown;
3775 true
3776 } else {
3777 false
3778 }
3779 }
3780 };
3781 if !retain {
3782 self.inner
3783 .gpadls
3784 .remove(&(gpadl_id, offer_id))
3785 .expect("gpadl validated above");
3786
3787 self.check_disconnected();
3788 }
3789 }
3790
3791 pub fn gpadl_teardown_complete(&mut self, offer_id: OfferId, gpadl_id: GpadlId) {
3793 let channel = &mut self.inner.channels[offer_id];
3794 let Some(gpadl) = self.inner.gpadls.get_mut(&(gpadl_id, offer_id)) else {
3795 tracelimit::error_ratelimited!(
3796 ?offer_id,
3797 key = %channel.offer.key(),
3798 ?gpadl_id,
3799 "invalid gpadl ID for channel"
3800 );
3801 return;
3802 };
3803 tracing::debug!(
3804 offer_id = offer_id.0,
3805 key = %channel.offer.key(),
3806 gpadl_id = gpadl_id.0,
3807 "Gpadl teardown complete"
3808 );
3809 match gpadl.state {
3810 GpadlState::InProgress
3811 | GpadlState::Offered
3812 | GpadlState::OfferedTearingDown
3813 | GpadlState::Accepted => {
3814 tracelimit::error_ratelimited!(?offer_id, key = %channel.offer.key(), ?gpadl_id, ?gpadl, "invalid gpadl state");
3815 }
3816 GpadlState::TearingDown => {
3817 if !channel.state.is_released() {
3818 self.sender().send_gpadl_torndown(gpadl_id);
3819 }
3820 self.inner
3821 .gpadls
3822 .remove(&(gpadl_id, offer_id))
3823 .expect("gpadl validated above");
3824
3825 self.check_disconnected();
3826 }
3827 }
3828 }
3829
3830 fn sender(&mut self) -> MessageSender<'_, N> {
3835 self.inner
3836 .pending_messages
3837 .sender(self.notifier, self.inner.state.is_paused())
3838 }
3839}
3840
3841fn revoke<N: Notifier>(
3842 mut sender: MessageSender<'_, N>,
3843 offer_id: OfferId,
3844 channel: &mut Channel,
3845 gpadls: &mut GpadlMap,
3846) -> bool {
3847 let info = match channel.state {
3848 ChannelState::Closed
3849 | ChannelState::Open { .. }
3850 | ChannelState::Opening { .. }
3851 | ChannelState::Closing { .. }
3852 | ChannelState::ClosingReopen { .. } => {
3853 channel.state = ChannelState::Revoked;
3854 Some(channel.info.as_ref().expect("assigned"))
3855 }
3856 ChannelState::Reoffered => {
3857 channel.state = ChannelState::Revoked;
3858 None
3859 }
3860 ChannelState::ClientReleased
3861 | ChannelState::OpeningClientRelease
3862 | ChannelState::ClosingClientRelease => None,
3863 ChannelState::Revoked => return true,
3865 };
3866 let retain = !channel.state.is_released();
3867
3868 gpadls.retain(|&(gpadl_id, gpadl_offer_id), gpadl| {
3870 if gpadl_offer_id != offer_id {
3871 return true;
3872 }
3873
3874 match gpadl.state {
3875 GpadlState::InProgress => true,
3876 GpadlState::Offered => {
3877 if let Some(info) = info {
3878 sender.send_gpadl_created(
3879 info.channel_id,
3880 gpadl_id,
3881 protocol::STATUS_UNSUCCESSFUL,
3882 );
3883 }
3884 false
3885 }
3886 GpadlState::OfferedTearingDown => false,
3887 GpadlState::Accepted => true,
3888 GpadlState::TearingDown => {
3889 if info.is_some() {
3890 sender.send_gpadl_torndown(gpadl_id);
3891 }
3892 false
3893 }
3894 }
3895 });
3896 if let Some(info) = info {
3897 sender.send_rescind(info);
3898 }
3899 if channel.restore_state != RestoreState::New {
3901 channel.restore_state = RestoreState::Restored;
3902 }
3903 retain
3904}
3905
3906struct PendingMessages(VecDeque<OutgoingMessage>);
3907
3908impl PendingMessages {
3909 fn sender<'a, N: Notifier>(
3911 &'a mut self,
3912 notifier: &'a mut N,
3913 is_paused: bool,
3914 ) -> MessageSender<'a, N> {
3915 MessageSender {
3916 notifier,
3917 pending_messages: self,
3918 is_paused,
3919 }
3920 }
3921}
3922
3923struct MessageSender<'a, N> {
3926 notifier: &'a mut N,
3927 pending_messages: &'a mut PendingMessages,
3928 is_paused: bool,
3929}
3930
3931impl<N: Notifier> MessageSender<'_, N> {
3932 fn send_message<
3934 T: IntoBytes + protocol::VmbusMessage + std::fmt::Debug + Immutable + KnownLayout,
3935 >(
3936 &mut self,
3937 msg: &T,
3938 ) {
3939 let message = OutgoingMessage::new(msg);
3940
3941 tracing::trace!(typ = ?T::MESSAGE_TYPE, ?msg, "sending message");
3942 if !self.pending_messages.0.is_empty()
3944 || self.is_paused
3945 || !self.notifier.send_message(&message, MessageTarget::Default)
3946 {
3947 tracing::trace!("message queued");
3948 self.pending_messages.0.push_back(message);
3950 }
3951 }
3952
3953 fn send_message_with_target<
3955 T: IntoBytes + protocol::VmbusMessage + std::fmt::Debug + Immutable + KnownLayout,
3956 >(
3957 &mut self,
3958 msg: &T,
3959 target: MessageTarget,
3960 ) {
3961 if target == MessageTarget::Default {
3962 self.send_message(msg);
3963 } else {
3964 tracing::trace!(typ = ?T::MESSAGE_TYPE, ?msg, "sending message");
3965 let message = OutgoingMessage::new(msg);
3968 if !self.notifier.send_message(&message, target) {
3969 tracelimit::warn_ratelimited!(?target, "failed to send message");
3970 }
3971 }
3972 }
3973
3974 fn send_offer(&mut self, channel: &mut Channel, connection_info: &ConnectionInfo) {
3976 let info = channel.info.as_ref().expect("assigned");
3977 let mut flags = channel.offer.flags;
3978 if !connection_info
3979 .version
3980 .feature_flags
3981 .confidential_channels()
3982 {
3983 flags.set_confidential_ring_buffer(false);
3984 flags.set_confidential_external_memory(false);
3985 }
3986
3987 let monitor_id = connection_info.monitor_page.and(info.monitor_id);
3994 let msg = protocol::OfferChannel {
3995 interface_id: channel.offer.interface_id,
3996 instance_id: channel.offer.instance_id,
3997 rsvd: [0; 4],
3998 flags,
3999 mmio_megabytes: channel.offer.mmio_megabytes,
4000 user_defined: channel.offer.user_defined,
4001 subchannel_index: channel.offer.subchannel_index,
4002 mmio_megabytes_optional: channel.offer.mmio_megabytes_optional,
4003 channel_id: info.channel_id,
4004 monitor_id: monitor_id.unwrap_or(MonitorId::INVALID).0,
4005 monitor_allocated: monitor_id.is_some().into(),
4006 is_dedicated: 1,
4009 connection_id: info.connection_id,
4010 };
4011 tracing::info!(
4012 channel_id = msg.channel_id.0,
4013 connection_id = msg.connection_id,
4014 key = %channel.offer.key(),
4015 "sending offer to guest"
4016 );
4017
4018 self.send_message(&msg);
4019 }
4020
4021 fn send_open_result(
4022 &mut self,
4023 channel_id: ChannelId,
4024 open_request: &OpenRequest,
4025 result: i32,
4026 target: MessageTarget,
4027 ) {
4028 self.send_message_with_target(
4029 &protocol::OpenResult {
4030 channel_id,
4031 open_id: open_request.open_id,
4032 status: result as u32,
4033 },
4034 target,
4035 );
4036 }
4037
4038 fn send_gpadl_created(&mut self, channel_id: ChannelId, gpadl_id: GpadlId, status: i32) {
4039 self.send_message(&protocol::GpadlCreated {
4040 channel_id,
4041 gpadl_id,
4042 status,
4043 });
4044 }
4045
4046 fn send_gpadl_torndown(&mut self, gpadl_id: GpadlId) {
4047 self.send_message(&protocol::GpadlTorndown { gpadl_id });
4048 }
4049
4050 fn send_rescind(&mut self, info: &OfferedInfo) {
4051 tracing::info!(
4052 channel_id = info.channel_id.0,
4053 "rescinding channel from guest"
4054 );
4055
4056 self.send_message(&protocol::RescindChannelOffer {
4057 channel_id: info.channel_id,
4058 });
4059 }
4060}
4061
4062struct VersionResponseData {
4064 version: VersionInfo,
4065 state: protocol::ConnectionState,
4066 monitor_pages: Option<MonitorPageGpas>,
4067}
4068
4069impl VersionResponseData {
4070 fn new(version: VersionInfo, state: protocol::ConnectionState) -> Self {
4072 VersionResponseData {
4073 version,
4074 state,
4075 monitor_pages: None,
4076 }
4077 }
4078
4079 fn with_monitor_pages(mut self, monitor_pages: Option<MonitorPageGpas>) -> Self {
4081 self.monitor_pages = monitor_pages;
4082 self
4083 }
4084}