1pub mod saved_state;
5#[cfg(test)]
6mod tests;
7
8use crate::Guid;
9use crate::SINT;
10use crate::SynicMessage;
11use crate::monitor::AssignedMonitors;
12use crate::protocol::Version;
13use hvdef::Vtl;
14use inspect::Inspect;
15pub use saved_state::RestoreError;
16pub use saved_state::SavedState;
17pub use saved_state::SavedStateData;
18use slab::Slab;
19use std::cmp::min;
20use std::collections::VecDeque;
21use std::collections::hash_map::Entry;
22use std::collections::hash_map::HashMap;
23use std::fmt::Display;
24use std::ops::Index;
25use std::ops::IndexMut;
26use std::task::Poll;
27use std::task::ready;
28use std::time::Duration;
29use thiserror::Error;
30use vmbus_channel::bus::ChannelType;
31use vmbus_channel::bus::GpadlRequest;
32use vmbus_channel::bus::OfferKey;
33use vmbus_channel::bus::OfferParams;
34use vmbus_channel::bus::OpenData;
35use vmbus_channel::bus::RestoredGpadl;
36use vmbus_core::HvsockConnectRequest;
37use vmbus_core::HvsockConnectResult;
38use vmbus_core::MaxVersionInfo;
39use vmbus_core::OutgoingMessage;
40use vmbus_core::VersionInfo;
41use vmbus_core::protocol;
42use vmbus_core::protocol::ChannelId;
43use vmbus_core::protocol::ConnectionId;
44use vmbus_core::protocol::FeatureFlags;
45use vmbus_core::protocol::GpadlId;
46use vmbus_core::protocol::Message;
47use vmbus_core::protocol::OfferFlags;
48use vmbus_core::protocol::UserDefinedData;
49use vmbus_ring::gparange;
50use vmcore::monitor::MonitorId;
51use vmcore::synic::MonitorInfo;
52use vmcore::synic::MonitorPageGpas;
53use zerocopy::FromZeros;
54use zerocopy::Immutable;
55use zerocopy::IntoBytes;
56use zerocopy::KnownLayout;
57
58#[derive(Debug, Error)]
60pub enum ChannelError {
61 #[error("unknown channel ID")]
62 UnknownChannelId,
63 #[error("unknown GPADL ID")]
64 UnknownGpadlId,
65 #[error("parse error")]
66 ParseError(#[from] protocol::ParseError),
67 #[error("invalid gpa range")]
68 InvalidGpaRange(#[source] gparange::Error),
69 #[error("duplicate GPADL ID")]
70 DuplicateGpadlId,
71 #[error("GPADL is already complete")]
72 GpadlAlreadyComplete,
73 #[error("GPADL channel ID mismatch")]
74 WrongGpadlChannelId,
75 #[error("trying to open an open channel")]
76 ChannelAlreadyOpen,
77 #[error("trying to close a closed channel")]
78 ChannelNotOpen,
79 #[error("invalid GPADL state for operation")]
80 InvalidGpadlState,
81 #[error("invalid channel state for operation")]
82 InvalidChannelState,
83 #[error("channel ID has already been released")]
84 ChannelReleased,
85 #[error("channel offers have already been sent")]
86 OffersAlreadySent,
87 #[error("invalid operation on reserved channel")]
88 ChannelReserved,
89 #[error("invalid operation on non-reserved channel")]
90 ChannelNotReserved,
91 #[error("received untrusted message for trusted connection")]
92 UntrustedMessage,
93 #[error("received a non-resuming message while paused")]
94 Paused,
95}
96
97#[derive(Debug, Error)]
98pub enum OfferError {
99 #[error("the channel ID {} is not valid for this operation", (.0).0)]
100 InvalidChannelId(ChannelId),
101 #[error("the channel ID {} is already in use", (.0).0)]
102 ChannelIdInUse(ChannelId),
103 #[error("offer {0} already exists")]
104 AlreadyExists(OfferKey),
105 #[error("specified resources do not match those of the existing saved or revoked offer")]
106 IncompatibleResources,
107 #[error("too many channels have been offered")]
108 TooManyChannels,
109 #[error("mismatched monitor ID from saved state; expected {0:?}, actual {1:?}")]
110 MismatchedMonitorId(Option<MonitorId>, MonitorId),
111}
112
113#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
115pub struct OfferId(usize);
116
117type IncompleteGpadlMap = HashMap<GpadlId, OfferId>;
118
119type GpadlMap = HashMap<(GpadlId, OfferId), Gpadl>;
120
121pub struct Server {
123 state: ConnectionState,
124 channels: ChannelList,
125 assigned_channels: AssignedChannels,
126 assigned_monitors: AssignedMonitors,
127 gpadls: GpadlMap,
128 incomplete_gpadls: IncompleteGpadlMap,
129 child_connection_id: u32,
130 max_version: Option<MaxVersionInfo>,
131 delayed_max_version: Option<MaxVersionInfo>,
132 pending_messages: PendingMessages,
135 require_server_allocated_mnf: bool,
139 use_absolute_channel_order: bool,
140}
141
142pub struct ServerWithNotifier<'a, T> {
143 inner: &'a mut Server,
144 notifier: &'a mut T,
145}
146
147impl<T> Drop for ServerWithNotifier<'_, T> {
148 fn drop(&mut self) {
149 self.inner.validate();
150 }
151}
152
153impl<T: Notifier> Inspect for ServerWithNotifier<'_, T> {
154 fn inspect(&self, req: inspect::Request<'_>) {
155 let mut resp = req.respond();
156 let (state, info, next_action) = match &self.inner.state {
157 ConnectionState::Disconnected => ("disconnected", None, None),
158 ConnectionState::Connecting { info, .. } => ("connecting", Some(info), None),
159 ConnectionState::Connected(info) => (
160 if info.offers_sent {
161 "connected"
162 } else {
163 "negotiated"
164 },
165 Some(info),
166 None,
167 ),
168 ConnectionState::Disconnecting { next_action, .. } => {
169 ("disconnecting", None, Some(next_action))
170 }
171 };
172
173 resp.field("connection_info", info);
174 let next_action = next_action.map(|a| match a {
175 ConnectionAction::None => "disconnect",
176 ConnectionAction::Reset => "reset",
177 ConnectionAction::SendUnloadComplete => "unload",
178 ConnectionAction::Reconnect { .. } => "reconnect",
179 ConnectionAction::SendFailedVersionResponse => "send_version_response",
180 });
181 resp.field("state", state)
182 .field("next_action", next_action)
183 .field(
184 "assigned_monitors_bitmap",
185 format_args!("{:x}", self.inner.assigned_monitors.bitmap()),
186 )
187 .child("channels", |req| {
188 let mut resp = req.respond();
189 self.inner
190 .channels
191 .inspect(self.notifier, self.inner.get_version(), &mut resp);
192 for ((gpadl_id, offer_id), gpadl) in &self.inner.gpadls {
193 let channel = &self.inner.channels[*offer_id];
194 resp.field(
195 &channel_inspect_path(
196 &channel.offer,
197 format_args!("/gpadls/{}", gpadl_id.0),
198 ),
199 gpadl,
200 );
201 }
202 });
203 }
204}
205
206#[derive(Debug, Copy, Clone, Inspect)]
208struct MonitorPageGpaInfo {
209 gpas: MonitorPageGpas,
210 server_allocated: bool,
211}
212
213impl MonitorPageGpaInfo {
214 fn from_guest_gpas(gpas: MonitorPageGpas) -> Self {
216 Self {
217 gpas,
218 server_allocated: false,
219 }
220 }
221
222 fn from_server_gpas(gpas: MonitorPageGpas) -> Self {
224 Self {
225 gpas,
226 server_allocated: true,
227 }
228 }
229}
230
231#[derive(Debug, Copy, Clone, Inspect)]
232struct ConnectionInfo {
233 version: VersionInfo,
234 trusted: bool,
237 offers_sent: bool,
238 interrupt_page: Option<u64>,
239 monitor_page: Option<MonitorPageGpaInfo>,
240 target_message_vp: u32,
241 modifying: bool,
242 client_id: Guid,
243 paused: bool,
244}
245
246#[derive(Debug)]
248enum ConnectionState {
249 Disconnected,
250 Disconnecting {
251 next_action: ConnectionAction,
252 modify_sent: bool,
253 },
254 Connecting {
255 info: ConnectionInfo,
256 next_action: ConnectionAction,
257 },
258 Connected(ConnectionInfo),
259}
260
261impl ConnectionState {
262 fn check_version(&self, min_version: Version) -> bool {
264 matches!(self, ConnectionState::Connected(info) if info.version.version >= min_version)
265 }
266
267 fn check_feature_flags(&self, flags: impl Fn(FeatureFlags) -> bool) -> bool {
270 matches!(self, ConnectionState::Connected(info) if flags(info.version.feature_flags))
271 }
272
273 fn get_version(&self) -> Option<VersionInfo> {
274 if let ConnectionState::Connected(info) = self {
275 Some(info.version)
276 } else {
277 None
278 }
279 }
280
281 fn get_connected_info(&self) -> Option<&ConnectionInfo> {
283 if let ConnectionState::Connected(info) = self {
284 Some(info)
285 } else {
286 None
287 }
288 }
289
290 fn is_trusted(&self) -> bool {
291 match self {
292 ConnectionState::Connected(info) => info.trusted,
293 ConnectionState::Connecting { info, .. } => info.trusted,
294 _ => false,
295 }
296 }
297
298 fn is_paused(&self) -> bool {
299 if let ConnectionState::Connected(info) = self {
300 info.paused
301 } else {
302 false
303 }
304 }
305}
306
307#[derive(Debug, Copy, Clone)]
308enum ConnectionAction {
309 None,
310 Reset,
311 SendUnloadComplete,
312 Reconnect {
313 initiate_contact: InitiateContactRequest,
314 },
315 SendFailedVersionResponse,
316}
317
318#[derive(PartialEq, Eq, Debug, Copy, Clone)]
319pub enum MonitorPageRequest {
320 None,
321 Some(MonitorPageGpas),
322 Invalid,
323}
324
325#[derive(PartialEq, Eq, Debug, Copy, Clone)]
326pub struct InitiateContactRequest {
327 pub version_requested: u32,
328 pub target_message_vp: u32,
329 pub monitor_page: MonitorPageRequest,
330 pub target_sint: u8,
331 pub target_vtl: u8,
332 pub feature_flags: u32,
333 pub interrupt_page: Option<u64>,
334 pub client_id: Guid,
335 pub trusted: bool,
336}
337
338#[derive(Debug, Copy, Clone)]
339pub struct OpenRequest {
340 pub open_id: u32,
341 pub ring_buffer_gpadl_id: GpadlId,
342 pub target_vp: u32,
343 pub downstream_ring_buffer_page_offset: u32,
344 pub user_data: UserDefinedData,
345 pub guest_specified_interrupt_info: Option<SignalInfo>,
346 pub flags: protocol::OpenChannelFlags,
347}
348
349#[derive(Debug, Copy, Clone, Eq, PartialEq)]
350pub enum Update<T: std::fmt::Debug + Copy + Clone> {
351 Unchanged,
352 Reset,
353 Set(T),
354}
355
356impl<T: std::fmt::Debug + Copy + Clone> From<Option<T>> for Update<T> {
357 fn from(value: Option<T>) -> Self {
358 match value {
359 None => Self::Reset,
360 Some(value) => Self::Set(value),
361 }
362 }
363}
364
365#[derive(Debug, Copy, Clone, Eq, PartialEq)]
366pub struct ModifyConnectionRequest {
367 pub version: Option<VersionInfo>,
368 pub monitor_page: Update<MonitorPageGpas>,
369 pub interrupt_page: Update<u64>,
370 pub target_message_vp: Option<u32>,
371 pub notify_relay: bool,
372}
373
374impl Default for ModifyConnectionRequest {
376 fn default() -> Self {
377 Self {
378 version: None,
379 monitor_page: Update::Unchanged,
380 interrupt_page: Update::Unchanged,
381 target_message_vp: None,
382 notify_relay: true,
383 }
384 }
385}
386
387impl From<protocol::ModifyConnection> for ModifyConnectionRequest {
388 fn from(value: protocol::ModifyConnection) -> Self {
389 let monitor_page = if value.parent_to_child_monitor_page_gpa != 0 {
390 Update::Set(MonitorPageGpas {
391 parent_to_child: value.parent_to_child_monitor_page_gpa,
392 child_to_parent: value.child_to_parent_monitor_page_gpa,
393 })
394 } else {
395 Update::Reset
396 };
397
398 Self {
399 monitor_page,
400 ..Default::default()
401 }
402 }
403}
404
405#[derive(Debug, Copy, Clone)]
407pub enum ModifyConnectionResponse {
408 Supported(
414 protocol::ConnectionState,
415 FeatureFlags,
416 Option<MonitorPageGpas>,
417 ),
418 Unsupported,
420 Modified(protocol::ConnectionState),
423}
424
425#[derive(Debug, Copy, Clone)]
426pub enum ModifyState {
427 NotModifying,
428 Modifying { pending_target_vp: Option<u32> },
429}
430
431impl ModifyState {
432 pub fn is_modifying(&self) -> bool {
433 matches!(self, ModifyState::Modifying { .. })
434 }
435}
436
437#[derive(Debug, Copy, Clone)]
438pub struct SignalInfo {
439 pub event_flag: u16,
440 pub connection_id: u32,
441}
442
443#[derive(Debug, Copy, Clone, PartialEq, Eq)]
444enum RestoreState {
445 New,
447 Restoring,
451 Unmatched,
454 Restored,
456}
457
458#[derive(Debug, Clone)]
460enum ChannelState {
461 ClientReleased,
465
466 Closed,
468
469 Opening {
472 request: OpenRequest,
473 reserved_state: Option<ReservedState>,
474 },
475
476 Open {
478 params: OpenRequest,
479 modify_state: ModifyState,
480 reserved_state: Option<ReservedState>,
481 },
482
483 Closing {
485 params: OpenRequest,
486 reserved_state: Option<ReservedState>,
487 },
488
489 ClosingReopen {
492 params: OpenRequest,
493 request: OpenRequest,
494 },
495
496 Revoked,
498
499 Reoffered,
502
503 ClosingClientRelease,
506
507 OpeningClientRelease,
510}
511
512impl ChannelState {
513 fn is_released(&self) -> bool {
516 match self {
517 ChannelState::Closed
518 | ChannelState::Opening { .. }
519 | ChannelState::Open { .. }
520 | ChannelState::Closing { .. }
521 | ChannelState::ClosingReopen { .. }
522 | ChannelState::Revoked
523 | ChannelState::Reoffered => false,
524
525 ChannelState::ClientReleased
526 | ChannelState::ClosingClientRelease
527 | ChannelState::OpeningClientRelease => true,
528 }
529 }
530
531 fn is_revoked(&self) -> bool {
533 match self {
534 ChannelState::Revoked | ChannelState::Reoffered => true,
535
536 ChannelState::ClientReleased
537 | ChannelState::Closed
538 | ChannelState::Opening { .. }
539 | ChannelState::Open { .. }
540 | ChannelState::Closing { .. }
541 | ChannelState::ClosingReopen { .. }
542 | ChannelState::ClosingClientRelease
543 | ChannelState::OpeningClientRelease => false,
544 }
545 }
546
547 fn is_reserved(&self) -> bool {
548 match self {
549 ChannelState::Open {
551 reserved_state: Some(_),
552 ..
553 }
554 | ChannelState::Opening {
555 reserved_state: Some(_),
556 ..
557 }
558 | ChannelState::Closing {
559 reserved_state: Some(_),
560 ..
561 } => true,
562
563 ChannelState::Opening { .. }
564 | ChannelState::Open { .. }
565 | ChannelState::Closing { .. }
566 | ChannelState::ClientReleased
567 | ChannelState::Closed
568 | ChannelState::ClosingReopen { .. }
569 | ChannelState::Revoked
570 | ChannelState::Reoffered
571 | ChannelState::ClosingClientRelease
572 | ChannelState::OpeningClientRelease => false,
573 }
574 }
575}
576
577impl Display for ChannelState {
578 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
579 let state = match self {
580 Self::ClientReleased => "ClientReleased",
581 Self::Closed => "Closed",
582 Self::Opening { .. } => "Opening",
583 Self::Open { .. } => "Open",
584 Self::Closing { .. } => "Closing",
585 Self::ClosingReopen { .. } => "ClosingReopen",
586 Self::Revoked => "Revoked",
587 Self::Reoffered => "Reoffered",
588 Self::ClosingClientRelease => "ClosingClientRelease",
589 Self::OpeningClientRelease => "OpeningClientRelease",
590 };
591 write!(f, "{}", state)
592 }
593}
594
595#[derive(Debug, Clone, Default, mesh::MeshPayload)]
597pub enum MnfUsage {
598 #[default]
600 Disabled,
601 Enabled { latency: Duration },
603 Relayed { monitor_id: u8 },
606}
607
608impl MnfUsage {
609 pub fn is_enabled(&self) -> bool {
610 matches!(self, Self::Enabled { .. })
611 }
612
613 pub fn is_relayed(&self) -> bool {
614 matches!(self, Self::Relayed { .. })
615 }
616
617 pub fn enabled_and_then<T>(&self, f: impl FnOnce(Duration) -> Option<T>) -> Option<T> {
618 if let Self::Enabled { latency } = self {
619 f(*latency)
620 } else {
621 None
622 }
623 }
624}
625
626impl From<Option<Duration>> for MnfUsage {
627 fn from(value: Option<Duration>) -> Self {
628 match value {
629 None => Self::Disabled,
630 Some(latency) => Self::Enabled { latency },
631 }
632 }
633}
634
635#[derive(Debug, Clone, Default, mesh::MeshPayload)]
636pub struct OfferParamsInternal {
637 pub interface_name: String,
639 pub instance_id: Guid,
640 pub interface_id: Guid,
641 pub mmio_megabytes: u16,
642 pub mmio_megabytes_optional: u16,
643 pub subchannel_index: u16,
644 pub use_mnf: MnfUsage,
645 pub offer_order: Option<u64>,
646 pub flags: OfferFlags,
647 pub user_defined: UserDefinedData,
648}
649
650impl OfferParamsInternal {
651 pub fn key(&self) -> OfferKey {
653 OfferKey {
654 interface_id: self.interface_id,
655 instance_id: self.instance_id,
656 subchannel_index: self.subchannel_index,
657 }
658 }
659}
660
661impl From<OfferParams> for OfferParamsInternal {
662 fn from(value: OfferParams) -> Self {
663 let mut user_defined = UserDefinedData::new_zeroed();
664
665 let mut flags = OfferFlags::new()
668 .with_confidential_ring_buffer(true)
669 .with_confidential_external_memory(value.allow_confidential_external_memory);
670
671 match value.channel_type {
672 ChannelType::Device { pipe_packets } => {
673 if pipe_packets {
674 flags.set_named_pipe_mode(true);
675 user_defined.as_pipe_params_mut().pipe_type = protocol::PipeType::MESSAGE;
676 }
677 }
678 ChannelType::Interface {
679 user_defined: interface_user_defined,
680 } => {
681 flags.set_enumerate_device_interface(true);
682 user_defined = interface_user_defined;
683 }
684 ChannelType::Pipe { message_mode } => {
685 flags.set_enumerate_device_interface(true);
686 flags.set_named_pipe_mode(true);
687 user_defined.as_pipe_params_mut().pipe_type = if message_mode {
688 protocol::PipeType::MESSAGE
689 } else {
690 protocol::PipeType::BYTE
691 };
692 }
693 ChannelType::HvSocket {
694 is_connect,
695 is_for_container,
696 silo_id,
697 } => {
698 flags.set_enumerate_device_interface(true);
699 flags.set_tlnpi_provider(true);
700 flags.set_named_pipe_mode(true);
701 *user_defined.as_hvsock_params_mut() = protocol::HvsockUserDefinedParameters::new(
702 is_connect,
703 is_for_container,
704 silo_id,
705 );
706 }
707 };
708
709 Self {
710 interface_name: value.interface_name,
711 instance_id: value.instance_id,
712 interface_id: value.interface_id,
713 mmio_megabytes: value.mmio_megabytes,
714 mmio_megabytes_optional: value.mmio_megabytes_optional,
715 subchannel_index: value.subchannel_index,
716 use_mnf: value.mnf_interrupt_latency.into(),
717 offer_order: value.offer_order,
718 user_defined,
719 flags,
720 }
721 }
722}
723
724#[derive(Debug, Copy, Clone, Inspect, PartialEq, Eq)]
725pub struct ConnectionTarget {
726 pub vp: u32,
727 pub sint: u8,
728}
729
730#[derive(Debug, Copy, Clone, PartialEq, Eq)]
731pub enum MessageTarget {
732 Default,
733 ReservedChannel(OfferId, ConnectionTarget),
734 Custom(ConnectionTarget),
735}
736
737impl MessageTarget {
738 pub fn for_offer(offer_id: OfferId, reserved_state: &Option<ReservedState>) -> Self {
739 if let Some(state) = reserved_state {
740 Self::ReservedChannel(offer_id, state.target)
741 } else {
742 Self::Default
743 }
744 }
745}
746
747#[derive(Debug, Copy, Clone)]
748pub struct ReservedState {
749 version: VersionInfo,
750 target: ConnectionTarget,
751}
752
753#[derive(Debug)]
755struct Channel {
756 info: Option<OfferedInfo>,
757 offer: OfferParamsInternal,
758 state: ChannelState,
759 restore_state: RestoreState,
760}
761
762#[derive(Debug, Copy, Clone)]
763struct OfferedInfo {
764 channel_id: ChannelId,
765 connection_id: u32,
766 monitor_id: Option<MonitorId>,
767}
768
769impl Channel {
770 fn inspect_state(&self, resp: &mut inspect::Response<'_>) {
771 let mut target_vp = None;
772 let mut event_flag = None;
773 let mut connection_id = None;
774 let mut reserved_target = None;
775 let state = match &self.state {
776 ChannelState::ClientReleased => "client_released",
777 ChannelState::Closed => "closed",
778 ChannelState::Opening { reserved_state, .. } => {
779 reserved_target = reserved_state.map(|state| state.target);
780 "opening"
781 }
782 ChannelState::Open {
783 params,
784 reserved_state,
785 ..
786 } => {
787 target_vp = Some(params.target_vp);
788 if let Some(id) = params.guest_specified_interrupt_info {
789 event_flag = Some(id.event_flag);
790 connection_id = Some(id.connection_id);
791 }
792 reserved_target = reserved_state.map(|state| state.target);
793 "open"
794 }
795 ChannelState::Closing { reserved_state, .. } => {
796 reserved_target = reserved_state.map(|state| state.target);
797 "closing"
798 }
799 ChannelState::ClosingReopen { .. } => "closing_reopen",
800 ChannelState::Revoked => "revoked",
801 ChannelState::Reoffered => "reoffered",
802 ChannelState::ClosingClientRelease => "closing_client_release",
803 ChannelState::OpeningClientRelease => "opening_client_release",
804 };
805 let restore_state = match self.restore_state {
806 RestoreState::New => "new",
807 RestoreState::Restoring => "restoring",
808 RestoreState::Restored => "restored",
809 RestoreState::Unmatched => "unmatched",
810 };
811 if let Some(info) = &self.info {
812 resp.field("channel_id", info.channel_id.0)
813 .field("offered_connection_id", info.connection_id)
814 .field("monitor_id", info.monitor_id.map(|id| id.0));
815 }
816 resp.field("state", state)
817 .field("restore_state", restore_state)
818 .field("interface_name", self.offer.interface_name.clone())
819 .display("instance_id", &self.offer.instance_id)
820 .display("interface_id", &self.offer.interface_id)
821 .field("mmio_megabytes", self.offer.mmio_megabytes)
822 .field("target_vp", target_vp)
823 .field("guest_specified_event_flag", event_flag)
824 .field("guest_specified_connection_id", connection_id)
825 .field("reserved_connection_target", reserved_target)
826 .binary("offer_flags", self.offer.flags.into_bits());
827 }
828
829 fn handled_monitor_info(&self) -> Option<MonitorInfo> {
839 self.offer.use_mnf.enabled_and_then(|latency| {
840 if self.state.is_reserved() {
841 None
842 } else {
843 self.info.and_then(|info| {
844 info.monitor_id.map(|monitor_id| MonitorInfo {
845 monitor_id,
846 latency,
847 })
848 })
849 }
850 })
851 }
852
853 fn prepare_channel(
856 &mut self,
857 offer_id: OfferId,
858 assigned_channels: &mut AssignedChannels,
859 assigned_monitors: &mut AssignedMonitors,
860 ) {
861 assert!(self.info.is_none());
862
863 let entry = assigned_channels
865 .allocate()
866 .expect("there are enough channel IDs for everything in ChannelList");
867
868 let channel_id = entry.id();
869 entry.insert(offer_id);
870 let connection_id = ConnectionId::new(channel_id.0, assigned_channels.vtl, SINT);
871
872 let monitor_id = match self.offer.use_mnf {
877 MnfUsage::Enabled { .. } => {
878 let monitor_id = assigned_monitors.assign_monitor();
879 if monitor_id.is_none() {
880 tracelimit::warn_ratelimited!("Out of monitor IDs.");
881 }
882
883 monitor_id
884 }
885 MnfUsage::Relayed { monitor_id } => Some(MonitorId(monitor_id)),
886 MnfUsage::Disabled => None,
887 };
888
889 self.info = Some(OfferedInfo {
890 channel_id,
891 connection_id: connection_id.0,
892 monitor_id,
893 });
894 }
895
896 fn release_channel(
898 &mut self,
899 offer_id: OfferId,
900 assigned_channels: &mut AssignedChannels,
901 assigned_monitors: &mut AssignedMonitors,
902 ) {
903 if let Some(info) = self.info.take() {
904 assigned_channels.free(info.channel_id, offer_id);
905
906 if let Some(monitor_id) = info.monitor_id {
908 if self.offer.use_mnf.is_enabled() {
909 assigned_monitors.release_monitor(monitor_id);
910 }
911 }
912 }
913 }
914}
915
916#[derive(Debug)]
917struct AssignedChannels {
918 assignments: Vec<Option<OfferId>>,
919 vtl: Vtl,
920 reserved_offset: usize,
921 count_in_reserved_range: usize,
923}
924
925impl AssignedChannels {
926 fn new(vtl: Vtl, channel_id_offset: u16) -> Self {
927 Self {
928 assignments: vec![None; MAX_CHANNELS],
929 vtl,
930 reserved_offset: channel_id_offset as usize,
931 count_in_reserved_range: 0,
932 }
933 }
934
935 fn allowable_channel_count(&self) -> usize {
936 MAX_CHANNELS - self.reserved_offset + self.count_in_reserved_range
937 }
938
939 fn get(&self, channel_id: ChannelId) -> Option<OfferId> {
940 self.assignments
941 .get(Self::index(channel_id))
942 .copied()
943 .flatten()
944 }
945
946 fn set(&mut self, channel_id: ChannelId) -> Result<AssignmentEntry<'_>, OfferError> {
947 let index = Self::index(channel_id);
948 if self
949 .assignments
950 .get(index)
951 .ok_or(OfferError::InvalidChannelId(channel_id))?
952 .is_some()
953 {
954 return Err(OfferError::ChannelIdInUse(channel_id));
955 }
956 Ok(AssignmentEntry { list: self, index })
957 }
958
959 fn allocate(&mut self) -> Option<AssignmentEntry<'_>> {
960 let index = self.reserved_offset
961 + self.assignments[self.reserved_offset..]
962 .iter()
963 .position(|x| x.is_none())?;
964 Some(AssignmentEntry { list: self, index })
965 }
966
967 fn free(&mut self, channel_id: ChannelId, offer_id: OfferId) {
968 let index = Self::index(channel_id);
969 let slot = &mut self.assignments[index];
970 assert_eq!(slot.take(), Some(offer_id));
971 if index < self.reserved_offset {
972 self.count_in_reserved_range -= 1;
973 }
974 }
975
976 fn index(channel_id: ChannelId) -> usize {
977 channel_id.0.wrapping_sub(1) as usize
978 }
979}
980
981struct AssignmentEntry<'a> {
982 list: &'a mut AssignedChannels,
983 index: usize,
984}
985
986impl AssignmentEntry<'_> {
987 pub fn id(&self) -> ChannelId {
988 ChannelId(self.index as u32 + 1)
989 }
990
991 pub fn insert(self, offer_id: OfferId) {
992 assert!(
993 self.list.assignments[self.index]
994 .replace(offer_id)
995 .is_none()
996 );
997
998 if self.index < self.list.reserved_offset {
999 self.list.count_in_reserved_range += 1;
1000 }
1001 }
1002}
1003
1004struct ChannelList {
1005 channels: Slab<Channel>,
1006}
1007
1008fn channel_inspect_path(offer: &OfferParamsInternal, suffix: std::fmt::Arguments<'_>) -> String {
1009 if offer.subchannel_index == 0 {
1010 format!("{}{}", offer.instance_id, suffix)
1011 } else {
1012 format!(
1013 "{}/subchannels/{}{}",
1014 offer.instance_id, offer.subchannel_index, suffix
1015 )
1016 }
1017}
1018
1019impl ChannelList {
1020 fn inspect(
1021 &self,
1022 notifier: &impl Notifier,
1023 version: Option<VersionInfo>,
1024 resp: &mut inspect::Response<'_>,
1025 ) {
1026 for (offer_id, channel) in self.iter() {
1027 resp.child(
1028 &channel_inspect_path(&channel.offer, format_args!("")),
1029 |req| {
1030 let mut resp = req.respond();
1031 channel.inspect_state(&mut resp);
1032
1033 resp.merge(inspect::adhoc(|req| {
1037 if !matches!(channel.state, ChannelState::Revoked) {
1038 notifier.inspect(version, offer_id, req);
1039 }
1040 }));
1041 },
1042 );
1043 }
1044 }
1045}
1046
1047pub const MAX_CHANNELS: usize = 2047;
1050
1051impl ChannelList {
1052 fn new() -> Self {
1053 Self {
1054 channels: Slab::new(),
1055 }
1056 }
1057
1058 fn len(&self) -> usize {
1060 self.channels.len()
1061 }
1062
1063 fn offer(&mut self, new_channel: Channel) -> OfferId {
1065 OfferId(self.channels.insert(new_channel))
1066 }
1067
1068 fn remove(&mut self, offer_id: OfferId) {
1070 let channel = self.channels.remove(offer_id.0);
1071 assert!(channel.info.is_none());
1072 }
1073
1074 fn get_by_channel_id_mut(
1076 &mut self,
1077 assigned_channels: &AssignedChannels,
1078 channel_id: ChannelId,
1079 ) -> Result<(OfferId, &mut Channel), ChannelError> {
1080 let offer_id = assigned_channels
1081 .get(channel_id)
1082 .ok_or(ChannelError::UnknownChannelId)?;
1083 let channel = &mut self[offer_id];
1084 if channel.state.is_released() {
1085 return Err(ChannelError::ChannelReleased);
1086 }
1087 assert_eq!(
1088 channel.info.as_ref().map(|info| info.channel_id),
1089 Some(channel_id)
1090 );
1091 Ok((offer_id, channel))
1092 }
1093
1094 fn get_by_channel_id(
1096 &self,
1097 assigned_channels: &AssignedChannels,
1098 channel_id: ChannelId,
1099 ) -> Result<(OfferId, &Channel), ChannelError> {
1100 let offer_id = assigned_channels
1101 .get(channel_id)
1102 .ok_or(ChannelError::UnknownChannelId)?;
1103 let channel = &self[offer_id];
1104 if channel.state.is_released() {
1105 return Err(ChannelError::ChannelReleased);
1106 }
1107 assert_eq!(
1108 channel.info.as_ref().map(|info| info.channel_id),
1109 Some(channel_id)
1110 );
1111 Ok((offer_id, channel))
1112 }
1113
1114 fn get_by_key_mut(&mut self, key: &OfferKey) -> Option<(OfferId, &mut Channel)> {
1117 for (offer_id, channel) in self.iter_mut() {
1118 if channel.offer.instance_id == key.instance_id
1119 && channel.offer.interface_id == key.interface_id
1120 && channel.offer.subchannel_index == key.subchannel_index
1121 {
1122 return Some((offer_id, channel));
1123 }
1124 }
1125 None
1126 }
1127
1128 fn iter(&self) -> impl Iterator<Item = (OfferId, &Channel)> {
1130 self.channels
1131 .iter()
1132 .map(|(id, channel)| (OfferId(id), channel))
1133 }
1134
1135 fn iter_mut(&mut self) -> impl Iterator<Item = (OfferId, &mut Channel)> {
1137 self.channels
1138 .iter_mut()
1139 .map(|(id, channel)| (OfferId(id), channel))
1140 }
1141
1142 fn retain<F>(&mut self, mut f: F)
1144 where
1145 F: FnMut(OfferId, &mut Channel) -> bool,
1146 {
1147 self.channels.retain(|id, channel| {
1148 let retain = f(OfferId(id), channel);
1149 if !retain {
1150 assert!(channel.info.is_none());
1151 }
1152 retain
1153 })
1154 }
1155}
1156
1157impl Index<OfferId> for ChannelList {
1158 type Output = Channel;
1159
1160 fn index(&self, offer_id: OfferId) -> &Self::Output {
1161 &self.channels[offer_id.0]
1162 }
1163}
1164
1165impl IndexMut<OfferId> for ChannelList {
1166 fn index_mut(&mut self, offer_id: OfferId) -> &mut Self::Output {
1167 &mut self.channels[offer_id.0]
1168 }
1169}
1170
1171#[derive(Debug, Inspect)]
1173struct Gpadl {
1174 count: u16,
1175 #[inspect(skip)]
1176 buf: Vec<u64>,
1177 state: GpadlState,
1178}
1179
1180#[derive(Debug, Copy, Clone, PartialEq, Eq, Inspect)]
1181enum GpadlState {
1182 InProgress,
1184 Offered,
1186 OfferedTearingDown,
1189 Accepted,
1191 TearingDown,
1193}
1194
1195impl Gpadl {
1196 fn new(count: u16, len: usize) -> Self {
1199 Self {
1200 state: GpadlState::InProgress,
1201 count,
1202 buf: Vec::with_capacity(len),
1203 }
1204 }
1205
1206 fn append(&mut self, data: &[u8]) -> Result<bool, ChannelError> {
1208 if self.state == GpadlState::InProgress {
1209 let buf = &mut self.buf;
1210 let len = min(data.len() & !7, (buf.capacity() - buf.len()) * 8);
1215 let data = &data[..len];
1216 let start = buf.len();
1217 buf.resize(buf.len() + data.len() / 8, 0);
1218 buf[start..].as_mut_bytes().copy_from_slice(data);
1219 Ok(if buf.len() == buf.capacity() {
1220 gparange::MultiPagedRangeBuf::<Vec<u64>>::validate(self.count as usize, buf)
1221 .map_err(ChannelError::InvalidGpaRange)?;
1222 self.state = GpadlState::Offered;
1223 true
1224 } else {
1225 false
1226 })
1227 } else {
1228 Err(ChannelError::GpadlAlreadyComplete)
1229 }
1230 }
1231}
1232
1233#[derive(Debug, Copy, Clone)]
1235pub struct OpenParams {
1236 pub open_data: OpenData,
1237 pub connection_id: u32,
1238 pub event_flag: u16,
1239 pub monitor_info: Option<MonitorInfo>,
1240 pub flags: protocol::OpenChannelFlags,
1241 pub reserved_target: Option<ConnectionTarget>,
1242 pub channel_id: ChannelId,
1243}
1244
1245impl OpenParams {
1246 fn from_request(
1247 info: &OfferedInfo,
1248 request: &OpenRequest,
1249 monitor_info: Option<MonitorInfo>,
1250 reserved_target: Option<ConnectionTarget>,
1251 ) -> Self {
1252 let (event_flag, connection_id) = if let Some(id) = request.guest_specified_interrupt_info {
1255 (id.event_flag, id.connection_id)
1256 } else {
1257 (info.channel_id.0 as u16, info.connection_id)
1258 };
1259
1260 Self {
1261 open_data: OpenData {
1262 target_vp: request.target_vp,
1263 ring_offset: request.downstream_ring_buffer_page_offset,
1264 ring_gpadl_id: request.ring_buffer_gpadl_id,
1265 user_data: request.user_data,
1266 event_flag,
1267 connection_id,
1268 },
1269 connection_id,
1270 event_flag,
1271 monitor_info,
1272 flags: request.flags.with_unused(0),
1273 reserved_target,
1274 channel_id: info.channel_id,
1275 }
1276 }
1277}
1278
1279#[derive(Debug)]
1281pub enum Action {
1282 Open(OpenParams, VersionInfo),
1283 Close,
1284 Gpadl(GpadlId, u16, Vec<u64>),
1285 TeardownGpadl {
1286 gpadl_id: GpadlId,
1287 post_restore: bool,
1288 },
1289 Modify {
1290 target_vp: u32,
1291 },
1292}
1293
1294static SUPPORTED_VERSIONS: &[Version] = &[
1296 Version::V1,
1297 Version::Win7,
1298 Version::Win8,
1299 Version::Win8_1,
1300 Version::Win10,
1301 Version::Win10Rs3_0,
1302 Version::Win10Rs3_1,
1303 Version::Win10Rs4,
1304 Version::Win10Rs5,
1305 Version::Iron,
1306 Version::Copper,
1307];
1308
1309const SUPPORTED_FEATURE_FLAGS: FeatureFlags = FeatureFlags::new()
1312 .with_guest_specified_signal_parameters(true)
1313 .with_channel_interrupt_redirection(true)
1314 .with_modify_connection(true)
1315 .with_client_id(true)
1316 .with_pause_resume(true)
1317 .with_server_specified_monitor_pages(true);
1318
1319pub trait Notifier: Send {
1321 fn notify(&mut self, offer_id: OfferId, action: Action);
1323
1324 fn forward_unhandled(&mut self, request: InitiateContactRequest);
1326
1327 fn modify_connection(&mut self, request: ModifyConnectionRequest) -> anyhow::Result<()>;
1333
1334 fn inspect(&self, version: Option<VersionInfo>, offer_id: OfferId, req: inspect::Request<'_>) {
1336 let _ = (version, offer_id, req);
1337 }
1338
1339 #[must_use]
1342 fn send_message(&mut self, message: &OutgoingMessage, target: MessageTarget) -> bool;
1343
1344 fn notify_hvsock(&mut self, request: &HvsockConnectRequest);
1346
1347 fn reset_complete(&mut self);
1349
1350 fn unload_complete(&mut self);
1352}
1353
1354impl Server {
1355 pub fn new(
1357 vtl: Vtl,
1358 child_connection_id: u32,
1359 channel_id_offset: u16,
1360 use_absolute_channel_order: bool,
1361 ) -> Self {
1362 Server {
1363 state: ConnectionState::Disconnected,
1364 channels: ChannelList::new(),
1365 assigned_channels: AssignedChannels::new(vtl, channel_id_offset),
1366 assigned_monitors: AssignedMonitors::new(),
1367 gpadls: Default::default(),
1368 incomplete_gpadls: Default::default(),
1369 child_connection_id,
1370 max_version: None,
1371 delayed_max_version: None,
1372 pending_messages: PendingMessages(VecDeque::new()),
1373 require_server_allocated_mnf: false,
1374 use_absolute_channel_order,
1375 }
1376 }
1377
1378 pub fn with_notifier<'a, T: Notifier>(
1380 &'a mut self,
1381 notifier: &'a mut T,
1382 ) -> ServerWithNotifier<'a, T> {
1383 self.validate();
1384 ServerWithNotifier {
1385 inner: self,
1386 notifier,
1387 }
1388 }
1389
1390 pub fn set_require_server_allocated_mnf(&mut self, require: bool) {
1393 self.require_server_allocated_mnf = require;
1394 }
1395
1396 fn validate(&self) {
1397 #[cfg(debug_assertions)]
1398 for (_, channel) in self.channels.iter() {
1399 let should_have_info = !channel.state.is_released();
1400 if channel.info.is_some() != should_have_info {
1401 panic!("channel invariant violation: {channel:?}");
1402 }
1403 }
1404 }
1405
1406 pub fn set_compatibility_version(&mut self, version: MaxVersionInfo, delay: bool) {
1408 if delay {
1409 self.delayed_max_version = Some(version)
1410 } else {
1411 tracing::info!(?version, "Limiting VmBus connections to version");
1412 self.max_version = Some(version);
1413 }
1414 }
1415
1416 pub fn channel_gpadls(&self, offer_id: OfferId) -> Vec<RestoredGpadl> {
1417 self.gpadls
1418 .iter()
1419 .filter_map(|(&(gpadl_id, gpadl_offer_id), gpadl)| {
1420 if offer_id != gpadl_offer_id {
1421 return None;
1422 }
1423 let accepted = match gpadl.state {
1424 GpadlState::Offered | GpadlState::OfferedTearingDown => false,
1425 GpadlState::Accepted => true,
1426 GpadlState::InProgress | GpadlState::TearingDown => return None,
1427 };
1428 Some(RestoredGpadl {
1429 request: GpadlRequest {
1430 id: gpadl_id,
1431 count: gpadl.count,
1432 buf: gpadl.buf.clone(),
1433 },
1434 accepted,
1435 })
1436 })
1437 .collect()
1438 }
1439
1440 pub fn get_version(&self) -> Option<VersionInfo> {
1441 self.state.get_version()
1442 }
1443
1444 pub fn get_restore_open_params(&self, offer_id: OfferId) -> Result<OpenParams, RestoreError> {
1445 let channel = &self.channels[offer_id];
1446
1447 match channel.restore_state {
1449 RestoreState::New => {
1450 return Err(RestoreError::MissingChannel(channel.offer.key()));
1454 }
1455 RestoreState::Restoring => {}
1456 RestoreState::Unmatched => unreachable!(),
1457 RestoreState::Restored => {
1458 return Err(RestoreError::AlreadyRestored(channel.offer.key()));
1459 }
1460 }
1461
1462 let info = channel
1463 .info
1464 .ok_or_else(|| RestoreError::MissingChannel(channel.offer.key()))?;
1465
1466 let (request, reserved_state) = match channel.state {
1467 ChannelState::Closed => {
1468 return Err(RestoreError::MismatchedOpenState(channel.offer.key()));
1469 }
1470 ChannelState::Closing { params, .. } | ChannelState::ClosingReopen { params, .. } => {
1471 (params, None)
1472 }
1473 ChannelState::Opening {
1474 request,
1475 reserved_state,
1476 } => (request, reserved_state),
1477 ChannelState::Open {
1478 params,
1479 reserved_state,
1480 ..
1481 } => (params, reserved_state),
1482 ChannelState::ClientReleased | ChannelState::Reoffered => {
1483 return Err(RestoreError::MissingChannel(channel.offer.key()));
1484 }
1485 ChannelState::Revoked
1486 | ChannelState::ClosingClientRelease
1487 | ChannelState::OpeningClientRelease => unreachable!(),
1488 };
1489
1490 Ok(OpenParams::from_request(
1491 &info,
1492 &request,
1493 channel.handled_monitor_info(),
1494 reserved_state.map(|state| state.target),
1495 ))
1496 }
1497
1498 pub fn has_pending_messages(&self) -> bool {
1500 !self.pending_messages.0.is_empty() && !self.state.is_paused()
1501 }
1502
1503 pub fn poll_flush_pending_messages(
1505 &mut self,
1506 mut send: impl FnMut(&OutgoingMessage) -> Poll<()>,
1507 ) -> Poll<()> {
1508 if !self.state.is_paused() {
1509 while let Some(message) = self.pending_messages.0.front() {
1510 ready!(send(message));
1511 self.pending_messages.0.pop_front();
1512 }
1513 }
1514
1515 Poll::Ready(())
1516 }
1517}
1518
1519impl<'a, N: 'a + Notifier> ServerWithNotifier<'a, N> {
1520 pub fn restore_channel(&mut self, offer_id: OfferId, open: bool) -> Result<(), RestoreError> {
1526 let channel = &mut self.inner.channels[offer_id];
1527
1528 match channel.restore_state {
1531 RestoreState::New => {
1532 if open {
1536 return Err(RestoreError::MissingChannel(channel.offer.key()));
1537 } else {
1538 return Ok(());
1539 }
1540 }
1541 RestoreState::Restoring => {}
1542 RestoreState::Unmatched => unreachable!(),
1543 RestoreState::Restored => {
1544 return Err(RestoreError::AlreadyRestored(channel.offer.key()));
1545 }
1546 }
1547
1548 let info = channel
1549 .info
1550 .ok_or_else(|| RestoreError::MissingChannel(channel.offer.key()))?;
1551
1552 if let Some(monitor_info) = channel.handled_monitor_info() {
1553 if !self
1554 .inner
1555 .assigned_monitors
1556 .claim_monitor(monitor_info.monitor_id)
1557 {
1558 return Err(RestoreError::DuplicateMonitorId(monitor_info.monitor_id.0));
1559 }
1560 }
1561
1562 if open {
1563 match channel.state {
1564 ChannelState::Closed => {
1565 return Err(RestoreError::MismatchedOpenState(channel.offer.key()));
1566 }
1567 ChannelState::Closing { .. } | ChannelState::ClosingReopen { .. } => {
1568 self.notifier.notify(offer_id, Action::Close);
1569 }
1570 ChannelState::Opening {
1571 request,
1572 reserved_state,
1573 } => {
1574 self.inner
1575 .pending_messages
1576 .sender(self.notifier, self.inner.state.is_paused())
1577 .send_open_result(
1578 info.channel_id,
1579 &request,
1580 protocol::STATUS_SUCCESS,
1581 MessageTarget::for_offer(offer_id, &reserved_state),
1582 );
1583 channel.state = ChannelState::Open {
1584 params: request,
1585 modify_state: ModifyState::NotModifying,
1586 reserved_state,
1587 };
1588 }
1589 ChannelState::Open { .. } => {}
1590 ChannelState::ClientReleased | ChannelState::Reoffered => {
1591 return Err(RestoreError::MissingChannel(channel.offer.key()));
1592 }
1593 ChannelState::Revoked
1594 | ChannelState::ClosingClientRelease
1595 | ChannelState::OpeningClientRelease => unreachable!(),
1596 };
1597 } else {
1598 match channel.state {
1599 ChannelState::Closed => {}
1600 ChannelState::Reoffered => {}
1605 ChannelState::Closing { .. } => {
1606 channel.state = ChannelState::Closed;
1607 }
1608 ChannelState::ClosingReopen { request, .. } => {
1609 self.notifier.notify(
1610 offer_id,
1611 Action::Open(
1612 OpenParams::from_request(
1613 &info,
1614 &request,
1615 channel.handled_monitor_info(),
1616 None,
1617 ),
1618 self.inner.state.get_version().expect("must be connected"),
1619 ),
1620 );
1621 channel.state = ChannelState::Opening {
1622 request,
1623 reserved_state: None,
1624 };
1625 }
1626 ChannelState::Opening {
1627 request,
1628 reserved_state,
1629 } => {
1630 self.notifier.notify(
1631 offer_id,
1632 Action::Open(
1633 OpenParams::from_request(
1634 &info,
1635 &request,
1636 channel.handled_monitor_info(),
1637 reserved_state.map(|state| state.target),
1638 ),
1639 self.inner.state.get_version().expect("must be connected"),
1640 ),
1641 );
1642 }
1643 ChannelState::Open { .. } => {
1644 return Err(RestoreError::MismatchedOpenState(channel.offer.key()));
1645 }
1646 ChannelState::ClientReleased => {
1647 return Err(RestoreError::MissingChannel(channel.offer.key()));
1648 }
1649 ChannelState::Revoked
1650 | ChannelState::ClosingClientRelease
1651 | ChannelState::OpeningClientRelease => unreachable!(),
1652 }
1653 }
1654
1655 channel.restore_state = RestoreState::Restored;
1656 Ok(())
1657 }
1658
1659 pub fn revoke_unclaimed_channels(&mut self) {
1662 for (offer_id, channel) in self.inner.channels.iter_mut() {
1663 match channel.restore_state {
1664 RestoreState::Restored => {
1665 }
1667 RestoreState::New => {
1668 if let ConnectionState::Connected(info) = &self.inner.state {
1673 if info.offers_sent && matches!(channel.state, ChannelState::ClientReleased)
1674 {
1675 channel.prepare_channel(
1676 offer_id,
1677 &mut self.inner.assigned_channels,
1678 &mut self.inner.assigned_monitors,
1679 );
1680 channel.state = ChannelState::Closed;
1681 self.inner
1682 .pending_messages
1683 .sender(self.notifier, self.inner.state.is_paused())
1684 .send_offer(channel, info);
1685 }
1686 }
1687 }
1688 RestoreState::Restoring => {
1689 let retain = revoke(
1693 self.inner
1694 .pending_messages
1695 .sender(self.notifier, self.inner.state.is_paused()),
1696 offer_id,
1697 channel,
1698 &mut self.inner.gpadls,
1699 );
1700 assert!(retain, "channel has not been released");
1701 channel.state = ChannelState::Reoffered;
1702 }
1703 RestoreState::Unmatched => {
1704 let retain = revoke(
1707 self.inner
1708 .pending_messages
1709 .sender(self.notifier, self.inner.state.is_paused()),
1710 offer_id,
1711 channel,
1712 &mut self.inner.gpadls,
1713 );
1714 assert!(retain, "channel has not been released");
1715 }
1716 }
1717 }
1718
1719 for (&(gpadl_id, offer_id), gpadl) in self.inner.gpadls.iter_mut() {
1721 match gpadl.state {
1722 GpadlState::InProgress | GpadlState::Accepted => {}
1723 GpadlState::Offered => {
1724 self.notifier.notify(
1725 offer_id,
1726 Action::Gpadl(gpadl_id, gpadl.count, gpadl.buf.clone()),
1727 );
1728 }
1729 GpadlState::TearingDown => {
1730 self.notifier.notify(
1731 offer_id,
1732 Action::TeardownGpadl {
1733 gpadl_id,
1734 post_restore: true,
1735 },
1736 );
1737 }
1738 GpadlState::OfferedTearingDown => unreachable!(),
1739 }
1740 }
1741
1742 self.check_disconnected();
1743 }
1744
1745 pub fn reset(&mut self) {
1750 assert!(!self.is_resetting());
1751 if self.request_disconnect(ConnectionAction::Reset) {
1752 self.complete_reset();
1753 }
1754 }
1755
1756 fn complete_reset(&mut self) {
1757 for (_, channel) in self.inner.channels.iter_mut() {
1759 channel.restore_state = RestoreState::New;
1760 }
1761 self.inner.pending_messages.0.clear();
1762 self.notifier.reset_complete();
1763 }
1764
1765 pub fn offer_channel(&mut self, offer: OfferParamsInternal) -> Result<OfferId, OfferError> {
1767 if let Some((offer_id, channel)) = self.inner.channels.get_by_key_mut(&offer.key()) {
1769 if channel.restore_state != RestoreState::Unmatched
1773 && !matches!(channel.state, ChannelState::Revoked)
1774 {
1775 return Err(OfferError::AlreadyExists(offer.key()));
1776 }
1777
1778 let info = channel.info.expect("assigned");
1779 if channel.restore_state == RestoreState::Unmatched {
1780 tracing::debug!(
1781 offer_id = offer_id.0,
1782 key = %channel.offer.key(),
1783 "matched channel"
1784 );
1785
1786 assert!(!matches!(channel.state, ChannelState::Revoked));
1787 channel.restore_state = RestoreState::Restoring;
1791
1792 if let MnfUsage::Relayed { monitor_id } = offer.use_mnf {
1795 if info.monitor_id != Some(MonitorId(monitor_id)) {
1796 return Err(OfferError::MismatchedMonitorId(
1797 info.monitor_id,
1798 MonitorId(monitor_id),
1799 ));
1800 }
1801 }
1802 } else {
1803 channel.state = ChannelState::Reoffered;
1807 tracing::info!(?offer_id, key = %channel.offer.key(), "channel marked for reoffer");
1808 }
1809
1810 channel.offer = offer;
1811 return Ok(offer_id);
1812 }
1813
1814 let mut connected_info = None;
1815 let state = match &self.inner.state {
1816 ConnectionState::Connected(info) => {
1817 if info.offers_sent {
1818 connected_info = Some(info);
1819 ChannelState::Closed
1820 } else {
1821 ChannelState::ClientReleased
1822 }
1823 }
1824 ConnectionState::Connecting { .. }
1825 | ConnectionState::Disconnecting { .. }
1826 | ConnectionState::Disconnected => ChannelState::ClientReleased,
1827 };
1828
1829 if self.inner.channels.len() >= self.inner.assigned_channels.allowable_channel_count() {
1831 return Err(OfferError::TooManyChannels);
1832 }
1833
1834 let key = offer.key();
1835 let confidential_ring_buffer = offer.flags.confidential_ring_buffer();
1836 let confidential_external_memory = offer.flags.confidential_external_memory();
1837 let channel = Channel {
1838 info: None,
1839 offer,
1840 state,
1841 restore_state: RestoreState::New,
1842 };
1843
1844 let offer_id = self.inner.channels.offer(channel);
1845 if let Some(info) = connected_info {
1846 let channel = &mut self.inner.channels[offer_id];
1847 channel.prepare_channel(
1848 offer_id,
1849 &mut self.inner.assigned_channels,
1850 &mut self.inner.assigned_monitors,
1851 );
1852
1853 self.inner
1854 .pending_messages
1855 .sender(self.notifier, self.inner.state.is_paused())
1856 .send_offer(channel, info);
1857 }
1858
1859 tracing::info!(?offer_id, %key, confidential_ring_buffer, confidential_external_memory, "new channel");
1860 Ok(offer_id)
1861 }
1862
1863 pub fn revoke_channel(&mut self, offer_id: OfferId) {
1865 let channel = &mut self.inner.channels[offer_id];
1866 let retain = revoke(
1867 self.inner
1868 .pending_messages
1869 .sender(self.notifier, self.inner.state.is_paused()),
1870 offer_id,
1871 channel,
1872 &mut self.inner.gpadls,
1873 );
1874 if !retain {
1875 self.inner.channels.remove(offer_id);
1876 }
1877
1878 self.check_disconnected();
1879 }
1880
1881 pub fn open_complete(&mut self, offer_id: OfferId, result: i32) {
1883 let channel = &mut self.inner.channels[offer_id];
1884 tracing::debug!(offer_id = offer_id.0, key = %channel.offer.key(), result, "open complete");
1885
1886 match channel.state {
1887 ChannelState::Opening {
1888 request,
1889 reserved_state,
1890 } => {
1891 let channel_id = channel.info.expect("assigned").channel_id;
1892 if result >= 0 {
1893 tracelimit::info_ratelimited!(
1894 offer_id = offer_id.0,
1895 channel_id = channel_id.0,
1896 key = %channel.offer.key(),
1897 result,
1898 "opened channel"
1899 );
1900 } else {
1901 tracelimit::error_ratelimited!(
1903 offer_id = offer_id.0,
1904 channel_id = channel_id.0,
1905 key = %channel.offer.key(),
1906 result,
1907 "failed to open channel"
1908 );
1909 }
1910
1911 self.inner
1912 .pending_messages
1913 .sender(self.notifier, self.inner.state.is_paused())
1914 .send_open_result(
1915 channel_id,
1916 &request,
1917 result,
1918 MessageTarget::for_offer(offer_id, &reserved_state),
1919 );
1920 channel.state = if result >= 0 {
1921 ChannelState::Open {
1922 params: request,
1923 modify_state: ModifyState::NotModifying,
1924 reserved_state,
1925 }
1926 } else {
1927 ChannelState::Closed
1928 };
1929 }
1930 ChannelState::OpeningClientRelease => {
1931 tracing::info!(
1932 offer_id = offer_id.0,
1933 key = %channel.offer.key(),
1934 result,
1935 "opened channel (client released)"
1936 );
1937
1938 if result >= 0 {
1939 channel.state = ChannelState::ClosingClientRelease;
1940 self.notifier.notify(offer_id, Action::Close);
1941 } else {
1942 channel.state = ChannelState::ClientReleased;
1943 self.check_disconnected();
1944 }
1945 }
1946
1947 ChannelState::ClientReleased
1948 | ChannelState::Closed
1949 | ChannelState::Open { .. }
1950 | ChannelState::Closing { .. }
1951 | ChannelState::ClosingReopen { .. }
1952 | ChannelState::Revoked
1953 | ChannelState::Reoffered
1954 | ChannelState::ClosingClientRelease => {
1955 tracing::error!(?offer_id, key = %channel.offer.key(), state = ?channel.state, "invalid open complete")
1956 }
1957 }
1958 }
1959
1960 fn are_channels_reset(&self, include_reserved: bool) -> bool {
1963 self.inner.gpadls.keys().all(|(_, offer_id)| {
1964 !include_reserved && self.inner.channels[*offer_id].state.is_reserved()
1965 }) && self.inner.channels.iter().all(|(_, channel)| {
1966 matches!(channel.state, ChannelState::ClientReleased)
1967 || (!include_reserved && channel.state.is_reserved())
1968 })
1969 }
1970
1971 fn check_disconnected(&mut self) {
1975 match self.inner.state {
1976 ConnectionState::Disconnecting {
1977 next_action,
1978 modify_sent: false,
1979 } => {
1980 if self.are_channels_reset(matches!(next_action, ConnectionAction::Reset)) {
1981 self.notify_disconnect(next_action);
1982 }
1983 }
1984 ConnectionState::Disconnecting {
1985 modify_sent: true, ..
1986 }
1987 | ConnectionState::Disconnected
1988 | ConnectionState::Connected { .. }
1989 | ConnectionState::Connecting { .. } => (),
1990 }
1991 }
1992
1993 fn notify_disconnect(&mut self, next_action: ConnectionAction) {
1995 debug_assert!(self.are_channels_reset(matches!(next_action, ConnectionAction::Reset)));
1997 self.inner.state = ConnectionState::Disconnecting {
1998 next_action,
1999 modify_sent: true,
2000 };
2001
2002 self.notifier
2004 .modify_connection(ModifyConnectionRequest {
2005 monitor_page: Update::Reset,
2006 interrupt_page: Update::Reset,
2007 ..Default::default()
2008 })
2009 .expect("resetting state should not fail");
2010 }
2011
2012 fn is_resetting(&self) -> bool {
2015 matches!(
2016 &self.inner.state,
2017 ConnectionState::Connecting {
2018 next_action: ConnectionAction::Reset,
2019 ..
2020 } | ConnectionState::Disconnecting {
2021 next_action: ConnectionAction::Reset,
2022 ..
2023 }
2024 )
2025 }
2026
2027 pub fn close_complete(&mut self, offer_id: OfferId) {
2029 let channel = &mut self.inner.channels[offer_id];
2030 tracing::info!(offer_id = offer_id.0, key = %channel.offer.key(), "closed channel");
2031 match channel.state {
2032 ChannelState::Closing {
2033 reserved_state: Some(reserved_state),
2034 ..
2035 } => {
2036 channel.state = ChannelState::Closed;
2037 if matches!(self.inner.state, ConnectionState::Connected { .. }) {
2038 let channel_id = channel.info.expect("assigned").channel_id;
2039 self.send_close_reserved_channel_response(
2040 channel_id,
2041 offer_id,
2042 reserved_state.target,
2043 );
2044 } else {
2045 if Self::client_release_channel(
2048 self.inner
2049 .pending_messages
2050 .sender(self.notifier, self.inner.state.is_paused()),
2051 offer_id,
2052 channel,
2053 &mut self.inner.gpadls,
2054 &mut self.inner.assigned_channels,
2055 &mut self.inner.assigned_monitors,
2056 None,
2057 ) {
2058 self.inner.channels.remove(offer_id);
2059 }
2060 }
2061 }
2062 ChannelState::Closing { .. } => {
2063 channel.state = ChannelState::Closed;
2064 }
2065 ChannelState::ClosingClientRelease => {
2066 channel.state = ChannelState::ClientReleased;
2067 self.check_disconnected();
2068 }
2069 ChannelState::ClosingReopen { request, .. } => {
2070 channel.state = ChannelState::Closed;
2071 self.open_channel(offer_id, &request, None);
2072 }
2073
2074 ChannelState::Closed
2075 | ChannelState::ClientReleased
2076 | ChannelState::Opening { .. }
2077 | ChannelState::Open { .. }
2078 | ChannelState::Revoked
2079 | ChannelState::Reoffered
2080 | ChannelState::OpeningClientRelease => {
2081 tracing::error!(?offer_id, key = %channel.offer.key(), state = ?channel.state, "invalid close complete")
2082 }
2083 }
2084 }
2085
2086 fn send_close_reserved_channel_response(
2087 &mut self,
2088 channel_id: ChannelId,
2089 offer_id: OfferId,
2090 target: ConnectionTarget,
2091 ) {
2092 self.sender().send_message_with_target(
2093 &protocol::CloseReservedChannelResponse { channel_id },
2094 MessageTarget::ReservedChannel(offer_id, target),
2095 );
2096 }
2097
2098 fn handle_initiate_contact(
2101 &mut self,
2102 input: &protocol::InitiateContact2,
2103 message: &SynicMessage,
2104 includes_client_id: bool,
2105 ) -> Result<(), ChannelError> {
2106 let target_info =
2107 protocol::TargetInfo::from(input.initiate_contact.interrupt_page_or_target_info);
2108
2109 let target_sint = if message.multiclient
2110 && input.initiate_contact.version_requested >= Version::Win10Rs3_1 as u32
2111 {
2112 target_info.sint()
2113 } else {
2114 SINT
2115 };
2116
2117 let target_vtl = if message.multiclient
2118 && input.initiate_contact.version_requested >= Version::Win10Rs4 as u32
2119 {
2120 target_info.vtl()
2121 } else {
2122 0
2123 };
2124
2125 let feature_flags = if input.initiate_contact.version_requested >= Version::Copper as u32 {
2126 target_info.feature_flags()
2127 } else {
2128 0
2129 };
2130
2131 let target_message_vp =
2136 if input.initiate_contact.version_requested >= Version::Win8_1 as u32 {
2137 input.initiate_contact.target_message_vp
2138 } else {
2139 0
2140 };
2141
2142 let interrupt_page = (input.initiate_contact.version_requested < Version::Win8 as u32
2149 && input.initiate_contact.interrupt_page_or_target_info != 0)
2150 .then_some(input.initiate_contact.interrupt_page_or_target_info);
2151
2152 let monitor_page = if (input.initiate_contact.parent_to_child_monitor_page_gpa == 0)
2155 != (input.initiate_contact.child_to_parent_monitor_page_gpa == 0)
2156 {
2157 MonitorPageRequest::Invalid
2158 } else if input.initiate_contact.parent_to_child_monitor_page_gpa != 0 {
2159 MonitorPageRequest::Some(MonitorPageGpas {
2160 parent_to_child: input.initiate_contact.parent_to_child_monitor_page_gpa,
2161 child_to_parent: input.initiate_contact.child_to_parent_monitor_page_gpa,
2162 })
2163 } else {
2164 MonitorPageRequest::None
2165 };
2166
2167 let client_id = if FeatureFlags::from(feature_flags).client_id() {
2170 if includes_client_id {
2171 input.client_id
2172 } else {
2173 return Err(ChannelError::ParseError(
2174 protocol::ParseError::MessageTooSmall(Some(
2175 protocol::MessageType::INITIATE_CONTACT,
2176 )),
2177 ));
2178 }
2179 } else {
2180 Guid::ZERO
2181 };
2182
2183 let request = InitiateContactRequest {
2184 version_requested: input.initiate_contact.version_requested,
2185 target_message_vp,
2186 monitor_page,
2187 target_sint,
2188 target_vtl,
2189 feature_flags,
2190 interrupt_page,
2191 client_id,
2192 trusted: message.trusted,
2193 };
2194 self.initiate_contact(request);
2195 Ok(())
2196 }
2197
2198 pub fn initiate_contact(&mut self, request: InitiateContactRequest) {
2199 let vtl = self.inner.assigned_channels.vtl as u8;
2202 if request.target_vtl != vtl {
2203 self.notifier.forward_unhandled(request);
2205 return;
2206 }
2207
2208 if request.target_sint != SINT {
2209 tracelimit::warn_ratelimited!(
2210 target_vtl = request.target_vtl,
2211 target_sint = request.target_sint,
2212 version = request.version_requested,
2213 "unsupported multiclient request",
2214 );
2215
2216 self.send_version_response_with_target(
2218 None,
2219 MessageTarget::Custom(ConnectionTarget {
2220 vp: request.target_message_vp,
2221 sint: request.target_sint,
2222 }),
2223 );
2224
2225 return;
2226 }
2227
2228 if !self.request_disconnect(ConnectionAction::Reconnect {
2229 initiate_contact: request,
2230 }) {
2231 return;
2232 }
2233
2234 let Some(version) = self.check_version_supported(&request) else {
2235 tracelimit::warn_ratelimited!(
2236 vtl,
2237 version = request.version_requested,
2238 client_id = ?request.client_id,
2239 "Guest requested unsupported version"
2240 );
2241
2242 self.send_version_response(None);
2244 return;
2245 };
2246
2247 tracelimit::info_ratelimited!(
2248 vtl,
2249 ?version,
2250 client_id = ?request.client_id,
2251 trusted = request.trusted,
2252 "Guest negotiated version"
2253 );
2254
2255 let monitor_page = match request.monitor_page {
2258 MonitorPageRequest::Some(mp) => {
2259 if self.inner.require_server_allocated_mnf {
2260 if !version.feature_flags.server_specified_monitor_pages() {
2261 tracelimit::warn_ratelimited!(
2262 "guest-supplied monitor pages not supported; MNF will be disabled"
2263 );
2264 }
2265
2266 None
2267 } else {
2268 Some(mp)
2269 }
2270 }
2271 MonitorPageRequest::None => None,
2272 MonitorPageRequest::Invalid => {
2273 self.send_version_response(Some(VersionResponseData::new(
2275 version,
2276 protocol::ConnectionState::FAILED_UNKNOWN_FAILURE,
2277 )));
2278
2279 return;
2280 }
2281 };
2282
2283 self.inner.state = ConnectionState::Connecting {
2284 info: ConnectionInfo {
2285 version,
2286 trusted: request.trusted,
2287 interrupt_page: request.interrupt_page,
2288 monitor_page: monitor_page.map(MonitorPageGpaInfo::from_guest_gpas),
2289 target_message_vp: request.target_message_vp,
2290 modifying: false,
2291 offers_sent: false,
2292 client_id: request.client_id,
2293 paused: false,
2294 },
2295 next_action: ConnectionAction::None,
2296 };
2297
2298 if let Err(err) = self.notifier.modify_connection(ModifyConnectionRequest {
2301 version: Some(version),
2302 monitor_page: monitor_page.into(),
2303 interrupt_page: request.interrupt_page.into(),
2304 target_message_vp: Some(request.target_message_vp),
2305 notify_relay: true,
2306 }) {
2307 tracelimit::error_ratelimited!(?err, "server failed to change state");
2308 self.inner.state = ConnectionState::Disconnected;
2309 self.send_version_response(Some(VersionResponseData::new(
2310 version,
2311 protocol::ConnectionState::FAILED_UNKNOWN_FAILURE,
2312 )));
2313 }
2314 }
2315
2316 pub(crate) fn complete_initiate_contact(&mut self, response: ModifyConnectionResponse) {
2317 let ConnectionState::Connecting {
2318 mut info,
2319 next_action,
2320 } = self.inner.state
2321 else {
2322 panic!("Invalid state for completing InitiateContact.");
2323 };
2324
2325 const LOCAL_FEATURE_FLAGS: FeatureFlags = FeatureFlags::new()
2329 .with_client_id(true)
2330 .with_confidential_channels(true);
2331
2332 let (relay_feature_flags, server_specified_monitor_page) = match response {
2333 ModifyConnectionResponse::Supported(
2335 protocol::ConnectionState::SUCCESSFUL,
2336 feature_flags,
2337 server_specified_monitor_page,
2338 ) => (feature_flags, server_specified_monitor_page),
2339 ModifyConnectionResponse::Supported(
2342 connection_state,
2343 feature_flags,
2344 server_specified_monitor_page,
2345 ) => {
2346 tracelimit::error_ratelimited!(
2347 ?connection_state,
2348 "initiate contact failed because relay request failed"
2349 );
2350
2351 info.version.feature_flags &= (feature_flags | LOCAL_FEATURE_FLAGS)
2354 .with_server_specified_monitor_pages(server_specified_monitor_page.is_some());
2355
2356 self.send_version_response(Some(VersionResponseData::new(
2357 info.version,
2358 connection_state,
2359 )));
2360 self.inner.state = ConnectionState::Disconnected;
2361 return;
2362 }
2363 ModifyConnectionResponse::Unsupported => {
2366 self.send_version_response(None);
2367 self.inner.state = ConnectionState::Disconnected;
2368 return;
2369 }
2370 ModifyConnectionResponse::Modified(_) => {
2371 panic!("Invalid response for completing InitiateContact.");
2372 }
2373 };
2374
2375 assert!(
2377 info.version.feature_flags.server_specified_monitor_pages()
2378 || server_specified_monitor_page.is_none()
2379 );
2380
2381 info.version.feature_flags &= relay_feature_flags | LOCAL_FEATURE_FLAGS;
2384
2385 if let Some(gpas) = server_specified_monitor_page {
2389 info.monitor_page = Some(MonitorPageGpaInfo::from_server_gpas(gpas));
2390 info.version
2391 .feature_flags
2392 .set_server_specified_monitor_pages(true);
2393 } else {
2394 info.version
2395 .feature_flags
2396 .set_server_specified_monitor_pages(false);
2397 }
2398
2399 let version = info.version;
2400 self.inner.state = ConnectionState::Connected(info);
2401
2402 self.send_version_response(Some(
2403 VersionResponseData::new(version, protocol::ConnectionState::SUCCESSFUL)
2404 .with_monitor_pages(server_specified_monitor_page),
2405 ));
2406 if !matches!(next_action, ConnectionAction::None) && self.request_disconnect(next_action) {
2407 self.do_next_action(next_action);
2408 }
2409 }
2410
2411 fn check_version_supported(&self, request: &InitiateContactRequest) -> Option<VersionInfo> {
2413 let version = SUPPORTED_VERSIONS
2414 .iter()
2415 .find(|v| request.version_requested == **v as u32)
2416 .copied()?;
2417
2418 if let Some(max_version) = self.inner.max_version {
2420 if version as u32 > max_version.version {
2421 return None;
2422 }
2423 }
2424
2425 let supported_flags = if version >= Version::Copper {
2426 let max_supported_flags =
2428 SUPPORTED_FEATURE_FLAGS.with_confidential_channels(request.trusted);
2429
2430 if let Some(max_version) = self.inner.max_version {
2432 max_supported_flags & max_version.feature_flags
2433 } else {
2434 max_supported_flags
2435 }
2436 } else {
2437 FeatureFlags::new()
2438 };
2439
2440 let feature_flags = supported_flags & request.feature_flags.into();
2441
2442 assert!(version >= Version::Copper || feature_flags == FeatureFlags::new());
2443 if feature_flags.into_bits() != request.feature_flags {
2444 tracelimit::warn_ratelimited!(
2445 supported = feature_flags.into_bits(),
2446 requested = request.feature_flags,
2447 "Guest requested unsupported feature flags."
2448 );
2449 }
2450
2451 Some(VersionInfo {
2452 version,
2453 feature_flags,
2454 })
2455 }
2456
2457 fn send_version_response(&mut self, data: Option<VersionResponseData>) {
2458 self.send_version_response_with_target(data, MessageTarget::Default);
2459 }
2460
2461 fn send_version_response_with_target(
2462 &mut self,
2463 data: Option<VersionResponseData>,
2464 target: MessageTarget,
2465 ) {
2466 enum VersionResponseType {
2467 PreCopper,
2468 Copper,
2469 CopperWithServerMnf,
2470 }
2471
2472 let mut response_copper_with_mnf = protocol::VersionResponse3::new_zeroed();
2473 let response_copper = &mut response_copper_with_mnf.version_response2;
2474 let response = &mut response_copper.version_response;
2475 let mut response_type = VersionResponseType::PreCopper;
2476 if let Some(data) = data {
2477 if data.state == protocol::ConnectionState::SUCCESSFUL
2480 || data.version.version >= Version::Win8
2481 {
2482 response.version_supported = 1;
2483 response.connection_state = data.state;
2484 response.selected_version_or_connection_id =
2485 if data.version.version >= Version::Win10Rs3_1 {
2486 self.inner.child_connection_id
2487 } else {
2488 data.version.version as u32
2489 };
2490
2491 if data.version.version >= Version::Copper {
2492 response_copper.supported_features = data.version.feature_flags.into();
2493 response_type = VersionResponseType::Copper;
2494 if let Some(monitor_page) = data.monitor_pages {
2495 assert!(data.version.feature_flags.server_specified_monitor_pages());
2496 response_copper_with_mnf.child_to_parent_monitor_page_gpa =
2497 monitor_page.child_to_parent;
2498 response_copper_with_mnf.parent_to_child_monitor_page_gpa =
2499 monitor_page.parent_to_child;
2500 response_type = VersionResponseType::CopperWithServerMnf;
2501 }
2502 }
2503 }
2504 }
2505
2506 match response_type {
2508 VersionResponseType::PreCopper => {
2509 self.sender().send_message_with_target(response, target)
2510 }
2511 VersionResponseType::Copper => self
2512 .sender()
2513 .send_message_with_target(response_copper, target),
2514 VersionResponseType::CopperWithServerMnf => self
2515 .sender()
2516 .send_message_with_target(&response_copper_with_mnf, target),
2517 }
2518 }
2519
2520 fn request_disconnect(&mut self, new_action: ConnectionAction) -> bool {
2523 assert!(!self.is_resetting());
2524
2525 let gpadls = &mut self.inner.gpadls;
2527 let vm_reset = matches!(new_action, ConnectionAction::Reset);
2528 self.inner.channels.retain(|offer_id, channel| {
2529 (!vm_reset && channel.state.is_reserved())
2531 || !Self::client_release_channel(
2532 self.inner
2533 .pending_messages
2534 .sender(self.notifier, self.inner.state.is_paused()),
2535 offer_id,
2536 channel,
2537 gpadls,
2538 &mut self.inner.assigned_channels,
2539 &mut self.inner.assigned_monitors,
2540 None,
2541 )
2542 });
2543
2544 match &mut self.inner.state {
2548 ConnectionState::Disconnected => {
2549 if vm_reset {
2551 if !self.are_channels_reset(true) {
2552 self.inner.state = ConnectionState::Disconnecting {
2553 next_action: ConnectionAction::Reset,
2554 modify_sent: false,
2555 };
2556 }
2557 } else {
2558 assert!(self.are_channels_reset(false));
2559 }
2560 }
2561
2562 ConnectionState::Connected { .. } => {
2563 if self.are_channels_reset(vm_reset) {
2564 self.notify_disconnect(new_action);
2565 } else {
2566 self.inner.state = ConnectionState::Disconnecting {
2567 next_action: new_action,
2568 modify_sent: false,
2569 };
2570 }
2571 }
2572
2573 ConnectionState::Connecting { next_action, .. }
2574 | ConnectionState::Disconnecting { next_action, .. } => {
2575 *next_action = new_action;
2576 }
2577 }
2578
2579 matches!(self.inner.state, ConnectionState::Disconnected)
2580 }
2581
2582 pub(crate) fn complete_disconnect(&mut self) {
2583 if let ConnectionState::Disconnecting {
2584 next_action,
2585 modify_sent,
2586 } = std::mem::replace(&mut self.inner.state, ConnectionState::Disconnected)
2587 {
2588 assert!(self.are_channels_reset(matches!(next_action, ConnectionAction::Reset)));
2589 if !modify_sent {
2590 tracelimit::warn_ratelimited!("unexpected modify response");
2591 }
2592
2593 self.inner.state = ConnectionState::Disconnected;
2594 self.do_next_action(next_action);
2595 } else {
2596 unreachable!("not ready for disconnect");
2597 }
2598 }
2599
2600 fn do_next_action(&mut self, action: ConnectionAction) {
2601 match action {
2602 ConnectionAction::None => {}
2603 ConnectionAction::Reset => {
2604 self.complete_reset();
2605 }
2606 ConnectionAction::SendUnloadComplete => {
2607 self.complete_unload();
2608 }
2609 ConnectionAction::Reconnect { initiate_contact } => {
2610 self.initiate_contact(initiate_contact);
2611 }
2612 ConnectionAction::SendFailedVersionResponse => {
2613 self.send_version_response(None);
2616 }
2617 }
2618 }
2619
2620 fn handle_unload(&mut self) {
2622 tracing::debug!(
2623 vtl = self.inner.assigned_channels.vtl as u8,
2624 state = ?self.inner.state,
2625 "VmBus received unload request from guest",
2626 );
2627
2628 if self.request_disconnect(ConnectionAction::SendUnloadComplete) {
2629 self.complete_unload();
2630 }
2631 }
2632
2633 fn complete_unload(&mut self) {
2634 self.notifier.unload_complete();
2635 if let Some(version) = self.inner.delayed_max_version.take() {
2636 self.inner.set_compatibility_version(version, false);
2637 }
2638
2639 self.sender().send_message(&protocol::UnloadComplete {});
2640 tracelimit::info_ratelimited!("Vmbus disconnected");
2641 }
2642
2643 fn handle_request_offers(&mut self) -> Result<(), ChannelError> {
2645 let ConnectionState::Connected(info) = &mut self.inner.state else {
2646 unreachable!(
2647 "in unexpected state {:?}, should be prevented by Message::parse()",
2648 self.inner.state
2649 );
2650 };
2651
2652 if info.offers_sent {
2653 return Err(ChannelError::OffersAlreadySent);
2654 }
2655
2656 info.offers_sent = true;
2657
2658 let mut sorted_channels: Vec<_> = self
2661 .inner
2662 .channels
2663 .iter_mut()
2664 .filter(|(_, channel)| !channel.state.is_reserved())
2665 .collect();
2666
2667 if self.inner.use_absolute_channel_order {
2668 sorted_channels.sort_unstable_by_key(|(_, channel)| {
2669 (
2670 channel.offer.offer_order.unwrap_or(u64::MAX),
2671 channel.offer.interface_id,
2672 channel.offer.instance_id,
2673 )
2674 });
2675 } else {
2676 sorted_channels.sort_unstable_by_key(|(_, channel)| {
2677 (
2678 channel.offer.interface_id,
2679 channel.offer.offer_order.unwrap_or(u64::MAX),
2680 channel.offer.instance_id,
2681 )
2682 });
2683 }
2684
2685 for (offer_id, channel) in sorted_channels {
2686 assert!(matches!(channel.state, ChannelState::ClientReleased));
2687
2688 channel.prepare_channel(
2689 offer_id,
2690 &mut self.inner.assigned_channels,
2691 &mut self.inner.assigned_monitors,
2692 );
2693
2694 channel.state = ChannelState::Closed;
2695 self.inner
2696 .pending_messages
2697 .sender(self.notifier, info.paused)
2698 .send_offer(channel, info);
2699 }
2700 self.sender().send_message(&protocol::AllOffersDelivered {});
2701
2702 Ok(())
2703 }
2704
2705 #[must_use]
2708 fn gpadl_updated(
2709 mut sender: MessageSender<'_, N>,
2710 offer_id: OfferId,
2711 channel: &Channel,
2712 gpadl_id: GpadlId,
2713 gpadl: &Gpadl,
2714 ) -> bool {
2715 if channel.state.is_revoked() {
2716 let channel_id = channel.info.as_ref().expect("assigned").channel_id;
2717 sender.send_gpadl_created(channel_id, gpadl_id, protocol::STATUS_UNSUCCESSFUL);
2718 false
2719 } else {
2720 sender.notifier.notify(
2722 offer_id,
2723 Action::Gpadl(gpadl_id, gpadl.count, gpadl.buf.clone()),
2724 );
2725 true
2726 }
2727 }
2728
2729 fn handle_gpadl_header_core(
2731 &mut self,
2732 input: &protocol::GpadlHeader,
2733 range: &[u8],
2734 ) -> Result<(), ChannelError> {
2735 let (offer_id, channel) = self
2737 .inner
2738 .channels
2739 .get_by_channel_id_mut(&self.inner.assigned_channels, input.channel_id)?;
2740
2741 if channel.state.is_reserved() {
2744 return Err(ChannelError::ChannelReserved);
2745 }
2746
2747 let mut gpadl = Gpadl::new(input.count, input.len as usize / 8);
2749 let done = gpadl.append(range)?;
2750
2751 let gpadl = match self.inner.gpadls.entry((input.gpadl_id, offer_id)) {
2753 Entry::Vacant(entry) => entry.insert(gpadl),
2754 Entry::Occupied(_) => return Err(ChannelError::DuplicateGpadlId),
2755 };
2756
2757 if !done
2759 && self
2760 .inner
2761 .incomplete_gpadls
2762 .insert(input.gpadl_id, offer_id)
2763 .is_some()
2764 {
2765 unreachable!("gpadl ID validated above");
2766 }
2767
2768 if done
2769 && !Self::gpadl_updated(
2770 self.inner
2771 .pending_messages
2772 .sender(self.notifier, self.inner.state.is_paused()),
2773 offer_id,
2774 channel,
2775 input.gpadl_id,
2776 gpadl,
2777 )
2778 {
2779 self.inner.gpadls.remove(&(input.gpadl_id, offer_id));
2780 }
2781 Ok(())
2782 }
2783
2784 fn handle_gpadl_header(&mut self, input: &protocol::GpadlHeader, range: &[u8]) {
2786 if let Err(err) = self.handle_gpadl_header_core(input, range) {
2787 tracelimit::warn_ratelimited!(
2788 err = &err as &dyn std::error::Error,
2789 channel_id = ?input.channel_id,
2790 key = %self.inner.channels.get_by_channel_id(&self.inner.assigned_channels, input.channel_id).map(|(_, c)| c.offer.key()).unwrap_or_default(),
2791 gpadl_id = ?input.gpadl_id,
2792 "error handling gpadl header"
2793 );
2794
2795 self.sender().send_gpadl_created(
2797 input.channel_id,
2798 input.gpadl_id,
2799 protocol::STATUS_UNSUCCESSFUL,
2800 );
2801 }
2802 }
2803
2804 fn handle_gpadl_body(
2810 &mut self,
2811 input: &protocol::GpadlBody,
2812 range: &[u8],
2813 ) -> Result<(), ChannelError> {
2814 let &offer_id = self
2818 .inner
2819 .incomplete_gpadls
2820 .get(&input.gpadl_id)
2821 .ok_or(ChannelError::UnknownGpadlId)?;
2822 let gpadl = self
2823 .inner
2824 .gpadls
2825 .get_mut(&(input.gpadl_id, offer_id))
2826 .ok_or(ChannelError::UnknownGpadlId)?;
2827 let channel = &mut self.inner.channels[offer_id];
2828
2829 match gpadl.append(range) {
2830 Ok(done) => {
2831 if done {
2832 self.inner.incomplete_gpadls.remove(&input.gpadl_id);
2833 if !Self::gpadl_updated(
2834 self.inner
2835 .pending_messages
2836 .sender(self.notifier, self.inner.state.is_paused()),
2837 offer_id,
2838 channel,
2839 input.gpadl_id,
2840 gpadl,
2841 ) {
2842 self.inner.gpadls.remove(&(input.gpadl_id, offer_id));
2843 }
2844 }
2845 }
2846 Err(err) => {
2847 self.inner.incomplete_gpadls.remove(&input.gpadl_id);
2848 self.inner.gpadls.remove(&(input.gpadl_id, offer_id));
2849 let channel_id = channel.info.as_ref().expect("assigned").channel_id;
2850 tracelimit::warn_ratelimited!(
2851 err = &err as &dyn std::error::Error,
2852 channel_id = channel_id.0,
2853 key = %channel.offer.key(),
2854 gpadl_id = input.gpadl_id.0,
2855 "error handling gpadl body"
2856 );
2857 self.sender().send_gpadl_created(
2858 channel_id,
2859 input.gpadl_id,
2860 protocol::STATUS_UNSUCCESSFUL,
2861 );
2862 }
2863 }
2864
2865 Ok(())
2866 }
2867
2868 fn handle_gpadl_teardown(
2870 &mut self,
2871 input: &protocol::GpadlTeardown,
2872 ) -> Result<(), ChannelError> {
2873 let (offer_id, channel) = self
2874 .inner
2875 .channels
2876 .get_by_channel_id_mut(&self.inner.assigned_channels, input.channel_id)?;
2877
2878 tracing::debug!(
2879 channel_id = input.channel_id.0,
2880 key = %channel.offer.key(),
2881 gpadl_id = input.gpadl_id.0,
2882 "Received GPADL teardown request"
2883 );
2884
2885 let gpadl = self
2886 .inner
2887 .gpadls
2888 .get_mut(&(input.gpadl_id, offer_id))
2889 .ok_or(ChannelError::UnknownGpadlId)?;
2890
2891 match gpadl.state {
2892 GpadlState::InProgress
2893 | GpadlState::Offered
2894 | GpadlState::OfferedTearingDown
2895 | GpadlState::TearingDown => {
2896 return Err(ChannelError::InvalidGpadlState);
2897 }
2898 GpadlState::Accepted => {
2899 if channel.info.as_ref().map(|info| info.channel_id) != Some(input.channel_id) {
2900 return Err(ChannelError::WrongGpadlChannelId);
2901 }
2902
2903 if channel.state.is_reserved() {
2907 return Err(ChannelError::ChannelReserved);
2908 }
2909
2910 if channel.state.is_revoked() {
2911 tracing::trace!(
2912 channel_id = input.channel_id.0,
2913 key = %channel.offer.key(),
2914 gpadl_id = input.gpadl_id.0,
2915 "Gpadl teardown for revoked channel"
2916 );
2917
2918 self.inner.gpadls.remove(&(input.gpadl_id, offer_id));
2919 self.sender().send_gpadl_torndown(input.gpadl_id);
2920 } else {
2921 gpadl.state = GpadlState::TearingDown;
2922 self.notifier.notify(
2923 offer_id,
2924 Action::TeardownGpadl {
2925 gpadl_id: input.gpadl_id,
2926 post_restore: false,
2927 },
2928 );
2929 }
2930 }
2931 }
2932 Ok(())
2933 }
2934
2935 fn open_channel(
2938 &mut self,
2939 offer_id: OfferId,
2940 input: &OpenRequest,
2941 reserved_state: Option<ReservedState>,
2942 ) {
2943 let channel = &mut self.inner.channels[offer_id];
2944 assert!(matches!(channel.state, ChannelState::Closed));
2945
2946 channel.state = ChannelState::Opening {
2947 request: *input,
2948 reserved_state,
2949 };
2950
2951 let info = channel.info.as_ref().expect("assigned");
2954 self.notifier.notify(
2955 offer_id,
2956 Action::Open(
2957 OpenParams::from_request(
2958 info,
2959 input,
2960 channel.handled_monitor_info(),
2961 reserved_state.map(|state| state.target),
2962 ),
2963 self.inner.state.get_version().expect("must be connected"),
2964 ),
2965 );
2966 }
2967
2968 fn handle_open_channel(&mut self, input: &protocol::OpenChannel2) -> Result<(), ChannelError> {
2970 let (offer_id, channel) = self
2971 .inner
2972 .channels
2973 .get_by_channel_id_mut(&self.inner.assigned_channels, input.open_channel.channel_id)?;
2974
2975 let guest_specified_interrupt_info = self
2976 .inner
2977 .state
2978 .check_feature_flags(|ff| ff.guest_specified_signal_parameters())
2979 .then_some(SignalInfo {
2980 event_flag: input.event_flag,
2981 connection_id: input.connection_id,
2982 });
2983
2984 let flags = if self
2985 .inner
2986 .state
2987 .check_feature_flags(|ff| ff.channel_interrupt_redirection())
2988 {
2989 input.flags
2990 } else {
2991 Default::default()
2992 };
2993
2994 let request = OpenRequest {
2995 open_id: input.open_channel.open_id,
2996 ring_buffer_gpadl_id: input.open_channel.ring_buffer_gpadl_id,
2997 target_vp: input.open_channel.target_vp,
2998 downstream_ring_buffer_page_offset: input
2999 .open_channel
3000 .downstream_ring_buffer_page_offset,
3001 user_data: input.open_channel.user_data,
3002 guest_specified_interrupt_info,
3003 flags,
3004 };
3005
3006 match channel.state {
3007 ChannelState::Closed => self.open_channel(offer_id, &request, None),
3008 ChannelState::Closing { params, .. } => {
3009 channel.state = ChannelState::ClosingReopen { params, request }
3013 }
3014 ChannelState::Revoked | ChannelState::Reoffered => {}
3015
3016 ChannelState::Open { .. }
3017 | ChannelState::Opening { .. }
3018 | ChannelState::ClosingReopen { .. } => return Err(ChannelError::ChannelAlreadyOpen),
3019
3020 ChannelState::ClientReleased
3021 | ChannelState::ClosingClientRelease
3022 | ChannelState::OpeningClientRelease => unreachable!(),
3023 }
3024 Ok(())
3025 }
3026
3027 fn handle_close_channel(&mut self, input: &protocol::CloseChannel) -> Result<(), ChannelError> {
3029 let (offer_id, channel) = self
3030 .inner
3031 .channels
3032 .get_by_channel_id_mut(&self.inner.assigned_channels, input.channel_id)?;
3033
3034 match channel.state {
3035 ChannelState::Open {
3036 params,
3037 modify_state,
3038 reserved_state: None,
3039 } => {
3040 if modify_state.is_modifying() {
3041 tracelimit::warn_ratelimited!(
3042 key = %channel.offer.key(),
3043 ?modify_state,
3044 "Client is closing the channel with a modify in progress"
3045 )
3046 }
3047
3048 channel.state = ChannelState::Closing {
3049 params,
3050 reserved_state: None,
3051 };
3052 self.notifier.notify(offer_id, Action::Close);
3053 }
3054
3055 ChannelState::Open {
3056 reserved_state: Some(_),
3057 ..
3058 } => return Err(ChannelError::ChannelReserved),
3059
3060 ChannelState::Revoked | ChannelState::Reoffered => {}
3061
3062 ChannelState::Closed
3063 | ChannelState::Opening { .. }
3064 | ChannelState::Closing { .. }
3065 | ChannelState::ClosingReopen { .. } => return Err(ChannelError::ChannelNotOpen),
3066
3067 ChannelState::ClientReleased
3068 | ChannelState::ClosingClientRelease
3069 | ChannelState::OpeningClientRelease => unreachable!(),
3070 }
3071
3072 Ok(())
3073 }
3074
3075 fn handle_open_reserved_channel(
3078 &mut self,
3079 input: &protocol::OpenReservedChannel,
3080 version: VersionInfo,
3081 ) -> Result<(), ChannelError> {
3082 let (offer_id, channel) = self
3083 .inner
3084 .channels
3085 .get_by_channel_id_mut(&self.inner.assigned_channels, input.channel_id)?;
3086
3087 let target = ConnectionTarget {
3088 vp: input.target_vp,
3089 sint: input.target_sint as u8,
3090 };
3091
3092 let reserved_state = Some(ReservedState { version, target });
3093
3094 let request = OpenRequest {
3095 ring_buffer_gpadl_id: input.ring_buffer_gpadl,
3096 target_vp: protocol::VP_INDEX_DISABLE_INTERRUPT,
3098 downstream_ring_buffer_page_offset: input.downstream_page_offset,
3099 open_id: 0,
3100 user_data: UserDefinedData::new_zeroed(),
3101 guest_specified_interrupt_info: None,
3102 flags: Default::default(),
3103 };
3104
3105 match channel.state {
3106 ChannelState::Closed => self.open_channel(offer_id, &request, reserved_state),
3107 ChannelState::Revoked | ChannelState::Reoffered => {}
3108
3109 ChannelState::Open { .. } | ChannelState::Opening { .. } => {
3110 return Err(ChannelError::ChannelAlreadyOpen);
3111 }
3112
3113 ChannelState::Closing { .. } | ChannelState::ClosingReopen { .. } => {
3114 return Err(ChannelError::InvalidChannelState);
3115 }
3116
3117 ChannelState::ClientReleased
3118 | ChannelState::ClosingClientRelease
3119 | ChannelState::OpeningClientRelease => unreachable!(),
3120 }
3121 Ok(())
3122 }
3123
3124 fn handle_close_reserved_channel(
3127 &mut self,
3128 input: &protocol::CloseReservedChannel,
3129 ) -> Result<(), ChannelError> {
3130 let (offer_id, channel) = self
3131 .inner
3132 .channels
3133 .get_by_channel_id_mut(&self.inner.assigned_channels, input.channel_id)?;
3134
3135 match channel.state {
3136 ChannelState::Open {
3137 params,
3138 reserved_state: Some(mut resvd),
3139 ..
3140 } => {
3141 resvd.target.vp = input.target_vp;
3142 resvd.target.sint = input.target_sint as u8;
3143 channel.state = ChannelState::Closing {
3144 params,
3145 reserved_state: Some(resvd),
3146 };
3147 self.notifier.notify(offer_id, Action::Close);
3148 }
3149
3150 ChannelState::Open {
3151 reserved_state: None,
3152 ..
3153 } => return Err(ChannelError::ChannelNotReserved),
3154
3155 ChannelState::Revoked | ChannelState::Reoffered => {}
3156
3157 ChannelState::Closed
3158 | ChannelState::Opening { .. }
3159 | ChannelState::Closing { .. }
3160 | ChannelState::ClosingReopen { .. } => return Err(ChannelError::ChannelNotOpen),
3161
3162 ChannelState::ClientReleased
3163 | ChannelState::ClosingClientRelease
3164 | ChannelState::OpeningClientRelease => unreachable!(),
3165 }
3166
3167 Ok(())
3168 }
3169
3170 #[must_use]
3174 fn client_release_channel(
3175 mut sender: MessageSender<'_, N>,
3176 offer_id: OfferId,
3177 channel: &mut Channel,
3178 gpadls: &mut GpadlMap,
3179 assigned_channels: &mut AssignedChannels,
3180 assigned_monitors: &mut AssignedMonitors,
3181 info: Option<&ConnectionInfo>,
3182 ) -> bool {
3183 tracelimit::info_ratelimited!(?offer_id, key = %channel.offer.key(), "client released channel");
3184 gpadls.retain(|&(gpadl_id, gpadl_offer_id), gpadl| {
3186 if gpadl_offer_id != offer_id {
3187 return true;
3188 }
3189 match gpadl.state {
3190 GpadlState::InProgress => false,
3191 GpadlState::Offered => {
3192 gpadl.state = GpadlState::OfferedTearingDown;
3193 true
3194 }
3195 GpadlState::Accepted => {
3196 if channel.state.is_revoked() {
3197 false
3199 } else {
3200 gpadl.state = GpadlState::TearingDown;
3201 sender.notifier.notify(
3202 offer_id,
3203 Action::TeardownGpadl {
3204 gpadl_id,
3205 post_restore: false,
3206 },
3207 );
3208 true
3209 }
3210 }
3211 GpadlState::OfferedTearingDown | GpadlState::TearingDown => true,
3212 }
3213 });
3214
3215 let remove = match &mut channel.state {
3216 ChannelState::Closed => {
3217 channel.state = ChannelState::ClientReleased;
3218 false
3219 }
3220 ChannelState::Reoffered => {
3221 if let Some(info) = info {
3222 channel.state = ChannelState::Closed;
3223 channel.restore_state = RestoreState::New;
3224 sender.send_offer(channel, info);
3225 return false;
3227 }
3228 channel.state = ChannelState::ClientReleased;
3229 false
3230 }
3231 ChannelState::Revoked => {
3232 channel.state = ChannelState::ClientReleased;
3233 true
3234 }
3235 ChannelState::Opening { .. } => {
3236 channel.state = ChannelState::OpeningClientRelease;
3237 false
3238 }
3239 ChannelState::Open { .. } => {
3240 channel.state = ChannelState::ClosingClientRelease;
3241 sender.notifier.notify(offer_id, Action::Close);
3242 false
3243 }
3244 ChannelState::Closing { .. } | ChannelState::ClosingReopen { .. } => {
3245 channel.state = ChannelState::ClosingClientRelease;
3246 false
3247 }
3248
3249 ChannelState::ClosingClientRelease
3250 | ChannelState::OpeningClientRelease
3251 | ChannelState::ClientReleased => false,
3252 };
3253
3254 assert!(channel.state.is_released());
3255
3256 channel.release_channel(offer_id, assigned_channels, assigned_monitors);
3257 remove
3258 }
3259
3260 fn handle_rel_id_released(
3262 &mut self,
3263 input: &protocol::RelIdReleased,
3264 ) -> Result<(), ChannelError> {
3265 let channel_id = input.channel_id;
3266 let (offer_id, channel) = self
3267 .inner
3268 .channels
3269 .get_by_channel_id_mut(&self.inner.assigned_channels, channel_id)?;
3270
3271 match channel.state {
3272 ChannelState::Closed
3273 | ChannelState::Revoked
3274 | ChannelState::Closing { .. }
3275 | ChannelState::Reoffered => {
3276 if Self::client_release_channel(
3277 self.inner
3278 .pending_messages
3279 .sender(self.notifier, self.inner.state.is_paused()),
3280 offer_id,
3281 channel,
3282 &mut self.inner.gpadls,
3283 &mut self.inner.assigned_channels,
3284 &mut self.inner.assigned_monitors,
3285 self.inner.state.get_connected_info(),
3286 ) {
3287 self.inner.channels.remove(offer_id);
3288 }
3289
3290 self.check_disconnected();
3291 }
3292
3293 ChannelState::Opening { .. }
3294 | ChannelState::Open { .. }
3295 | ChannelState::ClosingReopen { .. } => return Err(ChannelError::InvalidChannelState),
3296
3297 ChannelState::ClientReleased
3298 | ChannelState::OpeningClientRelease
3299 | ChannelState::ClosingClientRelease => unreachable!(),
3300 }
3301 Ok(())
3302 }
3303
3304 fn handle_tl_connect_request(&mut self, request: protocol::TlConnectRequest2) {
3307 let version = self
3308 .inner
3309 .state
3310 .get_version()
3311 .expect("must be connected")
3312 .version;
3313
3314 let hosted_silo_unaware = version < Version::Win10Rs5;
3315 self.notifier
3316 .notify_hvsock(&HvsockConnectRequest::from_message(
3317 request,
3318 hosted_silo_unaware,
3319 ));
3320 }
3321
3322 pub fn send_tl_connect_result(&mut self, result: HvsockConnectResult) {
3324 if !result.success && self.inner.state.check_version(Version::Win10Rs3_0) {
3328 self.sender().send_message(&protocol::TlConnectResult {
3332 service_id: result.service_id,
3333 endpoint_id: result.endpoint_id,
3334 status: protocol::STATUS_CONNECTION_REFUSED,
3335 })
3336 }
3337 }
3338
3339 fn handle_modify_channel(
3342 &mut self,
3343 request: &protocol::ModifyChannel,
3344 ) -> Result<(), ChannelError> {
3345 let result = self.modify_channel(request);
3346 if result.is_err() {
3347 self.send_modify_channel_response(request.channel_id, protocol::STATUS_UNSUCCESSFUL);
3348 }
3349
3350 result
3351 }
3352
3353 fn modify_channel(&mut self, request: &protocol::ModifyChannel) -> Result<(), ChannelError> {
3355 let (offer_id, channel) = self
3356 .inner
3357 .channels
3358 .get_by_channel_id_mut(&self.inner.assigned_channels, request.channel_id)?;
3359
3360 let (open_request, modify_state) = match &mut channel.state {
3361 ChannelState::Open {
3362 params,
3363 modify_state,
3364 reserved_state: None,
3365 } => (params, modify_state),
3366 _ => return Err(ChannelError::InvalidChannelState),
3367 };
3368
3369 if let ModifyState::Modifying { pending_target_vp } = modify_state {
3370 if self.inner.state.check_version(Version::Iron) {
3371 tracelimit::warn_ratelimited!(
3374 key = %channel.offer.key(),
3375 "Client sent new ModifyChannel before receiving ModifyChannelResponse."
3376 );
3377 } else {
3378 *pending_target_vp = Some(request.target_vp);
3381 }
3382 } else {
3383 self.notifier.notify(
3384 offer_id,
3385 Action::Modify {
3386 target_vp: request.target_vp,
3387 },
3388 );
3389
3390 open_request.target_vp = request.target_vp;
3392 *modify_state = ModifyState::Modifying {
3393 pending_target_vp: None,
3394 };
3395 }
3396
3397 Ok(())
3398 }
3399
3400 pub fn modify_channel_complete(&mut self, offer_id: OfferId, status: i32) {
3407 let channel = &mut self.inner.channels[offer_id];
3408
3409 if let ChannelState::Open {
3410 params,
3411 modify_state: ModifyState::Modifying { pending_target_vp },
3412 reserved_state: None,
3413 } = channel.state
3414 {
3415 channel.state = ChannelState::Open {
3416 params,
3417 modify_state: ModifyState::NotModifying,
3418 reserved_state: None,
3419 };
3420
3421 let channel_id = channel.info.as_ref().expect("assigned").channel_id;
3423 let key = channel.offer.key();
3424 self.send_modify_channel_response(channel_id, status);
3425
3426 if let Some(target_vp) = pending_target_vp {
3428 let request = protocol::ModifyChannel {
3429 channel_id,
3430 target_vp,
3431 };
3432
3433 if let Err(error) = self.handle_modify_channel(&request) {
3434 tracelimit::warn_ratelimited!(?error, %key, "Pending ModifyChannel request failed.")
3435 }
3436 }
3437 }
3438 }
3439
3440 fn send_modify_channel_response(&mut self, channel_id: ChannelId, status: i32) {
3441 if self.inner.state.check_version(Version::Iron) {
3442 self.sender()
3443 .send_message(&protocol::ModifyChannelResponse { channel_id, status });
3444 }
3445 }
3446
3447 fn handle_modify_connection(&mut self, request: protocol::ModifyConnection) {
3448 if let Err(err) = self.modify_connection(request) {
3449 tracelimit::error_ratelimited!(?err, "modifying connection failed");
3450 self.complete_modify_connection(ModifyConnectionResponse::Modified(
3451 protocol::ConnectionState::FAILED_UNKNOWN_FAILURE,
3452 ));
3453 }
3454 }
3455
3456 fn modify_connection(&mut self, request: protocol::ModifyConnection) -> anyhow::Result<()> {
3457 let ConnectionState::Connected(info) = &mut self.inner.state else {
3458 anyhow::bail!(
3459 "Invalid state for ModifyConnection request: {:?}",
3460 self.inner.state
3461 );
3462 };
3463
3464 if info.modifying {
3465 anyhow::bail!(
3466 "Duplicate ModifyConnection request, state: {:?}",
3467 self.inner.state
3468 );
3469 }
3470
3471 if matches!(
3472 info.monitor_page,
3473 Some(MonitorPageGpaInfo {
3474 server_allocated: true,
3475 ..
3476 })
3477 ) {
3478 anyhow::bail!("Cannot modify server-allocated monitor pages");
3479 }
3480
3481 if (request.child_to_parent_monitor_page_gpa == 0)
3482 != (request.parent_to_child_monitor_page_gpa == 0)
3483 {
3484 anyhow::bail!("Guest must specify either both or no monitor pages, {request:?}");
3485 }
3486
3487 let monitor_page = (request.child_to_parent_monitor_page_gpa != 0).then_some(
3488 MonitorPageGpaInfo::from_guest_gpas(MonitorPageGpas {
3489 child_to_parent: request.child_to_parent_monitor_page_gpa,
3490 parent_to_child: request.parent_to_child_monitor_page_gpa,
3491 }),
3492 );
3493
3494 info.modifying = true;
3495 info.monitor_page = monitor_page;
3496 tracing::debug!("modifying connection parameters.");
3497 self.notifier.modify_connection(request.into())?;
3498
3499 Ok(())
3500 }
3501
3502 pub fn complete_modify_connection(&mut self, response: ModifyConnectionResponse) {
3503 tracing::debug!(?response, "modifying connection parameters complete");
3504
3505 match &mut self.inner.state {
3509 ConnectionState::Connecting { .. } => self.complete_initiate_contact(response),
3510 ConnectionState::Disconnecting { .. } => self.complete_disconnect(),
3511 ConnectionState::Connected(info) => {
3512 let ModifyConnectionResponse::Modified(connection_state) = response else {
3513 panic!(
3514 "Relay should not return {:?} for a modify request with no version.",
3515 response
3516 );
3517 };
3518
3519 if !info.modifying {
3520 panic!(
3521 "ModifyConnection response while not modifying, state: {:?}",
3522 self.inner.state
3523 );
3524 }
3525
3526 info.modifying = false;
3527 self.sender()
3528 .send_message(&protocol::ModifyConnectionResponse { connection_state });
3529 }
3530 _ => panic!(
3531 "Invalid state for ModifyConnection response: {:?}",
3532 self.inner.state
3533 ),
3534 }
3535 }
3536
3537 fn handle_pause(&mut self) {
3538 tracelimit::info_ratelimited!("pausing sending messages");
3539 self.sender().send_message(&protocol::PauseResponse {});
3540 let ConnectionState::Connected(info) = &mut self.inner.state else {
3541 unreachable!(
3542 "in unexpected state {:?}, should be prevented by Message::parse()",
3543 self.inner.state
3544 );
3545 };
3546 info.paused = true;
3547 }
3548
3549 pub fn handle_synic_message(&mut self, message: SynicMessage) -> Result<(), ChannelError> {
3551 assert!(!self.is_resetting());
3552
3553 let version = self.inner.state.get_version();
3554 let msg = Message::parse(&message.data, version)?;
3555 tracing::trace!(?msg, message.trusted, "received vmbus message");
3556 if self.inner.state.is_trusted() && !message.trusted {
3561 tracelimit::warn_ratelimited!(?msg, "Received untrusted message");
3562 return Err(ChannelError::UntrustedMessage);
3563 }
3564
3565 match &mut self.inner.state {
3567 ConnectionState::Connected(info) if info.paused => {
3568 if !matches!(
3569 msg,
3570 Message::Resume(..)
3571 | Message::Unload(..)
3572 | Message::InitiateContact { .. }
3573 | Message::InitiateContact2 { .. }
3574 ) {
3575 tracelimit::warn_ratelimited!(?msg, "Received message while paused");
3576 return Err(ChannelError::Paused);
3577 }
3578 tracelimit::info_ratelimited!("resuming sending messages");
3579 info.paused = false;
3580 }
3581 _ => {}
3582 }
3583
3584 match msg {
3585 Message::InitiateContact2(input, ..) => {
3586 self.handle_initiate_contact(&input, &message, true)?
3587 }
3588 Message::InitiateContact(input, ..) => {
3589 self.handle_initiate_contact(&input.into(), &message, false)?
3590 }
3591 Message::Unload(..) => self.handle_unload(),
3592 Message::RequestOffers(..) => self.handle_request_offers()?,
3593 Message::GpadlHeader(input, range) => self.handle_gpadl_header(&input, range),
3594 Message::GpadlBody(input, range) => self.handle_gpadl_body(&input, range)?,
3595 Message::GpadlTeardown(input, ..) => self.handle_gpadl_teardown(&input)?,
3596 Message::OpenChannel(input, ..) => self.handle_open_channel(&input.into())?,
3597 Message::OpenChannel2(input, ..) => self.handle_open_channel(&input)?,
3598 Message::CloseChannel(input, ..) => self.handle_close_channel(&input)?,
3599 Message::RelIdReleased(input, ..) => self.handle_rel_id_released(&input)?,
3600 Message::TlConnectRequest(input, ..) => self.handle_tl_connect_request(input.into()),
3601 Message::TlConnectRequest2(input, ..) => self.handle_tl_connect_request(input),
3602 Message::ModifyChannel(input, ..) => self.handle_modify_channel(&input)?,
3603 Message::ModifyConnection(input, ..) => self.handle_modify_connection(input),
3604 Message::OpenReservedChannel(input, ..) => self.handle_open_reserved_channel(
3605 &input,
3606 version.expect("version validated by Message::parse"),
3607 )?,
3608 Message::CloseReservedChannel(input, ..) => {
3609 self.handle_close_reserved_channel(&input)?
3610 }
3611 Message::Pause(protocol::Pause, ..) => self.handle_pause(),
3612 Message::Resume(protocol::Resume, ..) => {}
3613 Message::OfferChannel(..)
3615 | Message::RescindChannelOffer(..)
3616 | Message::AllOffersDelivered(..)
3617 | Message::OpenResult(..)
3618 | Message::GpadlCreated(..)
3619 | Message::GpadlTorndown(..)
3620 | Message::VersionResponse(..)
3621 | Message::VersionResponse2(..)
3622 | Message::VersionResponse3(..)
3623 | Message::UnloadComplete(..)
3624 | Message::CloseReservedChannelResponse(..)
3625 | Message::TlConnectResult(..)
3626 | Message::ModifyChannelResponse(..)
3627 | Message::ModifyConnectionResponse(..)
3628 | Message::PauseResponse(..) => {
3629 unreachable!("Server received client message {:?}", msg);
3630 }
3631 }
3632 Ok(())
3633 }
3634
3635 pub fn gpadl_create_complete(&mut self, offer_id: OfferId, gpadl_id: GpadlId, status: i32) {
3637 let Some(gpadl) = self.inner.gpadls.get_mut(&(gpadl_id, offer_id)) else {
3638 tracelimit::error_ratelimited!(
3639 ?offer_id,
3640 key = %self.inner.channels[offer_id].offer.key(),
3641 ?gpadl_id,
3642 "invalid gpadl ID for channel"
3643 );
3644 return;
3645 };
3646 let retain = match gpadl.state {
3647 GpadlState::InProgress | GpadlState::TearingDown | GpadlState::Accepted => {
3648 tracelimit::error_ratelimited!(?offer_id, ?gpadl_id, ?gpadl, "invalid gpadl state");
3649 return;
3650 }
3651 GpadlState::Offered => {
3652 let channel_id = self.inner.channels[offer_id]
3653 .info
3654 .as_ref()
3655 .expect("assigned")
3656 .channel_id;
3657 self.inner
3658 .pending_messages
3659 .sender(self.notifier, self.inner.state.is_paused())
3660 .send_gpadl_created(channel_id, gpadl_id, status);
3661 if status >= 0 {
3662 gpadl.state = GpadlState::Accepted;
3663 true
3664 } else {
3665 false
3666 }
3667 }
3668 GpadlState::OfferedTearingDown => {
3669 if status >= 0 {
3670 self.notifier.notify(
3672 offer_id,
3673 Action::TeardownGpadl {
3674 gpadl_id,
3675 post_restore: false,
3676 },
3677 );
3678 gpadl.state = GpadlState::TearingDown;
3679 true
3680 } else {
3681 false
3682 }
3683 }
3684 };
3685 if !retain {
3686 self.inner
3687 .gpadls
3688 .remove(&(gpadl_id, offer_id))
3689 .expect("gpadl validated above");
3690
3691 self.check_disconnected();
3692 }
3693 }
3694
3695 pub fn gpadl_teardown_complete(&mut self, offer_id: OfferId, gpadl_id: GpadlId) {
3697 let channel = &mut self.inner.channels[offer_id];
3698 let Some(gpadl) = self.inner.gpadls.get_mut(&(gpadl_id, offer_id)) else {
3699 tracelimit::error_ratelimited!(
3700 ?offer_id,
3701 key = %channel.offer.key(),
3702 ?gpadl_id,
3703 "invalid gpadl ID for channel"
3704 );
3705 return;
3706 };
3707 tracing::debug!(
3708 offer_id = offer_id.0,
3709 key = %channel.offer.key(),
3710 gpadl_id = gpadl_id.0,
3711 "Gpadl teardown complete"
3712 );
3713 match gpadl.state {
3714 GpadlState::InProgress
3715 | GpadlState::Offered
3716 | GpadlState::OfferedTearingDown
3717 | GpadlState::Accepted => {
3718 tracelimit::error_ratelimited!(?offer_id, key = %channel.offer.key(), ?gpadl_id, ?gpadl, "invalid gpadl state");
3719 }
3720 GpadlState::TearingDown => {
3721 if !channel.state.is_released() {
3722 self.sender().send_gpadl_torndown(gpadl_id);
3723 }
3724 self.inner
3725 .gpadls
3726 .remove(&(gpadl_id, offer_id))
3727 .expect("gpadl validated above");
3728
3729 self.check_disconnected();
3730 }
3731 }
3732 }
3733
3734 fn sender(&mut self) -> MessageSender<'_, N> {
3739 self.inner
3740 .pending_messages
3741 .sender(self.notifier, self.inner.state.is_paused())
3742 }
3743}
3744
3745fn revoke<N: Notifier>(
3746 mut sender: MessageSender<'_, N>,
3747 offer_id: OfferId,
3748 channel: &mut Channel,
3749 gpadls: &mut GpadlMap,
3750) -> bool {
3751 let info = match channel.state {
3752 ChannelState::Closed
3753 | ChannelState::Open { .. }
3754 | ChannelState::Opening { .. }
3755 | ChannelState::Closing { .. }
3756 | ChannelState::ClosingReopen { .. } => {
3757 channel.state = ChannelState::Revoked;
3758 Some(channel.info.as_ref().expect("assigned"))
3759 }
3760 ChannelState::Reoffered => {
3761 channel.state = ChannelState::Revoked;
3762 None
3763 }
3764 ChannelState::ClientReleased
3765 | ChannelState::OpeningClientRelease
3766 | ChannelState::ClosingClientRelease => None,
3767 ChannelState::Revoked => return true,
3769 };
3770 let retain = !channel.state.is_released();
3771
3772 gpadls.retain(|&(gpadl_id, gpadl_offer_id), gpadl| {
3774 if gpadl_offer_id != offer_id {
3775 return true;
3776 }
3777
3778 match gpadl.state {
3779 GpadlState::InProgress => true,
3780 GpadlState::Offered => {
3781 if let Some(info) = info {
3782 sender.send_gpadl_created(
3783 info.channel_id,
3784 gpadl_id,
3785 protocol::STATUS_UNSUCCESSFUL,
3786 );
3787 }
3788 false
3789 }
3790 GpadlState::OfferedTearingDown => false,
3791 GpadlState::Accepted => true,
3792 GpadlState::TearingDown => {
3793 if info.is_some() {
3794 sender.send_gpadl_torndown(gpadl_id);
3795 }
3796 false
3797 }
3798 }
3799 });
3800 if let Some(info) = info {
3801 sender.send_rescind(info);
3802 }
3803 if channel.restore_state != RestoreState::New {
3805 channel.restore_state = RestoreState::Restored;
3806 }
3807 retain
3808}
3809
3810struct PendingMessages(VecDeque<OutgoingMessage>);
3811
3812impl PendingMessages {
3813 fn sender<'a, N: Notifier>(
3815 &'a mut self,
3816 notifier: &'a mut N,
3817 is_paused: bool,
3818 ) -> MessageSender<'a, N> {
3819 MessageSender {
3820 notifier,
3821 pending_messages: self,
3822 is_paused,
3823 }
3824 }
3825}
3826
3827struct MessageSender<'a, N> {
3830 notifier: &'a mut N,
3831 pending_messages: &'a mut PendingMessages,
3832 is_paused: bool,
3833}
3834
3835impl<N: Notifier> MessageSender<'_, N> {
3836 fn send_message<
3838 T: IntoBytes + protocol::VmbusMessage + std::fmt::Debug + Immutable + KnownLayout,
3839 >(
3840 &mut self,
3841 msg: &T,
3842 ) {
3843 let message = OutgoingMessage::new(msg);
3844
3845 tracing::trace!(typ = ?T::MESSAGE_TYPE, ?msg, "sending message");
3846 if !self.pending_messages.0.is_empty()
3848 || self.is_paused
3849 || !self.notifier.send_message(&message, MessageTarget::Default)
3850 {
3851 tracing::trace!("message queued");
3852 self.pending_messages.0.push_back(message);
3854 }
3855 }
3856
3857 fn send_message_with_target<
3859 T: IntoBytes + protocol::VmbusMessage + std::fmt::Debug + Immutable + KnownLayout,
3860 >(
3861 &mut self,
3862 msg: &T,
3863 target: MessageTarget,
3864 ) {
3865 if target == MessageTarget::Default {
3866 self.send_message(msg);
3867 } else {
3868 tracing::trace!(typ = ?T::MESSAGE_TYPE, ?msg, "sending message");
3869 let message = OutgoingMessage::new(msg);
3872 if !self.notifier.send_message(&message, target) {
3873 tracelimit::warn_ratelimited!(?target, "failed to send message");
3874 }
3875 }
3876 }
3877
3878 fn send_offer(&mut self, channel: &mut Channel, connection_info: &ConnectionInfo) {
3880 let info = channel.info.as_ref().expect("assigned");
3881 let mut flags = channel.offer.flags;
3882 if !connection_info
3883 .version
3884 .feature_flags
3885 .confidential_channels()
3886 {
3887 flags.set_confidential_ring_buffer(false);
3888 flags.set_confidential_external_memory(false);
3889 }
3890
3891 let monitor_id = connection_info.monitor_page.and(info.monitor_id);
3898 let msg = protocol::OfferChannel {
3899 interface_id: channel.offer.interface_id,
3900 instance_id: channel.offer.instance_id,
3901 rsvd: [0; 4],
3902 flags,
3903 mmio_megabytes: channel.offer.mmio_megabytes,
3904 user_defined: channel.offer.user_defined,
3905 subchannel_index: channel.offer.subchannel_index,
3906 mmio_megabytes_optional: channel.offer.mmio_megabytes_optional,
3907 channel_id: info.channel_id,
3908 monitor_id: monitor_id.unwrap_or(MonitorId::INVALID).0,
3909 monitor_allocated: monitor_id.is_some().into(),
3910 is_dedicated: 1,
3913 connection_id: info.connection_id,
3914 };
3915 tracing::info!(
3916 channel_id = msg.channel_id.0,
3917 connection_id = msg.connection_id,
3918 key = %channel.offer.key(),
3919 "sending offer to guest"
3920 );
3921
3922 self.send_message(&msg);
3923 }
3924
3925 fn send_open_result(
3926 &mut self,
3927 channel_id: ChannelId,
3928 open_request: &OpenRequest,
3929 result: i32,
3930 target: MessageTarget,
3931 ) {
3932 self.send_message_with_target(
3933 &protocol::OpenResult {
3934 channel_id,
3935 open_id: open_request.open_id,
3936 status: result as u32,
3937 },
3938 target,
3939 );
3940 }
3941
3942 fn send_gpadl_created(&mut self, channel_id: ChannelId, gpadl_id: GpadlId, status: i32) {
3943 self.send_message(&protocol::GpadlCreated {
3944 channel_id,
3945 gpadl_id,
3946 status,
3947 });
3948 }
3949
3950 fn send_gpadl_torndown(&mut self, gpadl_id: GpadlId) {
3951 self.send_message(&protocol::GpadlTorndown { gpadl_id });
3952 }
3953
3954 fn send_rescind(&mut self, info: &OfferedInfo) {
3955 tracing::info!(
3956 channel_id = info.channel_id.0,
3957 "rescinding channel from guest"
3958 );
3959
3960 self.send_message(&protocol::RescindChannelOffer {
3961 channel_id: info.channel_id,
3962 });
3963 }
3964}
3965
3966struct VersionResponseData {
3968 version: VersionInfo,
3969 state: protocol::ConnectionState,
3970 monitor_pages: Option<MonitorPageGpas>,
3971}
3972
3973impl VersionResponseData {
3974 fn new(version: VersionInfo, state: protocol::ConnectionState) -> Self {
3976 VersionResponseData {
3977 version,
3978 state,
3979 monitor_pages: None,
3980 }
3981 }
3982
3983 fn with_monitor_pages(mut self, monitor_pages: Option<MonitorPageGpas>) -> Self {
3985 self.monitor_pages = monitor_pages;
3986 self
3987 }
3988}