1mod saved_state;
5#[cfg(test)]
6mod tests;
7
8use crate::Guid;
9use crate::SINT;
10use crate::SynicMessage;
11use crate::monitor::AssignedMonitors;
12use crate::protocol::Version;
13use hvdef::Vtl;
14use inspect::Inspect;
15pub use saved_state::RestoreError;
16pub use saved_state::SavedState;
17pub use saved_state::SavedStateData;
18use slab::Slab;
19use std::cmp::min;
20use std::collections::VecDeque;
21use std::collections::hash_map::Entry;
22use std::collections::hash_map::HashMap;
23use std::fmt::Display;
24use std::ops::Index;
25use std::ops::IndexMut;
26use std::task::Poll;
27use std::task::ready;
28use std::time::Duration;
29use thiserror::Error;
30use vmbus_channel::bus::ChannelType;
31use vmbus_channel::bus::GpadlRequest;
32use vmbus_channel::bus::OfferKey;
33use vmbus_channel::bus::OfferParams;
34use vmbus_channel::bus::OpenData;
35use vmbus_channel::bus::RestoredGpadl;
36use vmbus_core::HvsockConnectRequest;
37use vmbus_core::HvsockConnectResult;
38use vmbus_core::MaxVersionInfo;
39use vmbus_core::OutgoingMessage;
40use vmbus_core::VersionInfo;
41use vmbus_core::protocol;
42use vmbus_core::protocol::ChannelId;
43use vmbus_core::protocol::ConnectionId;
44use vmbus_core::protocol::FeatureFlags;
45use vmbus_core::protocol::GpadlId;
46use vmbus_core::protocol::Message;
47use vmbus_core::protocol::OfferFlags;
48use vmbus_core::protocol::UserDefinedData;
49use vmbus_ring::gparange;
50use vmcore::monitor::MonitorId;
51use vmcore::synic::MonitorInfo;
52use vmcore::synic::MonitorPageGpas;
53use zerocopy::FromZeros;
54use zerocopy::Immutable;
55use zerocopy::IntoBytes;
56use zerocopy::KnownLayout;
57
58#[derive(Debug, Error)]
60pub enum ChannelError {
61 #[error("unknown channel ID")]
62 UnknownChannelId,
63 #[error("unknown GPADL ID")]
64 UnknownGpadlId,
65 #[error("parse error")]
66 ParseError(#[from] protocol::ParseError),
67 #[error("invalid gpa range")]
68 InvalidGpaRange(#[source] gparange::Error),
69 #[error("duplicate GPADL ID")]
70 DuplicateGpadlId,
71 #[error("GPADL is already complete")]
72 GpadlAlreadyComplete,
73 #[error("GPADL channel ID mismatch")]
74 WrongGpadlChannelId,
75 #[error("trying to open an open channel")]
76 ChannelAlreadyOpen,
77 #[error("trying to close a closed channel")]
78 ChannelNotOpen,
79 #[error("invalid GPADL state for operation")]
80 InvalidGpadlState,
81 #[error("invalid channel state for operation")]
82 InvalidChannelState,
83 #[error("channel ID has already been released")]
84 ChannelReleased,
85 #[error("channel offers have already been sent")]
86 OffersAlreadySent,
87 #[error("invalid operation on reserved channel")]
88 ChannelReserved,
89 #[error("invalid operation on non-reserved channel")]
90 ChannelNotReserved,
91 #[error("received untrusted message for trusted connection")]
92 UntrustedMessage,
93 #[error("received a non-resuming message while paused")]
94 Paused,
95}
96
97#[derive(Debug, Error)]
98pub enum OfferError {
99 #[error("the channel ID {} is not valid for this operation", (.0).0)]
100 InvalidChannelId(ChannelId),
101 #[error("the channel ID {} is already in use", (.0).0)]
102 ChannelIdInUse(ChannelId),
103 #[error("offer {0} already exists")]
104 AlreadyExists(OfferKey),
105 #[error("specified resources do not match those of the existing saved or revoked offer")]
106 IncompatibleResources,
107 #[error("too many channels have been offered")]
108 TooManyChannels,
109 #[error("mismatched monitor ID from saved state; expected {0:?}, actual {1:?}")]
110 MismatchedMonitorId(Option<MonitorId>, MonitorId),
111}
112
113#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
115pub struct OfferId(usize);
116
117type IncompleteGpadlMap = HashMap<GpadlId, OfferId>;
118
119type GpadlMap = HashMap<(GpadlId, OfferId), Gpadl>;
120
121pub struct Server {
123 state: ConnectionState,
124 channels: ChannelList,
125 assigned_channels: AssignedChannels,
126 assigned_monitors: AssignedMonitors,
127 gpadls: GpadlMap,
128 incomplete_gpadls: IncompleteGpadlMap,
129 child_connection_id: u32,
130 max_version: Option<MaxVersionInfo>,
131 delayed_max_version: Option<MaxVersionInfo>,
132 pending_messages: PendingMessages,
135}
136
137pub struct ServerWithNotifier<'a, T> {
138 inner: &'a mut Server,
139 notifier: &'a mut T,
140}
141
142impl<T> Drop for ServerWithNotifier<'_, T> {
143 fn drop(&mut self) {
144 self.inner.validate();
145 }
146}
147
148impl<T: Notifier> Inspect for ServerWithNotifier<'_, T> {
149 fn inspect(&self, req: inspect::Request<'_>) {
150 let mut resp = req.respond();
151 let (state, info, next_action) = match &self.inner.state {
152 ConnectionState::Disconnected => ("disconnected", None, None),
153 ConnectionState::Connecting { info, .. } => ("connecting", Some(info), None),
154 ConnectionState::Connected(info) => (
155 if info.offers_sent {
156 "connected"
157 } else {
158 "negotiated"
159 },
160 Some(info),
161 None,
162 ),
163 ConnectionState::Disconnecting { next_action, .. } => {
164 ("disconnecting", None, Some(next_action))
165 }
166 };
167
168 resp.field("connection_info", info);
169 let next_action = next_action.map(|a| match a {
170 ConnectionAction::None => "disconnect",
171 ConnectionAction::Reset => "reset",
172 ConnectionAction::SendUnloadComplete => "unload",
173 ConnectionAction::Reconnect { .. } => "reconnect",
174 ConnectionAction::SendFailedVersionResponse => "send_version_response",
175 });
176 resp.field("state", state)
177 .field("next_action", next_action)
178 .field(
179 "assigned_monitors_bitmap",
180 format_args!("{:x}", self.inner.assigned_monitors.bitmap()),
181 )
182 .child("channels", |req| {
183 let mut resp = req.respond();
184 self.inner
185 .channels
186 .inspect(self.notifier, self.inner.get_version(), &mut resp);
187 for ((gpadl_id, offer_id), gpadl) in &self.inner.gpadls {
188 let channel = &self.inner.channels[*offer_id];
189 resp.field(
190 &channel_inspect_path(
191 &channel.offer,
192 format_args!("/gpadls/{}", gpadl_id.0),
193 ),
194 gpadl,
195 );
196 }
197 });
198 }
199}
200
201#[derive(Debug, Copy, Clone, Inspect)]
202struct ConnectionInfo {
203 version: VersionInfo,
204 trusted: bool,
207 offers_sent: bool,
208 interrupt_page: Option<u64>,
209 monitor_page: Option<MonitorPageGpas>,
210 target_message_vp: u32,
211 modifying: bool,
212 client_id: Guid,
213 paused: bool,
214}
215
216#[derive(Debug)]
218enum ConnectionState {
219 Disconnected,
220 Disconnecting {
221 next_action: ConnectionAction,
222 modify_sent: bool,
223 },
224 Connecting {
225 info: ConnectionInfo,
226 next_action: ConnectionAction,
227 },
228 Connected(ConnectionInfo),
229}
230
231impl ConnectionState {
232 fn check_version(&self, min_version: Version) -> bool {
234 matches!(self, ConnectionState::Connected(info) if info.version.version >= min_version)
235 }
236
237 fn check_feature_flags(&self, flags: impl Fn(FeatureFlags) -> bool) -> bool {
240 matches!(self, ConnectionState::Connected(info) if flags(info.version.feature_flags))
241 }
242
243 fn get_version(&self) -> Option<VersionInfo> {
244 if let ConnectionState::Connected(info) = self {
245 Some(info.version)
246 } else {
247 None
248 }
249 }
250
251 fn is_trusted(&self) -> bool {
252 match self {
253 ConnectionState::Connected(info) => info.trusted,
254 ConnectionState::Connecting { info, .. } => info.trusted,
255 _ => false,
256 }
257 }
258
259 fn is_paused(&self) -> bool {
260 if let ConnectionState::Connected(info) = self {
261 info.paused
262 } else {
263 false
264 }
265 }
266}
267
268#[derive(Debug, Copy, Clone)]
269enum ConnectionAction {
270 None,
271 Reset,
272 SendUnloadComplete,
273 Reconnect {
274 initiate_contact: InitiateContactRequest,
275 },
276 SendFailedVersionResponse,
277}
278
279#[derive(PartialEq, Eq, Debug, Copy, Clone)]
280pub enum MonitorPageRequest {
281 None,
282 Some(MonitorPageGpas),
283 Invalid,
284}
285
286#[derive(PartialEq, Eq, Debug, Copy, Clone)]
287pub struct InitiateContactRequest {
288 pub version_requested: u32,
289 pub target_message_vp: u32,
290 pub monitor_page: MonitorPageRequest,
291 pub target_sint: u8,
292 pub target_vtl: u8,
293 pub feature_flags: u32,
294 pub interrupt_page: Option<u64>,
295 pub client_id: Guid,
296 pub trusted: bool,
297}
298
299#[derive(Debug, Copy, Clone)]
300pub struct OpenRequest {
301 pub open_id: u32,
302 pub ring_buffer_gpadl_id: GpadlId,
303 pub target_vp: u32,
304 pub downstream_ring_buffer_page_offset: u32,
305 pub user_data: UserDefinedData,
306 pub guest_specified_interrupt_info: Option<SignalInfo>,
307 pub flags: protocol::OpenChannelFlags,
308}
309
310#[derive(Debug, Copy, Clone, Eq, PartialEq)]
311pub enum Update<T: std::fmt::Debug + Copy + Clone> {
312 Unchanged,
313 Reset,
314 Set(T),
315}
316
317impl<T: std::fmt::Debug + Copy + Clone> From<Option<T>> for Update<T> {
318 fn from(value: Option<T>) -> Self {
319 match value {
320 None => Self::Reset,
321 Some(value) => Self::Set(value),
322 }
323 }
324}
325
326#[derive(Debug, Copy, Clone, Eq, PartialEq)]
327pub struct ModifyConnectionRequest {
328 pub version: Option<u32>,
329 pub monitor_page: Update<MonitorPageGpas>,
330 pub interrupt_page: Update<u64>,
331 pub target_message_vp: Option<u32>,
332 pub notify_relay: bool,
333}
334
335impl Default for ModifyConnectionRequest {
337 fn default() -> Self {
338 Self {
339 version: None,
340 monitor_page: Update::Unchanged,
341 interrupt_page: Update::Unchanged,
342 target_message_vp: None,
343 notify_relay: true,
344 }
345 }
346}
347
348impl From<protocol::ModifyConnection> for ModifyConnectionRequest {
349 fn from(value: protocol::ModifyConnection) -> Self {
350 let monitor_page = if value.parent_to_child_monitor_page_gpa != 0 {
351 Update::Set(MonitorPageGpas {
352 parent_to_child: value.parent_to_child_monitor_page_gpa,
353 child_to_parent: value.child_to_parent_monitor_page_gpa,
354 })
355 } else {
356 Update::Reset
357 };
358
359 Self {
360 monitor_page,
361 ..Default::default()
362 }
363 }
364}
365
366#[derive(Debug, Copy, Clone)]
368pub enum ModifyConnectionResponse {
369 Supported(protocol::ConnectionState, FeatureFlags),
374 Unsupported,
377}
378
379#[derive(Debug, Copy, Clone)]
380pub enum ModifyState {
381 NotModifying,
382 Modifying { pending_target_vp: Option<u32> },
383}
384
385impl ModifyState {
386 pub fn is_modifying(&self) -> bool {
387 matches!(self, ModifyState::Modifying { .. })
388 }
389}
390
391#[derive(Debug, Copy, Clone)]
392pub struct SignalInfo {
393 pub event_flag: u16,
394 pub connection_id: u32,
395}
396
397#[derive(Debug, Copy, Clone, PartialEq, Eq)]
398enum RestoreState {
399 New,
401 Restoring,
405 Unmatched,
408 Restored,
410}
411
412#[derive(Debug, Clone)]
414enum ChannelState {
415 ClientReleased,
419
420 Closed,
422
423 Opening {
426 request: OpenRequest,
427 reserved_state: Option<ReservedState>,
428 },
429
430 Open {
432 params: OpenRequest,
433 modify_state: ModifyState,
434 reserved_state: Option<ReservedState>,
435 },
436
437 Closing {
439 params: OpenRequest,
440 reserved_state: Option<ReservedState>,
441 },
442
443 ClosingReopen {
446 params: OpenRequest,
447 request: OpenRequest,
448 },
449
450 Revoked,
452
453 Reoffered,
456
457 ClosingClientRelease,
460
461 OpeningClientRelease,
464}
465
466impl ChannelState {
467 fn is_released(&self) -> bool {
470 match self {
471 ChannelState::Closed
472 | ChannelState::Opening { .. }
473 | ChannelState::Open { .. }
474 | ChannelState::Closing { .. }
475 | ChannelState::ClosingReopen { .. }
476 | ChannelState::Revoked
477 | ChannelState::Reoffered => false,
478
479 ChannelState::ClientReleased
480 | ChannelState::ClosingClientRelease
481 | ChannelState::OpeningClientRelease => true,
482 }
483 }
484
485 fn is_revoked(&self) -> bool {
487 match self {
488 ChannelState::Revoked | ChannelState::Reoffered => true,
489
490 ChannelState::ClientReleased
491 | ChannelState::Closed
492 | ChannelState::Opening { .. }
493 | ChannelState::Open { .. }
494 | ChannelState::Closing { .. }
495 | ChannelState::ClosingReopen { .. }
496 | ChannelState::ClosingClientRelease
497 | ChannelState::OpeningClientRelease => false,
498 }
499 }
500
501 fn is_reserved(&self) -> bool {
502 match self {
503 ChannelState::Open {
505 reserved_state: Some(_),
506 ..
507 }
508 | ChannelState::Opening {
509 reserved_state: Some(_),
510 ..
511 }
512 | ChannelState::Closing {
513 reserved_state: Some(_),
514 ..
515 } => true,
516
517 ChannelState::Opening { .. }
518 | ChannelState::Open { .. }
519 | ChannelState::Closing { .. }
520 | ChannelState::ClientReleased
521 | ChannelState::Closed
522 | ChannelState::ClosingReopen { .. }
523 | ChannelState::Revoked
524 | ChannelState::Reoffered
525 | ChannelState::ClosingClientRelease
526 | ChannelState::OpeningClientRelease => false,
527 }
528 }
529}
530
531impl Display for ChannelState {
532 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
533 let state = match self {
534 Self::ClientReleased => "ClientReleased",
535 Self::Closed => "Closed",
536 Self::Opening { .. } => "Opening",
537 Self::Open { .. } => "Open",
538 Self::Closing { .. } => "Closing",
539 Self::ClosingReopen { .. } => "ClosingReopen",
540 Self::Revoked => "Revoked",
541 Self::Reoffered => "Reoffered",
542 Self::ClosingClientRelease => "ClosingClientRelease",
543 Self::OpeningClientRelease => "OpeningClientRelease",
544 };
545 write!(f, "{}", state)
546 }
547}
548
549#[derive(Debug, Clone, Default, mesh::MeshPayload)]
551pub enum MnfUsage {
552 #[default]
554 Disabled,
555 Enabled { latency: Duration },
557 Relayed { monitor_id: u8 },
560}
561
562impl MnfUsage {
563 pub fn is_enabled(&self) -> bool {
564 matches!(self, Self::Enabled { .. })
565 }
566
567 pub fn is_relayed(&self) -> bool {
568 matches!(self, Self::Relayed { .. })
569 }
570
571 pub fn enabled_and_then<T>(&self, f: impl FnOnce(Duration) -> Option<T>) -> Option<T> {
572 if let Self::Enabled { latency } = self {
573 f(*latency)
574 } else {
575 None
576 }
577 }
578}
579
580impl From<Option<Duration>> for MnfUsage {
581 fn from(value: Option<Duration>) -> Self {
582 match value {
583 None => Self::Disabled,
584 Some(latency) => Self::Enabled { latency },
585 }
586 }
587}
588
589#[derive(Debug, Clone, Default, mesh::MeshPayload)]
590pub struct OfferParamsInternal {
591 pub interface_name: String,
593 pub instance_id: Guid,
594 pub interface_id: Guid,
595 pub mmio_megabytes: u16,
596 pub mmio_megabytes_optional: u16,
597 pub subchannel_index: u16,
598 pub use_mnf: MnfUsage,
599 pub offer_order: Option<u32>,
600 pub flags: OfferFlags,
601 pub user_defined: UserDefinedData,
602}
603
604impl OfferParamsInternal {
605 pub fn key(&self) -> OfferKey {
607 OfferKey {
608 interface_id: self.interface_id,
609 instance_id: self.instance_id,
610 subchannel_index: self.subchannel_index,
611 }
612 }
613}
614
615impl From<OfferParams> for OfferParamsInternal {
616 fn from(value: OfferParams) -> Self {
617 let mut user_defined = UserDefinedData::new_zeroed();
618
619 let mut flags = OfferFlags::new()
622 .with_confidential_ring_buffer(true)
623 .with_confidential_external_memory(value.allow_confidential_external_memory);
624
625 match value.channel_type {
626 ChannelType::Device { pipe_packets } => {
627 if pipe_packets {
628 flags.set_named_pipe_mode(true);
629 user_defined.as_pipe_params_mut().pipe_type = protocol::PipeType::MESSAGE;
630 }
631 }
632 ChannelType::Interface {
633 user_defined: interface_user_defined,
634 } => {
635 flags.set_enumerate_device_interface(true);
636 user_defined = interface_user_defined;
637 }
638 ChannelType::Pipe { message_mode } => {
639 flags.set_enumerate_device_interface(true);
640 flags.set_named_pipe_mode(true);
641 user_defined.as_pipe_params_mut().pipe_type = if message_mode {
642 protocol::PipeType::MESSAGE
643 } else {
644 protocol::PipeType::BYTE
645 };
646 }
647 ChannelType::HvSocket {
648 is_connect,
649 is_for_container,
650 silo_id,
651 } => {
652 flags.set_enumerate_device_interface(true);
653 flags.set_tlnpi_provider(true);
654 flags.set_named_pipe_mode(true);
655 *user_defined.as_hvsock_params_mut() = protocol::HvsockUserDefinedParameters::new(
656 is_connect,
657 is_for_container,
658 silo_id,
659 );
660 }
661 };
662
663 Self {
664 interface_name: value.interface_name,
665 instance_id: value.instance_id,
666 interface_id: value.interface_id,
667 mmio_megabytes: value.mmio_megabytes,
668 mmio_megabytes_optional: value.mmio_megabytes_optional,
669 subchannel_index: value.subchannel_index,
670 use_mnf: value.mnf_interrupt_latency.into(),
671 offer_order: value.offer_order,
672 user_defined,
673 flags,
674 }
675 }
676}
677
678#[derive(Debug, Copy, Clone, Inspect, PartialEq, Eq)]
679pub struct ConnectionTarget {
680 pub vp: u32,
681 pub sint: u8,
682}
683
684#[derive(Debug, Copy, Clone, PartialEq, Eq)]
685pub enum MessageTarget {
686 Default,
687 ReservedChannel(OfferId, ConnectionTarget),
688 Custom(ConnectionTarget),
689}
690
691impl MessageTarget {
692 pub fn for_offer(offer_id: OfferId, reserved_state: &Option<ReservedState>) -> Self {
693 if let Some(state) = reserved_state {
694 Self::ReservedChannel(offer_id, state.target)
695 } else {
696 Self::Default
697 }
698 }
699}
700
701#[derive(Debug, Copy, Clone)]
702pub struct ReservedState {
703 version: VersionInfo,
704 target: ConnectionTarget,
705}
706
707#[derive(Debug)]
709struct Channel {
710 info: Option<OfferedInfo>,
711 offer: OfferParamsInternal,
712 state: ChannelState,
713 restore_state: RestoreState,
714}
715
716#[derive(Debug, Copy, Clone)]
717struct OfferedInfo {
718 channel_id: ChannelId,
719 connection_id: u32,
720 monitor_id: Option<MonitorId>,
721}
722
723impl Channel {
724 fn inspect_state(&self, resp: &mut inspect::Response<'_>) {
725 let mut target_vp = None;
726 let mut event_flag = None;
727 let mut connection_id = None;
728 let mut reserved_target = None;
729 let state = match &self.state {
730 ChannelState::ClientReleased => "client_released",
731 ChannelState::Closed => "closed",
732 ChannelState::Opening { reserved_state, .. } => {
733 reserved_target = reserved_state.map(|state| state.target);
734 "opening"
735 }
736 ChannelState::Open {
737 params,
738 reserved_state,
739 ..
740 } => {
741 target_vp = Some(params.target_vp);
742 if let Some(id) = params.guest_specified_interrupt_info {
743 event_flag = Some(id.event_flag);
744 connection_id = Some(id.connection_id);
745 }
746 reserved_target = reserved_state.map(|state| state.target);
747 "open"
748 }
749 ChannelState::Closing { reserved_state, .. } => {
750 reserved_target = reserved_state.map(|state| state.target);
751 "closing"
752 }
753 ChannelState::ClosingReopen { .. } => "closing_reopen",
754 ChannelState::Revoked => "revoked",
755 ChannelState::Reoffered => "reoffered",
756 ChannelState::ClosingClientRelease => "closing_client_release",
757 ChannelState::OpeningClientRelease => "opening_client_release",
758 };
759 let restore_state = match self.restore_state {
760 RestoreState::New => "new",
761 RestoreState::Restoring => "restoring",
762 RestoreState::Restored => "restored",
763 RestoreState::Unmatched => "unmatched",
764 };
765 if let Some(info) = &self.info {
766 resp.field("channel_id", info.channel_id.0)
767 .field("offered_connection_id", info.connection_id)
768 .field("monitor_id", info.monitor_id.map(|id| id.0));
769 }
770 resp.field("state", state)
771 .field("restore_state", restore_state)
772 .field("interface_name", self.offer.interface_name.clone())
773 .display("instance_id", &self.offer.instance_id)
774 .display("interface_id", &self.offer.interface_id)
775 .field("mmio_megabytes", self.offer.mmio_megabytes)
776 .field("target_vp", target_vp)
777 .field("guest_specified_event_flag", event_flag)
778 .field("guest_specified_connection_id", connection_id)
779 .field("reserved_connection_target", reserved_target)
780 .binary("offer_flags", self.offer.flags.into_bits());
781 }
782
783 fn handled_monitor_info(&self) -> Option<MonitorInfo> {
793 self.offer.use_mnf.enabled_and_then(|latency| {
794 if self.state.is_reserved() {
795 None
796 } else {
797 self.info.and_then(|info| {
798 info.monitor_id.map(|monitor_id| MonitorInfo {
799 monitor_id,
800 latency,
801 })
802 })
803 }
804 })
805 }
806
807 fn prepare_channel(
810 &mut self,
811 offer_id: OfferId,
812 assigned_channels: &mut AssignedChannels,
813 assigned_monitors: &mut AssignedMonitors,
814 ) {
815 assert!(self.info.is_none());
816
817 let entry = assigned_channels
819 .allocate()
820 .expect("there are enough channel IDs for everything in ChannelList");
821
822 let channel_id = entry.id();
823 entry.insert(offer_id);
824 let connection_id = ConnectionId::new(channel_id.0, assigned_channels.vtl, SINT);
825
826 let monitor_id = match self.offer.use_mnf {
831 MnfUsage::Enabled { .. } => {
832 let monitor_id = assigned_monitors.assign_monitor();
833 if monitor_id.is_none() {
834 tracelimit::warn_ratelimited!("Out of monitor IDs.");
835 }
836
837 monitor_id
838 }
839 MnfUsage::Relayed { monitor_id } => Some(MonitorId(monitor_id)),
840 MnfUsage::Disabled => None,
841 };
842
843 self.info = Some(OfferedInfo {
844 channel_id,
845 connection_id: connection_id.0,
846 monitor_id,
847 });
848 }
849
850 fn release_channel(
852 &mut self,
853 offer_id: OfferId,
854 assigned_channels: &mut AssignedChannels,
855 assigned_monitors: &mut AssignedMonitors,
856 ) {
857 if let Some(info) = self.info.take() {
858 assigned_channels.free(info.channel_id, offer_id);
859
860 if let Some(monitor_id) = info.monitor_id {
862 if self.offer.use_mnf.is_enabled() {
863 assigned_monitors.release_monitor(monitor_id);
864 }
865 }
866 }
867 }
868}
869
870#[derive(Debug)]
871struct AssignedChannels {
872 assignments: Vec<Option<OfferId>>,
873 vtl: Vtl,
874 reserved_offset: usize,
875 count_in_reserved_range: usize,
877}
878
879impl AssignedChannels {
880 fn new(vtl: Vtl, channel_id_offset: u16) -> Self {
881 Self {
882 assignments: vec![None; MAX_CHANNELS],
883 vtl,
884 reserved_offset: channel_id_offset as usize,
885 count_in_reserved_range: 0,
886 }
887 }
888
889 fn allowable_channel_count(&self) -> usize {
890 MAX_CHANNELS - self.reserved_offset + self.count_in_reserved_range
891 }
892
893 fn get(&self, channel_id: ChannelId) -> Option<OfferId> {
894 self.assignments
895 .get(Self::index(channel_id))
896 .copied()
897 .flatten()
898 }
899
900 fn set(&mut self, channel_id: ChannelId) -> Result<AssignmentEntry<'_>, OfferError> {
901 let index = Self::index(channel_id);
902 if self
903 .assignments
904 .get(index)
905 .ok_or(OfferError::InvalidChannelId(channel_id))?
906 .is_some()
907 {
908 return Err(OfferError::ChannelIdInUse(channel_id));
909 }
910 Ok(AssignmentEntry { list: self, index })
911 }
912
913 fn allocate(&mut self) -> Option<AssignmentEntry<'_>> {
914 let index = self.reserved_offset
915 + self.assignments[self.reserved_offset..]
916 .iter()
917 .position(|x| x.is_none())?;
918 Some(AssignmentEntry { list: self, index })
919 }
920
921 fn free(&mut self, channel_id: ChannelId, offer_id: OfferId) {
922 let index = Self::index(channel_id);
923 let slot = &mut self.assignments[index];
924 assert_eq!(slot.take(), Some(offer_id));
925 if index < self.reserved_offset {
926 self.count_in_reserved_range -= 1;
927 }
928 }
929
930 fn index(channel_id: ChannelId) -> usize {
931 channel_id.0.wrapping_sub(1) as usize
932 }
933}
934
935struct AssignmentEntry<'a> {
936 list: &'a mut AssignedChannels,
937 index: usize,
938}
939
940impl AssignmentEntry<'_> {
941 pub fn id(&self) -> ChannelId {
942 ChannelId(self.index as u32 + 1)
943 }
944
945 pub fn insert(self, offer_id: OfferId) {
946 assert!(
947 self.list.assignments[self.index]
948 .replace(offer_id)
949 .is_none()
950 );
951
952 if self.index < self.list.reserved_offset {
953 self.list.count_in_reserved_range += 1;
954 }
955 }
956}
957
958struct ChannelList {
959 channels: Slab<Channel>,
960}
961
962fn channel_inspect_path(offer: &OfferParamsInternal, suffix: std::fmt::Arguments<'_>) -> String {
963 if offer.subchannel_index == 0 {
964 format!("{}{}", offer.instance_id, suffix)
965 } else {
966 format!(
967 "{}/subchannels/{}{}",
968 offer.instance_id, offer.subchannel_index, suffix
969 )
970 }
971}
972
973impl ChannelList {
974 fn inspect(
975 &self,
976 notifier: &impl Notifier,
977 version: Option<VersionInfo>,
978 resp: &mut inspect::Response<'_>,
979 ) {
980 for (offer_id, channel) in self.iter() {
981 resp.child(
982 &channel_inspect_path(&channel.offer, format_args!("")),
983 |req| {
984 let mut resp = req.respond();
985 channel.inspect_state(&mut resp);
986
987 resp.merge(inspect::adhoc(|req| {
991 if !matches!(channel.state, ChannelState::Revoked) {
992 notifier.inspect(version, offer_id, req);
993 }
994 }));
995 },
996 );
997 }
998 }
999}
1000
1001pub const MAX_CHANNELS: usize = 2047;
1004
1005impl ChannelList {
1006 fn new() -> Self {
1007 Self {
1008 channels: Slab::new(),
1009 }
1010 }
1011
1012 fn len(&self) -> usize {
1014 self.channels.len()
1015 }
1016
1017 fn offer(&mut self, new_channel: Channel) -> OfferId {
1019 OfferId(self.channels.insert(new_channel))
1020 }
1021
1022 fn remove(&mut self, offer_id: OfferId) {
1024 let channel = self.channels.remove(offer_id.0);
1025 assert!(channel.info.is_none());
1026 }
1027
1028 fn get_by_channel_id_mut(
1030 &mut self,
1031 assigned_channels: &AssignedChannels,
1032 channel_id: ChannelId,
1033 ) -> Result<(OfferId, &mut Channel), ChannelError> {
1034 let offer_id = assigned_channels
1035 .get(channel_id)
1036 .ok_or(ChannelError::UnknownChannelId)?;
1037 let channel = &mut self[offer_id];
1038 if channel.state.is_released() {
1039 return Err(ChannelError::ChannelReleased);
1040 }
1041 assert_eq!(
1042 channel.info.as_ref().map(|info| info.channel_id),
1043 Some(channel_id)
1044 );
1045 Ok((offer_id, channel))
1046 }
1047
1048 fn get_by_channel_id(
1050 &self,
1051 assigned_channels: &AssignedChannels,
1052 channel_id: ChannelId,
1053 ) -> Result<(OfferId, &Channel), ChannelError> {
1054 let offer_id = assigned_channels
1055 .get(channel_id)
1056 .ok_or(ChannelError::UnknownChannelId)?;
1057 let channel = &self[offer_id];
1058 if channel.state.is_released() {
1059 return Err(ChannelError::ChannelReleased);
1060 }
1061 assert_eq!(
1062 channel.info.as_ref().map(|info| info.channel_id),
1063 Some(channel_id)
1064 );
1065 Ok((offer_id, channel))
1066 }
1067
1068 fn get_by_key_mut(&mut self, key: &OfferKey) -> Option<(OfferId, &mut Channel)> {
1071 for (offer_id, channel) in self.iter_mut() {
1072 if channel.offer.instance_id == key.instance_id
1073 && channel.offer.interface_id == key.interface_id
1074 && channel.offer.subchannel_index == key.subchannel_index
1075 {
1076 return Some((offer_id, channel));
1077 }
1078 }
1079 None
1080 }
1081
1082 fn iter(&self) -> impl Iterator<Item = (OfferId, &Channel)> {
1084 self.channels
1085 .iter()
1086 .map(|(id, channel)| (OfferId(id), channel))
1087 }
1088
1089 fn iter_mut(&mut self) -> impl Iterator<Item = (OfferId, &mut Channel)> {
1091 self.channels
1092 .iter_mut()
1093 .map(|(id, channel)| (OfferId(id), channel))
1094 }
1095
1096 fn retain<F>(&mut self, mut f: F)
1098 where
1099 F: FnMut(OfferId, &mut Channel) -> bool,
1100 {
1101 self.channels.retain(|id, channel| {
1102 let retain = f(OfferId(id), channel);
1103 if !retain {
1104 assert!(channel.info.is_none());
1105 }
1106 retain
1107 })
1108 }
1109}
1110
1111impl Index<OfferId> for ChannelList {
1112 type Output = Channel;
1113
1114 fn index(&self, offer_id: OfferId) -> &Self::Output {
1115 &self.channels[offer_id.0]
1116 }
1117}
1118
1119impl IndexMut<OfferId> for ChannelList {
1120 fn index_mut(&mut self, offer_id: OfferId) -> &mut Self::Output {
1121 &mut self.channels[offer_id.0]
1122 }
1123}
1124
1125#[derive(Debug, Inspect)]
1127struct Gpadl {
1128 count: u16,
1129 #[inspect(skip)]
1130 buf: Vec<u64>,
1131 state: GpadlState,
1132}
1133
1134#[derive(Debug, Copy, Clone, PartialEq, Eq, Inspect)]
1135enum GpadlState {
1136 InProgress,
1138 Offered,
1140 OfferedTearingDown,
1143 Accepted,
1145 TearingDown,
1147}
1148
1149impl Gpadl {
1150 fn new(count: u16, len: usize) -> Self {
1153 Self {
1154 state: GpadlState::InProgress,
1155 count,
1156 buf: Vec::with_capacity(len),
1157 }
1158 }
1159
1160 fn append(&mut self, data: &[u8]) -> Result<bool, ChannelError> {
1162 if self.state == GpadlState::InProgress {
1163 let buf = &mut self.buf;
1164 let len = min(data.len() & !7, (buf.capacity() - buf.len()) * 8);
1169 let data = &data[..len];
1170 let start = buf.len();
1171 buf.resize(buf.len() + data.len() / 8, 0);
1172 buf[start..].as_mut_bytes().copy_from_slice(data);
1173 Ok(if buf.len() == buf.capacity() {
1174 gparange::MultiPagedRangeBuf::<Vec<u64>>::validate(self.count as usize, buf)
1175 .map_err(ChannelError::InvalidGpaRange)?;
1176 self.state = GpadlState::Offered;
1177 true
1178 } else {
1179 false
1180 })
1181 } else {
1182 Err(ChannelError::GpadlAlreadyComplete)
1183 }
1184 }
1185}
1186
1187#[derive(Debug, Copy, Clone)]
1189pub struct OpenParams {
1190 pub open_data: OpenData,
1191 pub connection_id: u32,
1192 pub event_flag: u16,
1193 pub monitor_info: Option<MonitorInfo>,
1194 pub flags: protocol::OpenChannelFlags,
1195 pub reserved_target: Option<ConnectionTarget>,
1196 pub channel_id: ChannelId,
1197}
1198
1199impl OpenParams {
1200 fn from_request(
1201 info: &OfferedInfo,
1202 request: &OpenRequest,
1203 monitor_info: Option<MonitorInfo>,
1204 reserved_target: Option<ConnectionTarget>,
1205 ) -> Self {
1206 let (event_flag, connection_id) = if let Some(id) = request.guest_specified_interrupt_info {
1209 (id.event_flag, id.connection_id)
1210 } else {
1211 (info.channel_id.0 as u16, info.connection_id)
1212 };
1213
1214 Self {
1215 open_data: OpenData {
1216 target_vp: request.target_vp,
1217 ring_offset: request.downstream_ring_buffer_page_offset,
1218 ring_gpadl_id: request.ring_buffer_gpadl_id,
1219 user_data: request.user_data,
1220 event_flag,
1221 connection_id,
1222 },
1223 connection_id,
1224 event_flag,
1225 monitor_info,
1226 flags: request.flags.with_unused(0),
1227 reserved_target,
1228 channel_id: info.channel_id,
1229 }
1230 }
1231}
1232
1233#[derive(Debug)]
1235pub enum Action {
1236 Open(OpenParams, VersionInfo),
1237 Close,
1238 Gpadl(GpadlId, u16, Vec<u64>),
1239 TeardownGpadl {
1240 gpadl_id: GpadlId,
1241 post_restore: bool,
1242 },
1243 Modify {
1244 target_vp: u32,
1245 },
1246}
1247
1248static SUPPORTED_VERSIONS: &[Version] = &[
1250 Version::V1,
1251 Version::Win7,
1252 Version::Win8,
1253 Version::Win8_1,
1254 Version::Win10,
1255 Version::Win10Rs3_0,
1256 Version::Win10Rs3_1,
1257 Version::Win10Rs4,
1258 Version::Win10Rs5,
1259 Version::Iron,
1260 Version::Copper,
1261];
1262
1263const SUPPORTED_FEATURE_FLAGS: FeatureFlags = FeatureFlags::new()
1266 .with_guest_specified_signal_parameters(true)
1267 .with_channel_interrupt_redirection(true)
1268 .with_modify_connection(true)
1269 .with_client_id(true)
1270 .with_pause_resume(true);
1271
1272pub trait Notifier: Send {
1274 fn notify(&mut self, offer_id: OfferId, action: Action);
1276
1277 fn forward_unhandled(&mut self, request: InitiateContactRequest);
1279
1280 fn modify_connection(&mut self, request: ModifyConnectionRequest) -> anyhow::Result<()>;
1286
1287 fn inspect(&self, version: Option<VersionInfo>, offer_id: OfferId, req: inspect::Request<'_>) {
1289 let _ = (version, offer_id, req);
1290 }
1291
1292 #[must_use]
1295 fn send_message(&mut self, message: &OutgoingMessage, target: MessageTarget) -> bool;
1296
1297 fn notify_hvsock(&mut self, request: &HvsockConnectRequest);
1299
1300 fn reset_complete(&mut self);
1302
1303 fn unload_complete(&mut self);
1305}
1306
1307impl Server {
1308 pub fn new(vtl: Vtl, child_connection_id: u32, channel_id_offset: u16) -> Self {
1310 Server {
1311 state: ConnectionState::Disconnected,
1312 channels: ChannelList::new(),
1313 assigned_channels: AssignedChannels::new(vtl, channel_id_offset),
1314 assigned_monitors: AssignedMonitors::new(),
1315 gpadls: Default::default(),
1316 incomplete_gpadls: Default::default(),
1317 child_connection_id,
1318 max_version: None,
1319 delayed_max_version: None,
1320 pending_messages: PendingMessages(VecDeque::new()),
1321 }
1322 }
1323
1324 pub fn with_notifier<'a, T: Notifier>(
1326 &'a mut self,
1327 notifier: &'a mut T,
1328 ) -> ServerWithNotifier<'a, T> {
1329 self.validate();
1330 ServerWithNotifier {
1331 inner: self,
1332 notifier,
1333 }
1334 }
1335
1336 fn validate(&self) {
1337 #[cfg(debug_assertions)]
1338 for (_, channel) in self.channels.iter() {
1339 let should_have_info = !channel.state.is_released();
1340 if channel.info.is_some() != should_have_info {
1341 panic!("channel invariant violation: {channel:?}");
1342 }
1343 }
1344 }
1345
1346 pub fn set_compatibility_version(&mut self, version: MaxVersionInfo, delay: bool) {
1348 if delay {
1349 self.delayed_max_version = Some(version)
1350 } else {
1351 tracing::info!(?version, "Limiting VmBus connections to version");
1352 self.max_version = Some(version);
1353 }
1354 }
1355
1356 pub fn channel_gpadls(&self, offer_id: OfferId) -> Vec<RestoredGpadl> {
1357 self.gpadls
1358 .iter()
1359 .filter_map(|(&(gpadl_id, gpadl_offer_id), gpadl)| {
1360 if offer_id != gpadl_offer_id {
1361 return None;
1362 }
1363 let accepted = match gpadl.state {
1364 GpadlState::Offered | GpadlState::OfferedTearingDown => false,
1365 GpadlState::Accepted => true,
1366 GpadlState::InProgress | GpadlState::TearingDown => return None,
1367 };
1368 Some(RestoredGpadl {
1369 request: GpadlRequest {
1370 id: gpadl_id,
1371 count: gpadl.count,
1372 buf: gpadl.buf.clone(),
1373 },
1374 accepted,
1375 })
1376 })
1377 .collect()
1378 }
1379
1380 pub fn get_version(&self) -> Option<VersionInfo> {
1381 self.state.get_version()
1382 }
1383
1384 pub fn get_restore_open_params(&self, offer_id: OfferId) -> Result<OpenParams, RestoreError> {
1385 let channel = &self.channels[offer_id];
1386
1387 match channel.restore_state {
1389 RestoreState::New => {
1390 return Err(RestoreError::MissingChannel(channel.offer.key()));
1394 }
1395 RestoreState::Restoring => {}
1396 RestoreState::Unmatched => unreachable!(),
1397 RestoreState::Restored => {
1398 return Err(RestoreError::AlreadyRestored(channel.offer.key()));
1399 }
1400 }
1401
1402 let info = channel
1403 .info
1404 .ok_or_else(|| RestoreError::MissingChannel(channel.offer.key()))?;
1405
1406 let (request, reserved_state) = match channel.state {
1407 ChannelState::Closed => {
1408 return Err(RestoreError::MismatchedOpenState(channel.offer.key()));
1409 }
1410 ChannelState::Closing { params, .. } | ChannelState::ClosingReopen { params, .. } => {
1411 (params, None)
1412 }
1413 ChannelState::Opening {
1414 request,
1415 reserved_state,
1416 } => (request, reserved_state),
1417 ChannelState::Open {
1418 params,
1419 reserved_state,
1420 ..
1421 } => (params, reserved_state),
1422 ChannelState::ClientReleased | ChannelState::Reoffered => {
1423 return Err(RestoreError::MissingChannel(channel.offer.key()));
1424 }
1425 ChannelState::Revoked
1426 | ChannelState::ClosingClientRelease
1427 | ChannelState::OpeningClientRelease => unreachable!(),
1428 };
1429
1430 Ok(OpenParams::from_request(
1431 &info,
1432 &request,
1433 channel.handled_monitor_info(),
1434 reserved_state.map(|state| state.target),
1435 ))
1436 }
1437
1438 pub fn has_pending_messages(&self) -> bool {
1440 !self.pending_messages.0.is_empty() && !self.state.is_paused()
1441 }
1442
1443 pub fn poll_flush_pending_messages(
1445 &mut self,
1446 mut send: impl FnMut(&OutgoingMessage) -> Poll<()>,
1447 ) -> Poll<()> {
1448 if !self.state.is_paused() {
1449 while let Some(message) = self.pending_messages.0.front() {
1450 ready!(send(message));
1451 self.pending_messages.0.pop_front();
1452 }
1453 }
1454
1455 Poll::Ready(())
1456 }
1457}
1458
1459impl<'a, N: 'a + Notifier> ServerWithNotifier<'a, N> {
1460 pub fn restore_channel(&mut self, offer_id: OfferId, open: bool) -> Result<(), RestoreError> {
1466 let channel = &mut self.inner.channels[offer_id];
1467
1468 match channel.restore_state {
1471 RestoreState::New => {
1472 if open {
1476 return Err(RestoreError::MissingChannel(channel.offer.key()));
1477 } else {
1478 return Ok(());
1479 }
1480 }
1481 RestoreState::Restoring => {}
1482 RestoreState::Unmatched => unreachable!(),
1483 RestoreState::Restored => {
1484 return Err(RestoreError::AlreadyRestored(channel.offer.key()));
1485 }
1486 }
1487
1488 let info = channel
1489 .info
1490 .ok_or_else(|| RestoreError::MissingChannel(channel.offer.key()))?;
1491
1492 if let Some(monitor_info) = channel.handled_monitor_info() {
1493 if !self
1494 .inner
1495 .assigned_monitors
1496 .claim_monitor(monitor_info.monitor_id)
1497 {
1498 return Err(RestoreError::DuplicateMonitorId(monitor_info.monitor_id.0));
1499 }
1500 }
1501
1502 if open {
1503 match channel.state {
1504 ChannelState::Closed => {
1505 return Err(RestoreError::MismatchedOpenState(channel.offer.key()));
1506 }
1507 ChannelState::Closing { .. } | ChannelState::ClosingReopen { .. } => {
1508 self.notifier.notify(offer_id, Action::Close);
1509 }
1510 ChannelState::Opening {
1511 request,
1512 reserved_state,
1513 } => {
1514 self.inner
1515 .pending_messages
1516 .sender(self.notifier, self.inner.state.is_paused())
1517 .send_open_result(
1518 info.channel_id,
1519 &request,
1520 protocol::STATUS_SUCCESS,
1521 MessageTarget::for_offer(offer_id, &reserved_state),
1522 );
1523 channel.state = ChannelState::Open {
1524 params: request,
1525 modify_state: ModifyState::NotModifying,
1526 reserved_state,
1527 };
1528 }
1529 ChannelState::Open { .. } => {}
1530 ChannelState::ClientReleased | ChannelState::Reoffered => {
1531 return Err(RestoreError::MissingChannel(channel.offer.key()));
1532 }
1533 ChannelState::Revoked
1534 | ChannelState::ClosingClientRelease
1535 | ChannelState::OpeningClientRelease => unreachable!(),
1536 };
1537 } else {
1538 match channel.state {
1539 ChannelState::Closed => {}
1540 ChannelState::Reoffered => {}
1545 ChannelState::Closing { .. } => {
1546 channel.state = ChannelState::Closed;
1547 }
1548 ChannelState::ClosingReopen { request, .. } => {
1549 self.notifier.notify(
1550 offer_id,
1551 Action::Open(
1552 OpenParams::from_request(
1553 &info,
1554 &request,
1555 channel.handled_monitor_info(),
1556 None,
1557 ),
1558 self.inner.state.get_version().expect("must be connected"),
1559 ),
1560 );
1561 channel.state = ChannelState::Opening {
1562 request,
1563 reserved_state: None,
1564 };
1565 }
1566 ChannelState::Opening {
1567 request,
1568 reserved_state,
1569 } => {
1570 self.notifier.notify(
1571 offer_id,
1572 Action::Open(
1573 OpenParams::from_request(
1574 &info,
1575 &request,
1576 channel.handled_monitor_info(),
1577 reserved_state.map(|state| state.target),
1578 ),
1579 self.inner.state.get_version().expect("must be connected"),
1580 ),
1581 );
1582 }
1583 ChannelState::Open { .. } => {
1584 return Err(RestoreError::MismatchedOpenState(channel.offer.key()));
1585 }
1586 ChannelState::ClientReleased => {
1587 return Err(RestoreError::MissingChannel(channel.offer.key()));
1588 }
1589 ChannelState::Revoked
1590 | ChannelState::ClosingClientRelease
1591 | ChannelState::OpeningClientRelease => unreachable!(),
1592 }
1593 }
1594
1595 channel.restore_state = RestoreState::Restored;
1596 Ok(())
1597 }
1598
1599 pub fn revoke_unclaimed_channels(&mut self) {
1602 for (offer_id, channel) in self.inner.channels.iter_mut() {
1603 match channel.restore_state {
1604 RestoreState::Restored => {
1605 }
1607 RestoreState::New => {
1608 if let ConnectionState::Connected(ConnectionInfo {
1613 offers_sent: true,
1614 version,
1615 ..
1616 }) = &self.inner.state
1617 {
1618 if matches!(channel.state, ChannelState::ClientReleased) {
1619 channel.prepare_channel(
1620 offer_id,
1621 &mut self.inner.assigned_channels,
1622 &mut self.inner.assigned_monitors,
1623 );
1624 channel.state = ChannelState::Closed;
1625 self.inner
1626 .pending_messages
1627 .sender(self.notifier, self.inner.state.is_paused())
1628 .send_offer(channel, *version);
1629 }
1630 }
1631 }
1632 RestoreState::Restoring => {
1633 let retain = revoke(
1637 self.inner
1638 .pending_messages
1639 .sender(self.notifier, self.inner.state.is_paused()),
1640 offer_id,
1641 channel,
1642 &mut self.inner.gpadls,
1643 );
1644 assert!(retain, "channel has not been released");
1645 channel.state = ChannelState::Reoffered;
1646 }
1647 RestoreState::Unmatched => {
1648 let retain = revoke(
1651 self.inner
1652 .pending_messages
1653 .sender(self.notifier, self.inner.state.is_paused()),
1654 offer_id,
1655 channel,
1656 &mut self.inner.gpadls,
1657 );
1658 assert!(retain, "channel has not been released");
1659 }
1660 }
1661 }
1662
1663 for (&(gpadl_id, offer_id), gpadl) in self.inner.gpadls.iter_mut() {
1665 match gpadl.state {
1666 GpadlState::InProgress | GpadlState::Accepted => {}
1667 GpadlState::Offered => {
1668 self.notifier.notify(
1669 offer_id,
1670 Action::Gpadl(gpadl_id, gpadl.count, gpadl.buf.clone()),
1671 );
1672 }
1673 GpadlState::TearingDown => {
1674 self.notifier.notify(
1675 offer_id,
1676 Action::TeardownGpadl {
1677 gpadl_id,
1678 post_restore: true,
1679 },
1680 );
1681 }
1682 GpadlState::OfferedTearingDown => unreachable!(),
1683 }
1684 }
1685
1686 self.check_disconnected();
1687 }
1688
1689 pub fn reset(&mut self) {
1694 assert!(!self.is_resetting());
1695 if self.request_disconnect(ConnectionAction::Reset) {
1696 self.complete_reset();
1697 }
1698 }
1699
1700 fn complete_reset(&mut self) {
1701 for (_, channel) in self.inner.channels.iter_mut() {
1703 channel.restore_state = RestoreState::New;
1704 }
1705 self.inner.pending_messages.0.clear();
1706 self.notifier.reset_complete();
1707 }
1708
1709 pub fn offer_channel(&mut self, offer: OfferParamsInternal) -> Result<OfferId, OfferError> {
1711 if let Some((offer_id, channel)) = self.inner.channels.get_by_key_mut(&offer.key()) {
1713 if channel.restore_state != RestoreState::Unmatched
1717 && !matches!(channel.state, ChannelState::Revoked)
1718 {
1719 return Err(OfferError::AlreadyExists(offer.key()));
1720 }
1721
1722 let info = channel.info.expect("assigned");
1723 if channel.restore_state == RestoreState::Unmatched {
1724 tracing::debug!(
1725 offer_id = offer_id.0,
1726 key = %channel.offer.key(),
1727 "matched channel"
1728 );
1729
1730 assert!(!matches!(channel.state, ChannelState::Revoked));
1731 channel.restore_state = RestoreState::Restoring;
1735
1736 if let MnfUsage::Relayed { monitor_id } = offer.use_mnf {
1739 if info.monitor_id != Some(MonitorId(monitor_id)) {
1740 return Err(OfferError::MismatchedMonitorId(
1741 info.monitor_id,
1742 MonitorId(monitor_id),
1743 ));
1744 }
1745 }
1746 } else {
1747 channel.state = ChannelState::Reoffered;
1751 tracing::info!(?offer_id, key = %channel.offer.key(), "channel marked for reoffer");
1752 }
1753
1754 channel.offer = offer;
1755 return Ok(offer_id);
1756 }
1757
1758 let mut connected_version = None;
1759 let state = match self.inner.state {
1760 ConnectionState::Connected(ConnectionInfo {
1761 offers_sent: true,
1762 version,
1763 ..
1764 }) => {
1765 connected_version = Some(version);
1766 ChannelState::Closed
1767 }
1768 ConnectionState::Connected(ConnectionInfo {
1769 offers_sent: false, ..
1770 })
1771 | ConnectionState::Connecting { .. }
1772 | ConnectionState::Disconnecting { .. }
1773 | ConnectionState::Disconnected => ChannelState::ClientReleased,
1774 };
1775
1776 if self.inner.channels.len() >= self.inner.assigned_channels.allowable_channel_count() {
1778 return Err(OfferError::TooManyChannels);
1779 }
1780
1781 let key = offer.key();
1782 let confidential_ring_buffer = offer.flags.confidential_ring_buffer();
1783 let confidential_external_memory = offer.flags.confidential_external_memory();
1784 let channel = Channel {
1785 info: None,
1786 offer,
1787 state,
1788 restore_state: RestoreState::New,
1789 };
1790
1791 let offer_id = self.inner.channels.offer(channel);
1792 if let Some(version) = connected_version {
1793 let channel = &mut self.inner.channels[offer_id];
1794 channel.prepare_channel(
1795 offer_id,
1796 &mut self.inner.assigned_channels,
1797 &mut self.inner.assigned_monitors,
1798 );
1799
1800 self.inner
1801 .pending_messages
1802 .sender(self.notifier, self.inner.state.is_paused())
1803 .send_offer(channel, version);
1804 }
1805
1806 tracing::info!(?offer_id, %key, confidential_ring_buffer, confidential_external_memory, "new channel");
1807 Ok(offer_id)
1808 }
1809
1810 pub fn revoke_channel(&mut self, offer_id: OfferId) {
1812 let channel = &mut self.inner.channels[offer_id];
1813 let retain = revoke(
1814 self.inner
1815 .pending_messages
1816 .sender(self.notifier, self.inner.state.is_paused()),
1817 offer_id,
1818 channel,
1819 &mut self.inner.gpadls,
1820 );
1821 if !retain {
1822 self.inner.channels.remove(offer_id);
1823 }
1824
1825 self.check_disconnected();
1826 }
1827
1828 pub fn open_complete(&mut self, offer_id: OfferId, result: i32) {
1830 tracing::debug!(offer_id = offer_id.0, result, "open complete");
1831
1832 let channel = &mut self.inner.channels[offer_id];
1833 match channel.state {
1834 ChannelState::Opening {
1835 request,
1836 reserved_state,
1837 } => {
1838 let channel_id = channel.info.expect("assigned").channel_id;
1839 if result >= 0 {
1840 tracelimit::info_ratelimited!(
1841 offer_id = offer_id.0,
1842 channel_id = channel_id.0,
1843 result,
1844 "opened channel"
1845 );
1846 } else {
1847 tracelimit::error_ratelimited!(
1849 offer_id = offer_id.0,
1850 channel_id = channel_id.0,
1851 result,
1852 "failed to open channel"
1853 );
1854 }
1855
1856 self.inner
1857 .pending_messages
1858 .sender(self.notifier, self.inner.state.is_paused())
1859 .send_open_result(
1860 channel_id,
1861 &request,
1862 result,
1863 MessageTarget::for_offer(offer_id, &reserved_state),
1864 );
1865 channel.state = if result >= 0 {
1866 ChannelState::Open {
1867 params: request,
1868 modify_state: ModifyState::NotModifying,
1869 reserved_state,
1870 }
1871 } else {
1872 ChannelState::Closed
1873 };
1874 }
1875 ChannelState::OpeningClientRelease => {
1876 tracing::info!(
1877 offer_id = offer_id.0,
1878 result,
1879 "opened channel (client released)"
1880 );
1881
1882 if result >= 0 {
1883 channel.state = ChannelState::ClosingClientRelease;
1884 self.notifier.notify(offer_id, Action::Close);
1885 } else {
1886 channel.state = ChannelState::ClientReleased;
1887 self.check_disconnected();
1888 }
1889 }
1890
1891 ChannelState::ClientReleased
1892 | ChannelState::Closed
1893 | ChannelState::Open { .. }
1894 | ChannelState::Closing { .. }
1895 | ChannelState::ClosingReopen { .. }
1896 | ChannelState::Revoked
1897 | ChannelState::Reoffered
1898 | ChannelState::ClosingClientRelease => {
1899 tracing::error!(?offer_id, state = ?channel.state, "invalid open complete")
1900 }
1901 }
1902 }
1903
1904 fn are_channels_reset(&self, include_reserved: bool) -> bool {
1907 self.inner.gpadls.keys().all(|(_, offer_id)| {
1908 !include_reserved && self.inner.channels[*offer_id].state.is_reserved()
1909 }) && self.inner.channels.iter().all(|(_, channel)| {
1910 matches!(channel.state, ChannelState::ClientReleased)
1911 || (!include_reserved && channel.state.is_reserved())
1912 })
1913 }
1914
1915 fn check_disconnected(&mut self) {
1919 match self.inner.state {
1920 ConnectionState::Disconnecting {
1921 next_action,
1922 modify_sent: false,
1923 } => {
1924 if self.are_channels_reset(matches!(next_action, ConnectionAction::Reset)) {
1925 self.notify_disconnect(next_action);
1926 }
1927 }
1928 ConnectionState::Disconnecting {
1929 modify_sent: true, ..
1930 }
1931 | ConnectionState::Disconnected
1932 | ConnectionState::Connected { .. }
1933 | ConnectionState::Connecting { .. } => (),
1934 }
1935 }
1936
1937 fn notify_disconnect(&mut self, next_action: ConnectionAction) {
1939 debug_assert!(self.are_channels_reset(matches!(next_action, ConnectionAction::Reset)));
1941 self.inner.state = ConnectionState::Disconnecting {
1942 next_action,
1943 modify_sent: true,
1944 };
1945
1946 self.notifier
1948 .modify_connection(ModifyConnectionRequest {
1949 monitor_page: Update::Reset,
1950 interrupt_page: Update::Reset,
1951 ..Default::default()
1952 })
1953 .expect("resetting state should not fail");
1954 }
1955
1956 fn is_resetting(&self) -> bool {
1959 matches!(
1960 &self.inner.state,
1961 ConnectionState::Connecting {
1962 next_action: ConnectionAction::Reset,
1963 ..
1964 } | ConnectionState::Disconnecting {
1965 next_action: ConnectionAction::Reset,
1966 ..
1967 }
1968 )
1969 }
1970
1971 pub fn close_complete(&mut self, offer_id: OfferId) {
1973 let channel = &mut self.inner.channels[offer_id];
1974 tracing::info!(offer_id = offer_id.0, "closed channel");
1975 match channel.state {
1976 ChannelState::Closing {
1977 reserved_state: Some(reserved_state),
1978 ..
1979 } => {
1980 channel.state = ChannelState::Closed;
1981 if matches!(self.inner.state, ConnectionState::Connected { .. }) {
1982 let channel_id = channel.info.expect("assigned").channel_id;
1983 self.send_close_reserved_channel_response(
1984 channel_id,
1985 offer_id,
1986 reserved_state.target,
1987 );
1988 } else {
1989 if Self::client_release_channel(
1992 self.inner
1993 .pending_messages
1994 .sender(self.notifier, self.inner.state.is_paused()),
1995 offer_id,
1996 channel,
1997 &mut self.inner.gpadls,
1998 &mut self.inner.assigned_channels,
1999 &mut self.inner.assigned_monitors,
2000 None,
2001 ) {
2002 self.inner.channels.remove(offer_id);
2003 }
2004 }
2005 }
2006 ChannelState::Closing { .. } => {
2007 channel.state = ChannelState::Closed;
2008 }
2009 ChannelState::ClosingClientRelease => {
2010 channel.state = ChannelState::ClientReleased;
2011 self.check_disconnected();
2012 }
2013 ChannelState::ClosingReopen { request, .. } => {
2014 channel.state = ChannelState::Closed;
2015 self.open_channel(offer_id, &request, None);
2016 }
2017
2018 ChannelState::Closed
2019 | ChannelState::ClientReleased
2020 | ChannelState::Opening { .. }
2021 | ChannelState::Open { .. }
2022 | ChannelState::Revoked
2023 | ChannelState::Reoffered
2024 | ChannelState::OpeningClientRelease => {
2025 tracing::error!(?offer_id, state = ?channel.state, "invalid close complete")
2026 }
2027 }
2028 }
2029
2030 fn send_close_reserved_channel_response(
2031 &mut self,
2032 channel_id: ChannelId,
2033 offer_id: OfferId,
2034 target: ConnectionTarget,
2035 ) {
2036 self.sender().send_message_with_target(
2037 &protocol::CloseReservedChannelResponse { channel_id },
2038 MessageTarget::ReservedChannel(offer_id, target),
2039 );
2040 }
2041
2042 fn handle_initiate_contact(
2045 &mut self,
2046 input: &protocol::InitiateContact2,
2047 message: &SynicMessage,
2048 includes_client_id: bool,
2049 ) -> Result<(), ChannelError> {
2050 let target_info =
2051 protocol::TargetInfo::from(input.initiate_contact.interrupt_page_or_target_info);
2052
2053 let target_sint = if message.multiclient
2054 && input.initiate_contact.version_requested >= Version::Win10Rs3_1 as u32
2055 {
2056 target_info.sint()
2057 } else {
2058 SINT
2059 };
2060
2061 let target_vtl = if message.multiclient
2062 && input.initiate_contact.version_requested >= Version::Win10Rs4 as u32
2063 {
2064 target_info.vtl()
2065 } else {
2066 0
2067 };
2068
2069 let feature_flags = if input.initiate_contact.version_requested >= Version::Copper as u32 {
2070 target_info.feature_flags()
2071 } else {
2072 0
2073 };
2074
2075 let target_message_vp =
2080 if input.initiate_contact.version_requested >= Version::Win8_1 as u32 {
2081 input.initiate_contact.target_message_vp
2082 } else {
2083 0
2084 };
2085
2086 let interrupt_page = (input.initiate_contact.version_requested < Version::Win8 as u32
2093 && input.initiate_contact.interrupt_page_or_target_info != 0)
2094 .then_some(input.initiate_contact.interrupt_page_or_target_info);
2095
2096 let monitor_page = if (input.initiate_contact.parent_to_child_monitor_page_gpa == 0)
2099 != (input.initiate_contact.child_to_parent_monitor_page_gpa == 0)
2100 {
2101 MonitorPageRequest::Invalid
2102 } else if input.initiate_contact.parent_to_child_monitor_page_gpa != 0 {
2103 MonitorPageRequest::Some(MonitorPageGpas {
2104 parent_to_child: input.initiate_contact.parent_to_child_monitor_page_gpa,
2105 child_to_parent: input.initiate_contact.child_to_parent_monitor_page_gpa,
2106 })
2107 } else {
2108 MonitorPageRequest::None
2109 };
2110
2111 let client_id = if FeatureFlags::from(feature_flags).client_id() {
2114 if includes_client_id {
2115 input.client_id
2116 } else {
2117 return Err(ChannelError::ParseError(
2118 protocol::ParseError::MessageTooSmall(Some(
2119 protocol::MessageType::INITIATE_CONTACT,
2120 )),
2121 ));
2122 }
2123 } else {
2124 Guid::ZERO
2125 };
2126
2127 let request = InitiateContactRequest {
2128 version_requested: input.initiate_contact.version_requested,
2129 target_message_vp,
2130 monitor_page,
2131 target_sint,
2132 target_vtl,
2133 feature_flags,
2134 interrupt_page,
2135 client_id,
2136 trusted: message.trusted,
2137 };
2138 self.initiate_contact(request);
2139 Ok(())
2140 }
2141
2142 pub fn initiate_contact(&mut self, request: InitiateContactRequest) {
2143 let vtl = self.inner.assigned_channels.vtl as u8;
2146 if request.target_vtl != vtl {
2147 self.notifier.forward_unhandled(request);
2149 return;
2150 }
2151
2152 if request.target_sint != SINT {
2153 tracelimit::warn_ratelimited!(
2154 "unsupported multiclient request for VTL {} SINT {}, version {:#x}",
2155 request.target_vtl,
2156 request.target_sint,
2157 request.version_requested,
2158 );
2159
2160 self.send_version_response_with_target(
2162 None,
2163 MessageTarget::Custom(ConnectionTarget {
2164 vp: request.target_message_vp,
2165 sint: request.target_sint,
2166 }),
2167 );
2168
2169 return;
2170 }
2171
2172 if !self.request_disconnect(ConnectionAction::Reconnect {
2173 initiate_contact: request,
2174 }) {
2175 return;
2176 }
2177
2178 let Some(version) = self.check_version_supported(&request) else {
2179 tracelimit::warn_ratelimited!(
2180 vtl,
2181 version = request.version_requested,
2182 client_id = ?request.client_id,
2183 "Guest requested unsupported version"
2184 );
2185
2186 self.send_version_response(None);
2188 return;
2189 };
2190
2191 tracelimit::info_ratelimited!(
2192 vtl,
2193 ?version,
2194 client_id = ?request.client_id,
2195 trusted = request.trusted,
2196 "Guest negotiated version"
2197 );
2198
2199 let monitor_page = match request.monitor_page {
2202 MonitorPageRequest::Some(mp) => Some(mp),
2203 MonitorPageRequest::None => None,
2204 MonitorPageRequest::Invalid => {
2205 self.send_version_response(Some((
2207 version,
2208 protocol::ConnectionState::FAILED_UNKNOWN_FAILURE,
2209 )));
2210
2211 return;
2212 }
2213 };
2214
2215 self.inner.state = ConnectionState::Connecting {
2216 info: ConnectionInfo {
2217 version,
2218 trusted: request.trusted,
2219 interrupt_page: request.interrupt_page,
2220 monitor_page,
2221 target_message_vp: request.target_message_vp,
2222 modifying: false,
2223 offers_sent: false,
2224 client_id: request.client_id,
2225 paused: false,
2226 },
2227 next_action: ConnectionAction::None,
2228 };
2229
2230 if let Err(err) = self.notifier.modify_connection(ModifyConnectionRequest {
2233 version: Some(request.version_requested),
2234 monitor_page: monitor_page.into(),
2235 interrupt_page: request.interrupt_page.into(),
2236 target_message_vp: Some(request.target_message_vp),
2237 notify_relay: true,
2238 }) {
2239 tracelimit::error_ratelimited!(?err, "server failed to change state");
2240 self.inner.state = ConnectionState::Disconnected;
2241 self.send_version_response(Some((
2242 version,
2243 protocol::ConnectionState::FAILED_UNKNOWN_FAILURE,
2244 )));
2245 }
2246 }
2247
2248 pub(crate) fn complete_initiate_contact(&mut self, response: ModifyConnectionResponse) {
2249 let ConnectionState::Connecting {
2250 mut info,
2251 next_action,
2252 } = self.inner.state
2253 else {
2254 panic!("Invalid state for completing InitiateContact.");
2255 };
2256
2257 const LOCAL_FEATURE_FLAGS: FeatureFlags = FeatureFlags::new()
2259 .with_client_id(true)
2260 .with_confidential_channels(true);
2261
2262 let relay_feature_flags = match response {
2263 ModifyConnectionResponse::Supported(
2265 protocol::ConnectionState::SUCCESSFUL,
2266 feature_flags,
2267 ) => feature_flags,
2268 ModifyConnectionResponse::Supported(connection_state, feature_flags) => {
2271 tracelimit::error_ratelimited!(
2272 ?connection_state,
2273 "initiate contact failed because relay request failed"
2274 );
2275
2276 info.version.feature_flags &= feature_flags | LOCAL_FEATURE_FLAGS;
2279
2280 self.send_version_response(Some((info.version, connection_state)));
2281 self.inner.state = ConnectionState::Disconnected;
2282 return;
2283 }
2284 ModifyConnectionResponse::Unsupported => {
2287 self.send_version_response(None);
2288 self.inner.state = ConnectionState::Disconnected;
2289 return;
2290 }
2291 };
2292
2293 info.version.feature_flags &= relay_feature_flags | LOCAL_FEATURE_FLAGS;
2296 self.inner.state = ConnectionState::Connected(info);
2297
2298 self.send_version_response(Some((info.version, protocol::ConnectionState::SUCCESSFUL)));
2299 if !matches!(next_action, ConnectionAction::None) && self.request_disconnect(next_action) {
2300 self.do_next_action(next_action);
2301 }
2302 }
2303
2304 fn check_version_supported(&self, request: &InitiateContactRequest) -> Option<VersionInfo> {
2306 let version = SUPPORTED_VERSIONS
2307 .iter()
2308 .find(|v| request.version_requested == **v as u32)
2309 .copied()?;
2310
2311 if let Some(max_version) = self.inner.max_version {
2313 if version as u32 > max_version.version {
2314 return None;
2315 }
2316 }
2317
2318 let supported_flags = if version >= Version::Copper {
2319 let max_supported_flags =
2321 SUPPORTED_FEATURE_FLAGS.with_confidential_channels(request.trusted);
2322
2323 if let Some(max_version) = self.inner.max_version {
2325 max_supported_flags & max_version.feature_flags
2326 } else {
2327 max_supported_flags
2328 }
2329 } else {
2330 FeatureFlags::new()
2331 };
2332
2333 let feature_flags = supported_flags & request.feature_flags.into();
2334
2335 assert!(version >= Version::Copper || feature_flags == FeatureFlags::new());
2336 if feature_flags.into_bits() != request.feature_flags {
2337 tracelimit::warn_ratelimited!(
2338 supported = feature_flags.into_bits(),
2339 requested = request.feature_flags,
2340 "Guest requested unsupported feature flags."
2341 );
2342 }
2343
2344 Some(VersionInfo {
2345 version,
2346 feature_flags,
2347 })
2348 }
2349
2350 fn send_version_response(&mut self, data: Option<(VersionInfo, protocol::ConnectionState)>) {
2351 self.send_version_response_with_target(data, MessageTarget::Default);
2352 }
2353
2354 fn send_version_response_with_target(
2355 &mut self,
2356 data: Option<(VersionInfo, protocol::ConnectionState)>,
2357 target: MessageTarget,
2358 ) {
2359 let mut response2 = protocol::VersionResponse2::new_zeroed();
2360 let response = &mut response2.version_response;
2361 let mut send_response2 = false;
2362 if let Some((version, state)) = data {
2363 if state == protocol::ConnectionState::SUCCESSFUL || version.version >= Version::Win8 {
2366 response.version_supported = 1;
2367 response.connection_state = state;
2368 response.selected_version_or_connection_id =
2369 if version.version >= Version::Win10Rs3_1 {
2370 self.inner.child_connection_id
2371 } else {
2372 version.version as u32
2373 };
2374
2375 if version.version >= Version::Copper {
2376 response2.supported_features = version.feature_flags.into();
2377 send_response2 = true;
2378 }
2379 }
2380 }
2381
2382 if send_response2 {
2383 self.sender().send_message_with_target(&response2, target);
2384 } else {
2385 self.sender().send_message_with_target(response, target);
2386 }
2387 }
2388
2389 fn request_disconnect(&mut self, new_action: ConnectionAction) -> bool {
2392 assert!(!self.is_resetting());
2393
2394 let gpadls = &mut self.inner.gpadls;
2396 let vm_reset = matches!(new_action, ConnectionAction::Reset);
2397 self.inner.channels.retain(|offer_id, channel| {
2398 (!vm_reset && channel.state.is_reserved())
2400 || !Self::client_release_channel(
2401 self.inner
2402 .pending_messages
2403 .sender(self.notifier, self.inner.state.is_paused()),
2404 offer_id,
2405 channel,
2406 gpadls,
2407 &mut self.inner.assigned_channels,
2408 &mut self.inner.assigned_monitors,
2409 None,
2410 )
2411 });
2412
2413 match &mut self.inner.state {
2417 ConnectionState::Disconnected => {
2418 if vm_reset {
2420 if !self.are_channels_reset(true) {
2421 self.inner.state = ConnectionState::Disconnecting {
2422 next_action: ConnectionAction::Reset,
2423 modify_sent: false,
2424 };
2425 }
2426 } else {
2427 assert!(self.are_channels_reset(false));
2428 }
2429 }
2430
2431 ConnectionState::Connected { .. } => {
2432 if self.are_channels_reset(vm_reset) {
2433 self.notify_disconnect(new_action);
2434 } else {
2435 self.inner.state = ConnectionState::Disconnecting {
2436 next_action: new_action,
2437 modify_sent: false,
2438 };
2439 }
2440 }
2441
2442 ConnectionState::Connecting { next_action, .. }
2443 | ConnectionState::Disconnecting { next_action, .. } => {
2444 *next_action = new_action;
2445 }
2446 }
2447
2448 matches!(self.inner.state, ConnectionState::Disconnected)
2449 }
2450
2451 pub(crate) fn complete_disconnect(&mut self) {
2452 if let ConnectionState::Disconnecting {
2453 next_action,
2454 modify_sent,
2455 } = std::mem::replace(&mut self.inner.state, ConnectionState::Disconnected)
2456 {
2457 assert!(self.are_channels_reset(matches!(next_action, ConnectionAction::Reset)));
2458 if !modify_sent {
2459 tracelimit::warn_ratelimited!("unexpected modify response");
2460 }
2461
2462 self.inner.state = ConnectionState::Disconnected;
2463 self.do_next_action(next_action);
2464 } else {
2465 unreachable!("not ready for disconnect");
2466 }
2467 }
2468
2469 fn do_next_action(&mut self, action: ConnectionAction) {
2470 match action {
2471 ConnectionAction::None => {}
2472 ConnectionAction::Reset => {
2473 self.complete_reset();
2474 }
2475 ConnectionAction::SendUnloadComplete => {
2476 self.complete_unload();
2477 }
2478 ConnectionAction::Reconnect { initiate_contact } => {
2479 self.initiate_contact(initiate_contact);
2480 }
2481 ConnectionAction::SendFailedVersionResponse => {
2482 self.send_version_response(None);
2485 }
2486 }
2487 }
2488
2489 fn handle_unload(&mut self) {
2491 tracing::debug!(
2492 vtl = self.inner.assigned_channels.vtl as u8,
2493 state = ?self.inner.state,
2494 "VmBus received unload request from guest",
2495 );
2496
2497 if self.request_disconnect(ConnectionAction::SendUnloadComplete) {
2498 self.complete_unload();
2499 }
2500 }
2501
2502 fn complete_unload(&mut self) {
2503 self.notifier.unload_complete();
2504 if let Some(version) = self.inner.delayed_max_version.take() {
2505 self.inner.set_compatibility_version(version, false);
2506 }
2507
2508 self.sender().send_message(&protocol::UnloadComplete {});
2509 tracelimit::info_ratelimited!("Vmbus disconnected");
2510 }
2511
2512 fn handle_request_offers(&mut self) -> Result<(), ChannelError> {
2514 let ConnectionState::Connected(info) = &mut self.inner.state else {
2515 unreachable!(
2516 "in unexpected state {:?}, should be prevented by Message::parse()",
2517 self.inner.state
2518 );
2519 };
2520
2521 if info.offers_sent {
2522 return Err(ChannelError::OffersAlreadySent);
2523 }
2524
2525 info.offers_sent = true;
2526
2527 let mut sorted_channels: Vec<_> = self
2530 .inner
2531 .channels
2532 .iter_mut()
2533 .filter(|(_, channel)| !channel.state.is_reserved())
2534 .collect();
2535
2536 sorted_channels.sort_unstable_by_key(|(_, channel)| {
2537 (
2538 channel.offer.interface_id,
2539 channel.offer.offer_order.unwrap_or(u32::MAX),
2540 channel.offer.instance_id,
2541 )
2542 });
2543
2544 for (offer_id, channel) in sorted_channels {
2545 assert!(matches!(channel.state, ChannelState::ClientReleased));
2546 assert!(channel.info.is_none());
2547
2548 channel.prepare_channel(
2549 offer_id,
2550 &mut self.inner.assigned_channels,
2551 &mut self.inner.assigned_monitors,
2552 );
2553
2554 channel.state = ChannelState::Closed;
2555 self.inner
2556 .pending_messages
2557 .sender(self.notifier, info.paused)
2558 .send_offer(channel, info.version);
2559 }
2560 self.sender().send_message(&protocol::AllOffersDelivered {});
2561
2562 Ok(())
2563 }
2564
2565 #[must_use]
2568 fn gpadl_updated(
2569 mut sender: MessageSender<'_, N>,
2570 offer_id: OfferId,
2571 channel: &Channel,
2572 gpadl_id: GpadlId,
2573 gpadl: &Gpadl,
2574 ) -> bool {
2575 if channel.state.is_revoked() {
2576 let channel_id = channel.info.as_ref().expect("assigned").channel_id;
2577 sender.send_gpadl_created(channel_id, gpadl_id, protocol::STATUS_UNSUCCESSFUL);
2578 false
2579 } else {
2580 sender.notifier.notify(
2582 offer_id,
2583 Action::Gpadl(gpadl_id, gpadl.count, gpadl.buf.clone()),
2584 );
2585 true
2586 }
2587 }
2588
2589 fn handle_gpadl_header_core(
2591 &mut self,
2592 input: &protocol::GpadlHeader,
2593 range: &[u8],
2594 ) -> Result<(), ChannelError> {
2595 let (offer_id, channel) = self
2597 .inner
2598 .channels
2599 .get_by_channel_id_mut(&self.inner.assigned_channels, input.channel_id)?;
2600
2601 if channel.state.is_reserved() {
2604 return Err(ChannelError::ChannelReserved);
2605 }
2606
2607 let mut gpadl = Gpadl::new(input.count, input.len as usize / 8);
2609 let done = gpadl.append(range)?;
2610
2611 let gpadl = match self.inner.gpadls.entry((input.gpadl_id, offer_id)) {
2613 Entry::Vacant(entry) => entry.insert(gpadl),
2614 Entry::Occupied(_) => return Err(ChannelError::DuplicateGpadlId),
2615 };
2616
2617 if !done
2619 && self
2620 .inner
2621 .incomplete_gpadls
2622 .insert(input.gpadl_id, offer_id)
2623 .is_some()
2624 {
2625 unreachable!("gpadl ID validated above");
2626 }
2627
2628 if done
2629 && !Self::gpadl_updated(
2630 self.inner
2631 .pending_messages
2632 .sender(self.notifier, self.inner.state.is_paused()),
2633 offer_id,
2634 channel,
2635 input.gpadl_id,
2636 gpadl,
2637 )
2638 {
2639 self.inner.gpadls.remove(&(input.gpadl_id, offer_id));
2640 }
2641 Ok(())
2642 }
2643
2644 fn handle_gpadl_header(&mut self, input: &protocol::GpadlHeader, range: &[u8]) {
2646 if let Err(err) = self.handle_gpadl_header_core(input, range) {
2647 tracelimit::warn_ratelimited!(
2648 err = &err as &dyn std::error::Error,
2649 channel_id = ?input.channel_id,
2650 gpadl_id = ?input.gpadl_id,
2651 "error handling gpadl header"
2652 );
2653
2654 self.sender().send_gpadl_created(
2656 input.channel_id,
2657 input.gpadl_id,
2658 protocol::STATUS_UNSUCCESSFUL,
2659 );
2660 }
2661 }
2662
2663 fn handle_gpadl_body(
2669 &mut self,
2670 input: &protocol::GpadlBody,
2671 range: &[u8],
2672 ) -> Result<(), ChannelError> {
2673 let &offer_id = self
2677 .inner
2678 .incomplete_gpadls
2679 .get(&input.gpadl_id)
2680 .ok_or(ChannelError::UnknownGpadlId)?;
2681 let gpadl = self
2682 .inner
2683 .gpadls
2684 .get_mut(&(input.gpadl_id, offer_id))
2685 .ok_or(ChannelError::UnknownGpadlId)?;
2686 let channel = &mut self.inner.channels[offer_id];
2687
2688 match gpadl.append(range) {
2689 Ok(done) => {
2690 if done {
2691 self.inner.incomplete_gpadls.remove(&input.gpadl_id);
2692 if !Self::gpadl_updated(
2693 self.inner
2694 .pending_messages
2695 .sender(self.notifier, self.inner.state.is_paused()),
2696 offer_id,
2697 channel,
2698 input.gpadl_id,
2699 gpadl,
2700 ) {
2701 self.inner.gpadls.remove(&(input.gpadl_id, offer_id));
2702 }
2703 }
2704 }
2705 Err(err) => {
2706 self.inner.incomplete_gpadls.remove(&input.gpadl_id);
2707 self.inner.gpadls.remove(&(input.gpadl_id, offer_id));
2708 let channel_id = channel.info.as_ref().expect("assigned").channel_id;
2709 tracelimit::warn_ratelimited!(
2710 err = &err as &dyn std::error::Error,
2711 channel_id = channel_id.0,
2712 gpadl_id = input.gpadl_id.0,
2713 "error handling gpadl body"
2714 );
2715 self.sender().send_gpadl_created(
2716 channel_id,
2717 input.gpadl_id,
2718 protocol::STATUS_UNSUCCESSFUL,
2719 );
2720 }
2721 }
2722
2723 Ok(())
2724 }
2725
2726 fn handle_gpadl_teardown(
2728 &mut self,
2729 input: &protocol::GpadlTeardown,
2730 ) -> Result<(), ChannelError> {
2731 tracing::debug!(
2732 channel_id = input.channel_id.0,
2733 gpadl_id = input.gpadl_id.0,
2734 "Received GPADL teardown request"
2735 );
2736
2737 let (offer_id, channel) = self
2738 .inner
2739 .channels
2740 .get_by_channel_id_mut(&self.inner.assigned_channels, input.channel_id)?;
2741
2742 let gpadl = self
2743 .inner
2744 .gpadls
2745 .get_mut(&(input.gpadl_id, offer_id))
2746 .ok_or(ChannelError::UnknownGpadlId)?;
2747
2748 match gpadl.state {
2749 GpadlState::InProgress
2750 | GpadlState::Offered
2751 | GpadlState::OfferedTearingDown
2752 | GpadlState::TearingDown => {
2753 return Err(ChannelError::InvalidGpadlState);
2754 }
2755 GpadlState::Accepted => {
2756 if channel.info.as_ref().map(|info| info.channel_id) != Some(input.channel_id) {
2757 return Err(ChannelError::WrongGpadlChannelId);
2758 }
2759
2760 if channel.state.is_reserved() {
2764 return Err(ChannelError::ChannelReserved);
2765 }
2766
2767 if channel.state.is_revoked() {
2768 tracing::trace!(
2769 channel_id = input.channel_id.0,
2770 gpadl_id = input.gpadl_id.0,
2771 "Gpadl teardown for revoked channel"
2772 );
2773
2774 self.inner.gpadls.remove(&(input.gpadl_id, offer_id));
2775 self.sender().send_gpadl_torndown(input.gpadl_id);
2776 } else {
2777 gpadl.state = GpadlState::TearingDown;
2778 self.notifier.notify(
2779 offer_id,
2780 Action::TeardownGpadl {
2781 gpadl_id: input.gpadl_id,
2782 post_restore: false,
2783 },
2784 );
2785 }
2786 }
2787 }
2788 Ok(())
2789 }
2790
2791 fn open_channel(
2794 &mut self,
2795 offer_id: OfferId,
2796 input: &OpenRequest,
2797 reserved_state: Option<ReservedState>,
2798 ) {
2799 let channel = &mut self.inner.channels[offer_id];
2800 assert!(matches!(channel.state, ChannelState::Closed));
2801
2802 channel.state = ChannelState::Opening {
2803 request: *input,
2804 reserved_state,
2805 };
2806
2807 let info = channel.info.as_ref().expect("assigned");
2810 self.notifier.notify(
2811 offer_id,
2812 Action::Open(
2813 OpenParams::from_request(
2814 info,
2815 input,
2816 channel.handled_monitor_info(),
2817 reserved_state.map(|state| state.target),
2818 ),
2819 self.inner.state.get_version().expect("must be connected"),
2820 ),
2821 );
2822 }
2823
2824 fn handle_open_channel(&mut self, input: &protocol::OpenChannel2) -> Result<(), ChannelError> {
2826 let (offer_id, channel) = self
2827 .inner
2828 .channels
2829 .get_by_channel_id_mut(&self.inner.assigned_channels, input.open_channel.channel_id)?;
2830
2831 let guest_specified_interrupt_info = self
2832 .inner
2833 .state
2834 .check_feature_flags(|ff| ff.guest_specified_signal_parameters())
2835 .then_some(SignalInfo {
2836 event_flag: input.event_flag,
2837 connection_id: input.connection_id,
2838 });
2839
2840 let flags = if self
2841 .inner
2842 .state
2843 .check_feature_flags(|ff| ff.channel_interrupt_redirection())
2844 {
2845 input.flags
2846 } else {
2847 Default::default()
2848 };
2849
2850 let request = OpenRequest {
2851 open_id: input.open_channel.open_id,
2852 ring_buffer_gpadl_id: input.open_channel.ring_buffer_gpadl_id,
2853 target_vp: input.open_channel.target_vp,
2854 downstream_ring_buffer_page_offset: input
2855 .open_channel
2856 .downstream_ring_buffer_page_offset,
2857 user_data: input.open_channel.user_data,
2858 guest_specified_interrupt_info,
2859 flags,
2860 };
2861
2862 match channel.state {
2863 ChannelState::Closed => self.open_channel(offer_id, &request, None),
2864 ChannelState::Closing { params, .. } => {
2865 channel.state = ChannelState::ClosingReopen { params, request }
2869 }
2870 ChannelState::Revoked | ChannelState::Reoffered => {}
2871
2872 ChannelState::Open { .. }
2873 | ChannelState::Opening { .. }
2874 | ChannelState::ClosingReopen { .. } => return Err(ChannelError::ChannelAlreadyOpen),
2875
2876 ChannelState::ClientReleased
2877 | ChannelState::ClosingClientRelease
2878 | ChannelState::OpeningClientRelease => unreachable!(),
2879 }
2880 Ok(())
2881 }
2882
2883 fn handle_close_channel(&mut self, input: &protocol::CloseChannel) -> Result<(), ChannelError> {
2885 let (offer_id, channel) = self
2886 .inner
2887 .channels
2888 .get_by_channel_id_mut(&self.inner.assigned_channels, input.channel_id)?;
2889
2890 match channel.state {
2891 ChannelState::Open {
2892 params,
2893 modify_state,
2894 reserved_state: None,
2895 } => {
2896 if modify_state.is_modifying() {
2897 tracelimit::warn_ratelimited!(
2898 ?modify_state,
2899 "Client is closing the channel with a modify in progress"
2900 )
2901 }
2902
2903 channel.state = ChannelState::Closing {
2904 params,
2905 reserved_state: None,
2906 };
2907 self.notifier.notify(offer_id, Action::Close);
2908 }
2909
2910 ChannelState::Open {
2911 reserved_state: Some(_),
2912 ..
2913 } => return Err(ChannelError::ChannelReserved),
2914
2915 ChannelState::Revoked | ChannelState::Reoffered => {}
2916
2917 ChannelState::Closed
2918 | ChannelState::Opening { .. }
2919 | ChannelState::Closing { .. }
2920 | ChannelState::ClosingReopen { .. } => return Err(ChannelError::ChannelNotOpen),
2921
2922 ChannelState::ClientReleased
2923 | ChannelState::ClosingClientRelease
2924 | ChannelState::OpeningClientRelease => unreachable!(),
2925 }
2926
2927 Ok(())
2928 }
2929
2930 fn handle_open_reserved_channel(
2933 &mut self,
2934 input: &protocol::OpenReservedChannel,
2935 version: VersionInfo,
2936 ) -> Result<(), ChannelError> {
2937 let (offer_id, channel) = self
2938 .inner
2939 .channels
2940 .get_by_channel_id_mut(&self.inner.assigned_channels, input.channel_id)?;
2941
2942 let target = ConnectionTarget {
2943 vp: input.target_vp,
2944 sint: input.target_sint as u8,
2945 };
2946
2947 let reserved_state = Some(ReservedState { version, target });
2948
2949 let request = OpenRequest {
2950 ring_buffer_gpadl_id: input.ring_buffer_gpadl,
2951 target_vp: protocol::VP_INDEX_DISABLE_INTERRUPT,
2953 downstream_ring_buffer_page_offset: input.downstream_page_offset,
2954 open_id: 0,
2955 user_data: UserDefinedData::new_zeroed(),
2956 guest_specified_interrupt_info: None,
2957 flags: Default::default(),
2958 };
2959
2960 match channel.state {
2961 ChannelState::Closed => self.open_channel(offer_id, &request, reserved_state),
2962 ChannelState::Revoked | ChannelState::Reoffered => {}
2963
2964 ChannelState::Open { .. } | ChannelState::Opening { .. } => {
2965 return Err(ChannelError::ChannelAlreadyOpen);
2966 }
2967
2968 ChannelState::Closing { .. } | ChannelState::ClosingReopen { .. } => {
2969 return Err(ChannelError::InvalidChannelState);
2970 }
2971
2972 ChannelState::ClientReleased
2973 | ChannelState::ClosingClientRelease
2974 | ChannelState::OpeningClientRelease => unreachable!(),
2975 }
2976 Ok(())
2977 }
2978
2979 fn handle_close_reserved_channel(
2982 &mut self,
2983 input: &protocol::CloseReservedChannel,
2984 ) -> Result<(), ChannelError> {
2985 let (offer_id, channel) = self
2986 .inner
2987 .channels
2988 .get_by_channel_id_mut(&self.inner.assigned_channels, input.channel_id)?;
2989
2990 match channel.state {
2991 ChannelState::Open {
2992 params,
2993 reserved_state: Some(mut resvd),
2994 ..
2995 } => {
2996 resvd.target.vp = input.target_vp;
2997 resvd.target.sint = input.target_sint as u8;
2998 channel.state = ChannelState::Closing {
2999 params,
3000 reserved_state: Some(resvd),
3001 };
3002 self.notifier.notify(offer_id, Action::Close);
3003 }
3004
3005 ChannelState::Open {
3006 reserved_state: None,
3007 ..
3008 } => return Err(ChannelError::ChannelNotReserved),
3009
3010 ChannelState::Revoked | ChannelState::Reoffered => {}
3011
3012 ChannelState::Closed
3013 | ChannelState::Opening { .. }
3014 | ChannelState::Closing { .. }
3015 | ChannelState::ClosingReopen { .. } => return Err(ChannelError::ChannelNotOpen),
3016
3017 ChannelState::ClientReleased
3018 | ChannelState::ClosingClientRelease
3019 | ChannelState::OpeningClientRelease => unreachable!(),
3020 }
3021
3022 Ok(())
3023 }
3024
3025 #[must_use]
3029 fn client_release_channel(
3030 mut sender: MessageSender<'_, N>,
3031 offer_id: OfferId,
3032 channel: &mut Channel,
3033 gpadls: &mut GpadlMap,
3034 assigned_channels: &mut AssignedChannels,
3035 assigned_monitors: &mut AssignedMonitors,
3036 version: Option<VersionInfo>,
3037 ) -> bool {
3038 gpadls.retain(|&(gpadl_id, gpadl_offer_id), gpadl| {
3040 if gpadl_offer_id != offer_id {
3041 return true;
3042 }
3043 match gpadl.state {
3044 GpadlState::InProgress => false,
3045 GpadlState::Offered => {
3046 gpadl.state = GpadlState::OfferedTearingDown;
3047 true
3048 }
3049 GpadlState::Accepted => {
3050 if channel.state.is_revoked() {
3051 false
3053 } else {
3054 gpadl.state = GpadlState::TearingDown;
3055 sender.notifier.notify(
3056 offer_id,
3057 Action::TeardownGpadl {
3058 gpadl_id,
3059 post_restore: false,
3060 },
3061 );
3062 true
3063 }
3064 }
3065 GpadlState::OfferedTearingDown | GpadlState::TearingDown => true,
3066 }
3067 });
3068
3069 let remove = match &mut channel.state {
3070 ChannelState::Closed => {
3071 channel.state = ChannelState::ClientReleased;
3072 false
3073 }
3074 ChannelState::Reoffered => {
3075 if let Some(version) = version {
3076 channel.state = ChannelState::Closed;
3077 channel.restore_state = RestoreState::New;
3078 sender.send_offer(channel, version);
3079 return false;
3081 }
3082 channel.state = ChannelState::ClientReleased;
3083 false
3084 }
3085 ChannelState::Revoked => {
3086 channel.state = ChannelState::ClientReleased;
3087 true
3088 }
3089 ChannelState::Opening { .. } => {
3090 channel.state = ChannelState::OpeningClientRelease;
3091 false
3092 }
3093 ChannelState::Open { .. } => {
3094 channel.state = ChannelState::ClosingClientRelease;
3095 sender.notifier.notify(offer_id, Action::Close);
3096 false
3097 }
3098 ChannelState::Closing { .. } | ChannelState::ClosingReopen { .. } => {
3099 channel.state = ChannelState::ClosingClientRelease;
3100 false
3101 }
3102
3103 ChannelState::ClosingClientRelease
3104 | ChannelState::OpeningClientRelease
3105 | ChannelState::ClientReleased => false,
3106 };
3107
3108 assert!(channel.state.is_released());
3109
3110 channel.release_channel(offer_id, assigned_channels, assigned_monitors);
3111 remove
3112 }
3113
3114 fn handle_rel_id_released(
3116 &mut self,
3117 input: &protocol::RelIdReleased,
3118 ) -> Result<(), ChannelError> {
3119 let channel_id = input.channel_id;
3120 let (offer_id, channel) = self
3121 .inner
3122 .channels
3123 .get_by_channel_id_mut(&self.inner.assigned_channels, channel_id)?;
3124
3125 match channel.state {
3126 ChannelState::Closed
3127 | ChannelState::Revoked
3128 | ChannelState::Closing { .. }
3129 | ChannelState::Reoffered => {
3130 if Self::client_release_channel(
3131 self.inner
3132 .pending_messages
3133 .sender(self.notifier, self.inner.state.is_paused()),
3134 offer_id,
3135 channel,
3136 &mut self.inner.gpadls,
3137 &mut self.inner.assigned_channels,
3138 &mut self.inner.assigned_monitors,
3139 self.inner.state.get_version(),
3140 ) {
3141 self.inner.channels.remove(offer_id);
3142 }
3143
3144 self.check_disconnected();
3145 }
3146
3147 ChannelState::Opening { .. }
3148 | ChannelState::Open { .. }
3149 | ChannelState::ClosingReopen { .. } => return Err(ChannelError::InvalidChannelState),
3150
3151 ChannelState::ClientReleased
3152 | ChannelState::OpeningClientRelease
3153 | ChannelState::ClosingClientRelease => unreachable!(),
3154 }
3155 Ok(())
3156 }
3157
3158 fn handle_tl_connect_request(&mut self, request: protocol::TlConnectRequest2) {
3161 let version = self
3162 .inner
3163 .state
3164 .get_version()
3165 .expect("must be connected")
3166 .version;
3167
3168 let hosted_silo_unaware = version < Version::Win10Rs5;
3169 self.notifier
3170 .notify_hvsock(&HvsockConnectRequest::from_message(
3171 request,
3172 hosted_silo_unaware,
3173 ));
3174 }
3175
3176 pub fn send_tl_connect_result(&mut self, result: HvsockConnectResult) {
3178 if !result.success && self.inner.state.check_version(Version::Win10Rs3_0) {
3182 self.sender().send_message(&protocol::TlConnectResult {
3186 service_id: result.service_id,
3187 endpoint_id: result.endpoint_id,
3188 status: protocol::STATUS_CONNECTION_REFUSED,
3189 })
3190 }
3191 }
3192
3193 fn handle_modify_channel(
3196 &mut self,
3197 request: &protocol::ModifyChannel,
3198 ) -> Result<(), ChannelError> {
3199 let result = self.modify_channel(request);
3200 if result.is_err() {
3201 self.send_modify_channel_response(request.channel_id, protocol::STATUS_UNSUCCESSFUL);
3202 }
3203
3204 result
3205 }
3206
3207 fn modify_channel(&mut self, request: &protocol::ModifyChannel) -> Result<(), ChannelError> {
3209 let (offer_id, channel) = self
3210 .inner
3211 .channels
3212 .get_by_channel_id_mut(&self.inner.assigned_channels, request.channel_id)?;
3213
3214 let (open_request, modify_state) = match &mut channel.state {
3215 ChannelState::Open {
3216 params,
3217 modify_state,
3218 reserved_state: None,
3219 } => (params, modify_state),
3220 _ => return Err(ChannelError::InvalidChannelState),
3221 };
3222
3223 if let ModifyState::Modifying { pending_target_vp } = modify_state {
3224 if self.inner.state.check_version(Version::Iron) {
3225 tracelimit::warn_ratelimited!(
3228 "Client sent new ModifyChannel before receiving ModifyChannelResponse."
3229 );
3230 } else {
3231 *pending_target_vp = Some(request.target_vp);
3234 }
3235 } else {
3236 self.notifier.notify(
3237 offer_id,
3238 Action::Modify {
3239 target_vp: request.target_vp,
3240 },
3241 );
3242
3243 open_request.target_vp = request.target_vp;
3245 *modify_state = ModifyState::Modifying {
3246 pending_target_vp: None,
3247 };
3248 }
3249
3250 Ok(())
3251 }
3252
3253 pub fn modify_channel_complete(&mut self, offer_id: OfferId, status: i32) {
3260 let channel = &mut self.inner.channels[offer_id];
3261
3262 if let ChannelState::Open {
3263 params,
3264 modify_state: ModifyState::Modifying { pending_target_vp },
3265 reserved_state: None,
3266 } = channel.state
3267 {
3268 channel.state = ChannelState::Open {
3269 params,
3270 modify_state: ModifyState::NotModifying,
3271 reserved_state: None,
3272 };
3273
3274 let channel_id = channel.info.as_ref().expect("assigned").channel_id;
3276 self.send_modify_channel_response(channel_id, status);
3277
3278 if let Some(target_vp) = pending_target_vp {
3280 let request = protocol::ModifyChannel {
3281 channel_id,
3282 target_vp,
3283 };
3284
3285 if let Err(error) = self.handle_modify_channel(&request) {
3286 tracelimit::warn_ratelimited!(?error, "Pending ModifyChannel request failed.")
3287 }
3288 }
3289 }
3290 }
3291
3292 fn send_modify_channel_response(&mut self, channel_id: ChannelId, status: i32) {
3293 if self.inner.state.check_version(Version::Iron) {
3294 self.sender()
3295 .send_message(&protocol::ModifyChannelResponse { channel_id, status });
3296 }
3297 }
3298
3299 fn handle_modify_connection(&mut self, request: protocol::ModifyConnection) {
3300 if let Err(err) = self.modify_connection(request) {
3301 tracelimit::error_ratelimited!(?err, "modifying connection failed");
3302 self.complete_modify_connection(ModifyConnectionResponse::Supported(
3303 protocol::ConnectionState::FAILED_UNKNOWN_FAILURE,
3304 FeatureFlags::new(),
3305 ));
3306 }
3307 }
3308
3309 fn modify_connection(&mut self, request: protocol::ModifyConnection) -> anyhow::Result<()> {
3310 let ConnectionState::Connected(info) = &mut self.inner.state else {
3311 anyhow::bail!(
3312 "Invalid state for ModifyConnection request: {:?}",
3313 self.inner.state
3314 );
3315 };
3316
3317 if info.modifying {
3318 anyhow::bail!(
3319 "Duplicate ModifyConnection request, state: {:?}",
3320 self.inner.state
3321 );
3322 }
3323
3324 if (request.child_to_parent_monitor_page_gpa == 0)
3325 != (request.parent_to_child_monitor_page_gpa == 0)
3326 {
3327 anyhow::bail!("Guest must specify either both or no monitor pages, {request:?}");
3328 }
3329
3330 let monitor_page =
3331 (request.child_to_parent_monitor_page_gpa != 0).then_some(MonitorPageGpas {
3332 child_to_parent: request.child_to_parent_monitor_page_gpa,
3333 parent_to_child: request.parent_to_child_monitor_page_gpa,
3334 });
3335
3336 info.modifying = true;
3337 info.monitor_page = monitor_page;
3338 tracing::debug!("modifying connection parameters.");
3339 self.notifier.modify_connection(request.into())?;
3340
3341 Ok(())
3342 }
3343
3344 pub fn complete_modify_connection(&mut self, response: ModifyConnectionResponse) {
3345 tracing::debug!(?response, "modifying connection parameters complete");
3346
3347 match &mut self.inner.state {
3351 ConnectionState::Connecting { .. } => self.complete_initiate_contact(response),
3352 ConnectionState::Disconnecting { .. } => self.complete_disconnect(),
3353 ConnectionState::Connected(info) => {
3354 let ModifyConnectionResponse::Supported(connection_state, ..) = response else {
3355 panic!(
3356 "Relay should not return {:?} for a modify request with no version.",
3357 response
3358 );
3359 };
3360
3361 if !info.modifying {
3362 panic!(
3363 "ModifyConnection response while not modifying, state: {:?}",
3364 self.inner.state
3365 );
3366 }
3367
3368 info.modifying = false;
3369 self.sender()
3370 .send_message(&protocol::ModifyConnectionResponse { connection_state });
3371 }
3372 _ => panic!(
3373 "Invalid state for ModifyConnection response: {:?}",
3374 self.inner.state
3375 ),
3376 }
3377 }
3378
3379 fn handle_pause(&mut self) {
3380 tracelimit::info_ratelimited!("pausing sending messages");
3381 self.sender().send_message(&protocol::PauseResponse {});
3382 let ConnectionState::Connected(info) = &mut self.inner.state else {
3383 unreachable!(
3384 "in unexpected state {:?}, should be prevented by Message::parse()",
3385 self.inner.state
3386 );
3387 };
3388 info.paused = true;
3389 }
3390
3391 pub fn handle_synic_message(&mut self, message: SynicMessage) -> Result<(), ChannelError> {
3393 assert!(!self.is_resetting());
3394
3395 let version = self.inner.state.get_version();
3396 let msg = Message::parse(&message.data, version)?;
3397 tracing::trace!(?msg, message.trusted, "received vmbus message");
3398 if self.inner.state.is_trusted() && !message.trusted {
3403 tracelimit::warn_ratelimited!(?msg, "Received untrusted message");
3404 return Err(ChannelError::UntrustedMessage);
3405 }
3406
3407 match &mut self.inner.state {
3409 ConnectionState::Connected(info) if info.paused => {
3410 if !matches!(
3411 msg,
3412 Message::Resume(..)
3413 | Message::Unload(..)
3414 | Message::InitiateContact { .. }
3415 | Message::InitiateContact2 { .. }
3416 ) {
3417 tracelimit::warn_ratelimited!(?msg, "Received message while paused");
3418 return Err(ChannelError::Paused);
3419 }
3420 tracelimit::info_ratelimited!("resuming sending messages");
3421 info.paused = false;
3422 }
3423 _ => {}
3424 }
3425
3426 match msg {
3427 Message::InitiateContact2(input, ..) => {
3428 self.handle_initiate_contact(&input, &message, true)?
3429 }
3430 Message::InitiateContact(input, ..) => {
3431 self.handle_initiate_contact(&input.into(), &message, false)?
3432 }
3433 Message::Unload(..) => self.handle_unload(),
3434 Message::RequestOffers(..) => self.handle_request_offers()?,
3435 Message::GpadlHeader(input, range) => self.handle_gpadl_header(&input, range),
3436 Message::GpadlBody(input, range) => self.handle_gpadl_body(&input, range)?,
3437 Message::GpadlTeardown(input, ..) => self.handle_gpadl_teardown(&input)?,
3438 Message::OpenChannel(input, ..) => self.handle_open_channel(&input.into())?,
3439 Message::OpenChannel2(input, ..) => self.handle_open_channel(&input)?,
3440 Message::CloseChannel(input, ..) => self.handle_close_channel(&input)?,
3441 Message::RelIdReleased(input, ..) => self.handle_rel_id_released(&input)?,
3442 Message::TlConnectRequest(input, ..) => self.handle_tl_connect_request(input.into()),
3443 Message::TlConnectRequest2(input, ..) => self.handle_tl_connect_request(input),
3444 Message::ModifyChannel(input, ..) => self.handle_modify_channel(&input)?,
3445 Message::ModifyConnection(input, ..) => self.handle_modify_connection(input),
3446 Message::OpenReservedChannel(input, ..) => self.handle_open_reserved_channel(
3447 &input,
3448 version.expect("version validated by Message::parse"),
3449 )?,
3450 Message::CloseReservedChannel(input, ..) => {
3451 self.handle_close_reserved_channel(&input)?
3452 }
3453 Message::Pause(protocol::Pause, ..) => self.handle_pause(),
3454 Message::Resume(protocol::Resume, ..) => {}
3455 Message::OfferChannel(..)
3457 | Message::RescindChannelOffer(..)
3458 | Message::AllOffersDelivered(..)
3459 | Message::OpenResult(..)
3460 | Message::GpadlCreated(..)
3461 | Message::GpadlTorndown(..)
3462 | Message::VersionResponse(..)
3463 | Message::VersionResponse2(..)
3464 | Message::UnloadComplete(..)
3465 | Message::CloseReservedChannelResponse(..)
3466 | Message::TlConnectResult(..)
3467 | Message::ModifyChannelResponse(..)
3468 | Message::ModifyConnectionResponse(..)
3469 | Message::PauseResponse(..) => {
3470 unreachable!("Server received client message {:?}", msg);
3471 }
3472 }
3473 Ok(())
3474 }
3475
3476 fn get_gpadl(
3477 gpadls: &mut GpadlMap,
3478 offer_id: OfferId,
3479 gpadl_id: GpadlId,
3480 ) -> Option<&mut Gpadl> {
3481 let gpadl = gpadls.get_mut(&(gpadl_id, offer_id));
3482 if gpadl.is_none() {
3483 tracelimit::error_ratelimited!(?offer_id, ?gpadl_id, "invalid gpadl ID for channel");
3484 }
3485 gpadl
3486 }
3487
3488 pub fn gpadl_create_complete(&mut self, offer_id: OfferId, gpadl_id: GpadlId, status: i32) {
3490 let gpadl = if let Some(gpadl) = Self::get_gpadl(&mut self.inner.gpadls, offer_id, gpadl_id)
3491 {
3492 gpadl
3493 } else {
3494 return;
3495 };
3496 let retain = match gpadl.state {
3497 GpadlState::InProgress | GpadlState::TearingDown | GpadlState::Accepted => {
3498 tracelimit::error_ratelimited!(?offer_id, ?gpadl_id, ?gpadl, "invalid gpadl state");
3499 return;
3500 }
3501 GpadlState::Offered => {
3502 let channel_id = self.inner.channels[offer_id]
3503 .info
3504 .as_ref()
3505 .expect("assigned")
3506 .channel_id;
3507 self.inner
3508 .pending_messages
3509 .sender(self.notifier, self.inner.state.is_paused())
3510 .send_gpadl_created(channel_id, gpadl_id, status);
3511 if status >= 0 {
3512 gpadl.state = GpadlState::Accepted;
3513 true
3514 } else {
3515 false
3516 }
3517 }
3518 GpadlState::OfferedTearingDown => {
3519 if status >= 0 {
3520 self.notifier.notify(
3522 offer_id,
3523 Action::TeardownGpadl {
3524 gpadl_id,
3525 post_restore: false,
3526 },
3527 );
3528 gpadl.state = GpadlState::TearingDown;
3529 true
3530 } else {
3531 false
3532 }
3533 }
3534 };
3535 if !retain {
3536 self.inner
3537 .gpadls
3538 .remove(&(gpadl_id, offer_id))
3539 .expect("gpadl validated above");
3540
3541 self.check_disconnected();
3542 }
3543 }
3544
3545 pub fn gpadl_teardown_complete(&mut self, offer_id: OfferId, gpadl_id: GpadlId) {
3547 tracing::debug!(
3548 offer_id = offer_id.0,
3549 gpadl_id = gpadl_id.0,
3550 "Gpadl teardown complete"
3551 );
3552
3553 let gpadl = if let Some(gpadl) = Self::get_gpadl(&mut self.inner.gpadls, offer_id, gpadl_id)
3554 {
3555 gpadl
3556 } else {
3557 return;
3558 };
3559 let channel = &mut self.inner.channels[offer_id];
3560 match gpadl.state {
3561 GpadlState::InProgress
3562 | GpadlState::Offered
3563 | GpadlState::OfferedTearingDown
3564 | GpadlState::Accepted => {
3565 tracelimit::error_ratelimited!(?offer_id, ?gpadl_id, ?gpadl, "invalid gpadl state");
3566 }
3567 GpadlState::TearingDown => {
3568 if !channel.state.is_released() {
3569 self.sender().send_gpadl_torndown(gpadl_id);
3570 }
3571 self.inner
3572 .gpadls
3573 .remove(&(gpadl_id, offer_id))
3574 .expect("gpadl validated above");
3575
3576 self.check_disconnected();
3577 }
3578 }
3579 }
3580
3581 fn sender(&mut self) -> MessageSender<'_, N> {
3586 self.inner
3587 .pending_messages
3588 .sender(self.notifier, self.inner.state.is_paused())
3589 }
3590}
3591
3592fn revoke<N: Notifier>(
3593 mut sender: MessageSender<'_, N>,
3594 offer_id: OfferId,
3595 channel: &mut Channel,
3596 gpadls: &mut GpadlMap,
3597) -> bool {
3598 let info = match channel.state {
3599 ChannelState::Closed
3600 | ChannelState::Open { .. }
3601 | ChannelState::Opening { .. }
3602 | ChannelState::Closing { .. }
3603 | ChannelState::ClosingReopen { .. } => {
3604 channel.state = ChannelState::Revoked;
3605 Some(channel.info.as_ref().expect("assigned"))
3606 }
3607 ChannelState::Reoffered => {
3608 channel.state = ChannelState::Revoked;
3609 None
3610 }
3611 ChannelState::ClientReleased
3612 | ChannelState::OpeningClientRelease
3613 | ChannelState::ClosingClientRelease => None,
3614 ChannelState::Revoked => return true,
3616 };
3617 let retain = !channel.state.is_released();
3618
3619 gpadls.retain(|&(gpadl_id, gpadl_offer_id), gpadl| {
3621 if gpadl_offer_id != offer_id {
3622 return true;
3623 }
3624
3625 match gpadl.state {
3626 GpadlState::InProgress => true,
3627 GpadlState::Offered => {
3628 if let Some(info) = info {
3629 sender.send_gpadl_created(
3630 info.channel_id,
3631 gpadl_id,
3632 protocol::STATUS_UNSUCCESSFUL,
3633 );
3634 }
3635 false
3636 }
3637 GpadlState::OfferedTearingDown => false,
3638 GpadlState::Accepted => true,
3639 GpadlState::TearingDown => {
3640 if info.is_some() {
3641 sender.send_gpadl_torndown(gpadl_id);
3642 }
3643 false
3644 }
3645 }
3646 });
3647 if let Some(info) = info {
3648 sender.send_rescind(info);
3649 }
3650 if channel.restore_state != RestoreState::New {
3652 channel.restore_state = RestoreState::Restored;
3653 }
3654 retain
3655}
3656
3657struct PendingMessages(VecDeque<OutgoingMessage>);
3658
3659impl PendingMessages {
3660 fn sender<'a, N: Notifier>(
3662 &'a mut self,
3663 notifier: &'a mut N,
3664 is_paused: bool,
3665 ) -> MessageSender<'a, N> {
3666 MessageSender {
3667 notifier,
3668 pending_messages: self,
3669 is_paused,
3670 }
3671 }
3672}
3673
3674struct MessageSender<'a, N> {
3677 notifier: &'a mut N,
3678 pending_messages: &'a mut PendingMessages,
3679 is_paused: bool,
3680}
3681
3682impl<N: Notifier> MessageSender<'_, N> {
3683 fn send_message<
3685 T: IntoBytes + protocol::VmbusMessage + std::fmt::Debug + Immutable + KnownLayout,
3686 >(
3687 &mut self,
3688 msg: &T,
3689 ) {
3690 let message = OutgoingMessage::new(msg);
3691
3692 tracing::trace!(typ = ?T::MESSAGE_TYPE, ?msg, "sending message");
3693 if !self.pending_messages.0.is_empty()
3695 || self.is_paused
3696 || !self.notifier.send_message(&message, MessageTarget::Default)
3697 {
3698 tracing::trace!("message queued");
3699 self.pending_messages.0.push_back(message);
3701 }
3702 }
3703
3704 fn send_message_with_target<
3706 T: IntoBytes + protocol::VmbusMessage + std::fmt::Debug + Immutable + KnownLayout,
3707 >(
3708 &mut self,
3709 msg: &T,
3710 target: MessageTarget,
3711 ) {
3712 if target == MessageTarget::Default {
3713 self.send_message(msg);
3714 } else {
3715 tracing::trace!(typ = ?T::MESSAGE_TYPE, ?msg, "sending message");
3716 let message = OutgoingMessage::new(msg);
3719 if !self.notifier.send_message(&message, target) {
3720 tracelimit::warn_ratelimited!(?target, "failed to send message");
3721 }
3722 }
3723 }
3724
3725 fn send_offer(&mut self, channel: &mut Channel, version: VersionInfo) {
3727 let info = channel.info.as_ref().expect("assigned");
3728 let mut flags = channel.offer.flags;
3729 if !version.feature_flags.confidential_channels() {
3730 flags.set_confidential_ring_buffer(false);
3731 flags.set_confidential_external_memory(false);
3732 }
3733
3734 let msg = protocol::OfferChannel {
3735 interface_id: channel.offer.interface_id,
3736 instance_id: channel.offer.instance_id,
3737 rsvd: [0; 4],
3738 flags,
3739 mmio_megabytes: channel.offer.mmio_megabytes,
3740 user_defined: channel.offer.user_defined,
3741 subchannel_index: channel.offer.subchannel_index,
3742 mmio_megabytes_optional: channel.offer.mmio_megabytes_optional,
3743 channel_id: info.channel_id,
3744 monitor_id: info.monitor_id.unwrap_or(MonitorId::INVALID).0,
3745 monitor_allocated: info.monitor_id.is_some() as u8,
3746 is_dedicated: 1,
3749 connection_id: info.connection_id,
3750 };
3751 tracing::info!(
3752 channel_id = msg.channel_id.0,
3753 connection_id = msg.connection_id,
3754 key = %channel.offer.key(),
3755 "sending offer to guest"
3756 );
3757
3758 self.send_message(&msg);
3759 }
3760
3761 fn send_open_result(
3762 &mut self,
3763 channel_id: ChannelId,
3764 open_request: &OpenRequest,
3765 result: i32,
3766 target: MessageTarget,
3767 ) {
3768 self.send_message_with_target(
3769 &protocol::OpenResult {
3770 channel_id,
3771 open_id: open_request.open_id,
3772 status: result as u32,
3773 },
3774 target,
3775 );
3776 }
3777
3778 fn send_gpadl_created(&mut self, channel_id: ChannelId, gpadl_id: GpadlId, status: i32) {
3779 self.send_message(&protocol::GpadlCreated {
3780 channel_id,
3781 gpadl_id,
3782 status,
3783 });
3784 }
3785
3786 fn send_gpadl_torndown(&mut self, gpadl_id: GpadlId) {
3787 self.send_message(&protocol::GpadlTorndown { gpadl_id });
3788 }
3789
3790 fn send_rescind(&mut self, info: &OfferedInfo) {
3791 tracing::info!(
3792 channel_id = info.channel_id.0,
3793 "rescinding channel from guest"
3794 );
3795
3796 self.send_message(&protocol::RescindChannelOffer {
3797 channel_id: info.channel_id,
3798 });
3799 }
3800}