vmbus_server/
channels.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4mod saved_state;
5
6use crate::Guid;
7use crate::SINT;
8use crate::SynicMessage;
9use crate::monitor::AssignedMonitors;
10use crate::protocol::Version;
11use hvdef::Vtl;
12use inspect::Inspect;
13pub use saved_state::RestoreError;
14pub use saved_state::SavedState;
15pub use saved_state::SavedStateData;
16use slab::Slab;
17use std::cmp::min;
18use std::collections::VecDeque;
19use std::collections::hash_map::Entry;
20use std::collections::hash_map::HashMap;
21use std::fmt::Display;
22use std::ops::Index;
23use std::ops::IndexMut;
24use std::task::Poll;
25use std::task::ready;
26use std::time::Duration;
27use thiserror::Error;
28use vmbus_channel::bus::ChannelType;
29use vmbus_channel::bus::GpadlRequest;
30use vmbus_channel::bus::OfferKey;
31use vmbus_channel::bus::OfferParams;
32use vmbus_channel::bus::OpenData;
33use vmbus_channel::bus::RestoredGpadl;
34use vmbus_core::HvsockConnectRequest;
35use vmbus_core::HvsockConnectResult;
36use vmbus_core::MaxVersionInfo;
37use vmbus_core::OutgoingMessage;
38use vmbus_core::VersionInfo;
39use vmbus_core::protocol;
40use vmbus_core::protocol::ChannelId;
41use vmbus_core::protocol::ConnectionId;
42use vmbus_core::protocol::FeatureFlags;
43use vmbus_core::protocol::GpadlId;
44use vmbus_core::protocol::Message;
45use vmbus_core::protocol::OfferFlags;
46use vmbus_core::protocol::UserDefinedData;
47use vmbus_ring::gparange;
48use vmcore::monitor::MonitorId;
49use vmcore::synic::MonitorInfo;
50use vmcore::synic::MonitorPageGpas;
51use zerocopy::FromZeros;
52use zerocopy::Immutable;
53use zerocopy::IntoBytes;
54use zerocopy::KnownLayout;
55
56/// An error caused by a channel operation.
57#[derive(Debug, Error)]
58pub enum ChannelError {
59    #[error("unknown channel ID")]
60    UnknownChannelId,
61    #[error("unknown GPADL ID")]
62    UnknownGpadlId,
63    #[error("parse error")]
64    ParseError(#[from] protocol::ParseError),
65    #[error("invalid gpa range")]
66    InvalidGpaRange(#[source] gparange::Error),
67    #[error("duplicate GPADL ID")]
68    DuplicateGpadlId,
69    #[error("GPADL is already complete")]
70    GpadlAlreadyComplete,
71    #[error("GPADL channel ID mismatch")]
72    WrongGpadlChannelId,
73    #[error("trying to open an open channel")]
74    ChannelAlreadyOpen,
75    #[error("trying to close a closed channel")]
76    ChannelNotOpen,
77    #[error("invalid GPADL state for operation")]
78    InvalidGpadlState,
79    #[error("invalid channel state for operation")]
80    InvalidChannelState,
81    #[error("channel ID has already been released")]
82    ChannelReleased,
83    #[error("channel offers have already been sent")]
84    OffersAlreadySent,
85    #[error("invalid operation on reserved channel")]
86    ChannelReserved,
87    #[error("invalid operation on non-reserved channel")]
88    ChannelNotReserved,
89    #[error("received untrusted message for trusted connection")]
90    UntrustedMessage,
91    #[error("received a non-resuming message while paused")]
92    Paused,
93}
94
95#[derive(Debug, Error)]
96pub enum OfferError {
97    #[error("the channel ID {} is not valid for this operation", (.0).0)]
98    InvalidChannelId(ChannelId),
99    #[error("the channel ID {} is already in use", (.0).0)]
100    ChannelIdInUse(ChannelId),
101    #[error("offer {0} already exists")]
102    AlreadyExists(OfferKey),
103    #[error("specified resources do not match those of the existing saved or revoked offer")]
104    IncompatibleResources,
105    #[error("too many channels have been offered")]
106    TooManyChannels,
107    #[error("mismatched monitor ID from saved state; expected {0:?}, actual {1:?}")]
108    MismatchedMonitorId(Option<MonitorId>, MonitorId),
109}
110
111/// A unique identifier for an offered channel.
112#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
113pub struct OfferId(usize);
114
115type IncompleteGpadlMap = HashMap<GpadlId, OfferId>;
116
117type GpadlMap = HashMap<(GpadlId, OfferId), Gpadl>;
118
119/// A struct modeling the server side of the VMBus control plane.
120pub struct Server {
121    state: ConnectionState,
122    channels: ChannelList,
123    assigned_channels: AssignedChannels,
124    assigned_monitors: AssignedMonitors,
125    gpadls: GpadlMap,
126    incomplete_gpadls: IncompleteGpadlMap,
127    child_connection_id: u32,
128    max_version: Option<MaxVersionInfo>,
129    delayed_max_version: Option<MaxVersionInfo>,
130    // This must be separate from the connection state because e.g. the UnloadComplete message,
131    // or messages for reserved channels, can be pending even when disconnected.
132    pending_messages: PendingMessages,
133}
134
135pub struct ServerWithNotifier<'a, T> {
136    inner: &'a mut Server,
137    notifier: &'a mut T,
138}
139
140impl<T> Drop for ServerWithNotifier<'_, T> {
141    fn drop(&mut self) {
142        self.inner.validate();
143    }
144}
145
146impl<T: Notifier> Inspect for ServerWithNotifier<'_, T> {
147    fn inspect(&self, req: inspect::Request<'_>) {
148        let mut resp = req.respond();
149        let (state, info, next_action) = match &self.inner.state {
150            ConnectionState::Disconnected => ("disconnected", None, None),
151            ConnectionState::Connecting { info, .. } => ("connecting", Some(info), None),
152            ConnectionState::Connected(info) => (
153                if info.offers_sent {
154                    "connected"
155                } else {
156                    "negotiated"
157                },
158                Some(info),
159                None,
160            ),
161            ConnectionState::Disconnecting { next_action, .. } => {
162                ("disconnecting", None, Some(next_action))
163            }
164        };
165
166        resp.field("connection_info", info);
167        let next_action = next_action.map(|a| match a {
168            ConnectionAction::None => "disconnect",
169            ConnectionAction::Reset => "reset",
170            ConnectionAction::SendUnloadComplete => "unload",
171            ConnectionAction::Reconnect { .. } => "reconnect",
172            ConnectionAction::SendFailedVersionResponse => "send_version_response",
173        });
174        resp.field("state", state)
175            .field("next_action", next_action)
176            .field(
177                "assigned_monitors_bitmap",
178                format_args!("{:x}", self.inner.assigned_monitors.bitmap()),
179            )
180            .child("channels", |req| {
181                let mut resp = req.respond();
182                self.inner
183                    .channels
184                    .inspect(self.notifier, self.inner.get_version(), &mut resp);
185                for ((gpadl_id, offer_id), gpadl) in &self.inner.gpadls {
186                    let channel = &self.inner.channels[*offer_id];
187                    resp.field(
188                        &channel_inspect_path(
189                            &channel.offer,
190                            format_args!("/gpadls/{}", gpadl_id.0),
191                        ),
192                        gpadl,
193                    );
194                }
195            });
196    }
197}
198
199#[derive(Debug, Copy, Clone, Inspect)]
200struct ConnectionInfo {
201    version: VersionInfo,
202    // Indicates if the connection is trusted for the paravisor of a hardware-isolated VM. In other
203    // cases, this value is always false.
204    trusted: bool,
205    offers_sent: bool,
206    interrupt_page: Option<u64>,
207    monitor_page: Option<MonitorPageGpas>,
208    target_message_vp: u32,
209    modifying: bool,
210    client_id: Guid,
211    paused: bool,
212}
213
214/// The state of the VMBus connection.
215#[derive(Debug)]
216enum ConnectionState {
217    Disconnected,
218    Disconnecting {
219        next_action: ConnectionAction,
220        modify_sent: bool,
221    },
222    Connecting {
223        info: ConnectionInfo,
224        next_action: ConnectionAction,
225    },
226    Connected(ConnectionInfo),
227}
228
229impl ConnectionState {
230    /// Checks whether the state is connected using at least the specified version.
231    fn check_version(&self, min_version: Version) -> bool {
232        matches!(self, ConnectionState::Connected(info) if info.version.version >= min_version)
233    }
234
235    /// Checks whether the state is connected and the specified predicate holds for the feature
236    /// flags.
237    fn check_feature_flags(&self, flags: impl Fn(FeatureFlags) -> bool) -> bool {
238        matches!(self, ConnectionState::Connected(info) if flags(info.version.feature_flags))
239    }
240
241    fn get_version(&self) -> Option<VersionInfo> {
242        if let ConnectionState::Connected(info) = self {
243            Some(info.version)
244        } else {
245            None
246        }
247    }
248
249    fn is_trusted(&self) -> bool {
250        match self {
251            ConnectionState::Connected(info) => info.trusted,
252            ConnectionState::Connecting { info, .. } => info.trusted,
253            _ => false,
254        }
255    }
256
257    fn is_paused(&self) -> bool {
258        if let ConnectionState::Connected(info) = self {
259            info.paused
260        } else {
261            false
262        }
263    }
264}
265
266#[derive(Debug, Copy, Clone)]
267enum ConnectionAction {
268    None,
269    Reset,
270    SendUnloadComplete,
271    Reconnect {
272        initiate_contact: InitiateContactRequest,
273    },
274    SendFailedVersionResponse,
275}
276
277#[derive(PartialEq, Eq, Debug, Copy, Clone)]
278pub enum MonitorPageRequest {
279    None,
280    Some(MonitorPageGpas),
281    Invalid,
282}
283
284#[derive(PartialEq, Eq, Debug, Copy, Clone)]
285pub struct InitiateContactRequest {
286    pub version_requested: u32,
287    pub target_message_vp: u32,
288    pub monitor_page: MonitorPageRequest,
289    pub target_sint: u8,
290    pub target_vtl: u8,
291    pub feature_flags: u32,
292    pub interrupt_page: Option<u64>,
293    pub client_id: Guid,
294    pub trusted: bool,
295}
296
297#[derive(Debug, Copy, Clone)]
298pub struct OpenRequest {
299    pub open_id: u32,
300    pub ring_buffer_gpadl_id: GpadlId,
301    pub target_vp: u32,
302    pub downstream_ring_buffer_page_offset: u32,
303    pub user_data: UserDefinedData,
304    pub guest_specified_interrupt_info: Option<SignalInfo>,
305    pub flags: protocol::OpenChannelFlags,
306}
307
308#[derive(Debug, Copy, Clone, Eq, PartialEq)]
309pub enum Update<T: std::fmt::Debug + Copy + Clone> {
310    Unchanged,
311    Reset,
312    Set(T),
313}
314
315impl<T: std::fmt::Debug + Copy + Clone> From<Option<T>> for Update<T> {
316    fn from(value: Option<T>) -> Self {
317        match value {
318            None => Self::Reset,
319            Some(value) => Self::Set(value),
320        }
321    }
322}
323
324#[derive(Debug, Copy, Clone, Eq, PartialEq)]
325pub struct ModifyConnectionRequest {
326    pub version: Option<u32>,
327    pub monitor_page: Update<MonitorPageGpas>,
328    pub interrupt_page: Update<u64>,
329    pub target_message_vp: Option<u32>,
330    pub notify_relay: bool,
331}
332
333// Manual implementation because notify_relay should be true by default.
334impl Default for ModifyConnectionRequest {
335    fn default() -> Self {
336        Self {
337            version: None,
338            monitor_page: Update::Unchanged,
339            interrupt_page: Update::Unchanged,
340            target_message_vp: None,
341            notify_relay: true,
342        }
343    }
344}
345
346impl From<protocol::ModifyConnection> for ModifyConnectionRequest {
347    fn from(value: protocol::ModifyConnection) -> Self {
348        let monitor_page = if value.parent_to_child_monitor_page_gpa != 0 {
349            Update::Set(MonitorPageGpas {
350                parent_to_child: value.parent_to_child_monitor_page_gpa,
351                child_to_parent: value.child_to_parent_monitor_page_gpa,
352            })
353        } else {
354            Update::Reset
355        };
356
357        Self {
358            monitor_page,
359            ..Default::default()
360        }
361    }
362}
363
364/// Response to a ModifyConnectionRequest.
365#[derive(Debug, Copy, Clone)]
366pub enum ModifyConnectionResponse {
367    /// No version change was was requested, or the requested version is supported. Includes all the
368    /// feature flags supported by the relay host, so that supported flags reported to the guest can
369    /// be limited to that. The FeatureFlags field is not relevant if no version change was
370    /// requested.
371    Supported(protocol::ConnectionState, FeatureFlags),
372    /// A version change was requested but the relay host doesn't support that version. This
373    /// response cannot be returned for a request with no version change set.
374    Unsupported,
375}
376
377#[derive(Debug, Copy, Clone)]
378pub enum ModifyState {
379    NotModifying,
380    Modifying { pending_target_vp: Option<u32> },
381}
382
383impl ModifyState {
384    pub fn is_modifying(&self) -> bool {
385        matches!(self, ModifyState::Modifying { .. })
386    }
387}
388
389#[derive(Debug, Copy, Clone)]
390pub struct SignalInfo {
391    pub event_flag: u16,
392    pub connection_id: u32,
393}
394
395#[derive(Debug, Copy, Clone, PartialEq, Eq)]
396enum RestoreState {
397    /// The channel has been offered newly this session.
398    New,
399    /// The channel was in the saved state and has been re-offered this session,
400    /// but restore_channel has not yet been called on it, and revoke_unclaimed_channels
401    /// has not yet been called.
402    Restoring,
403    /// The channel was in the saved state but has not yet been re-offered this
404    /// session.
405    Unmatched,
406    /// The channel was in the saved state and is now in a fully restored state.
407    Restored,
408}
409
410/// The state of a single vmbus channel.
411#[derive(Debug, Clone)]
412enum ChannelState {
413    /// The device has offered the channel but the offer has not been sent to the
414    /// guest. However, there may still be GPADLs for this channel from a
415    /// previous connection.
416    ClientReleased,
417
418    /// The channel has been offered to the guest.
419    Closed,
420
421    /// The guest has requested to open the channel and the device has been
422    /// notified.
423    Opening {
424        request: OpenRequest,
425        reserved_state: Option<ReservedState>,
426    },
427
428    /// The channel is open by both the guest and the device.
429    Open {
430        params: OpenRequest,
431        modify_state: ModifyState,
432        reserved_state: Option<ReservedState>,
433    },
434
435    /// The device has been notified to close the channel.
436    Closing {
437        params: OpenRequest,
438        reserved_state: Option<ReservedState>,
439    },
440
441    /// The device has been notified to close the channel, and the guest has
442    /// requested to reopen it.
443    ClosingReopen {
444        params: OpenRequest,
445        request: OpenRequest,
446    },
447
448    /// The device has revoked the channel but the guest has not released it yet.
449    Revoked,
450
451    /// The device has been reoffered, but the guest has not released the previous
452    /// offer yet.
453    Reoffered,
454
455    /// The guest has released the channel but there is still a pending close
456    /// request to the device.
457    ClosingClientRelease,
458
459    /// The guest has released the channel, but there is still a pending open
460    /// request to the device.
461    OpeningClientRelease,
462}
463
464impl ChannelState {
465    /// If true, the channel is unreferenced by the guest, and the guest should
466    /// not be able to perform operations on the channel.
467    fn is_released(&self) -> bool {
468        match self {
469            ChannelState::Closed
470            | ChannelState::Opening { .. }
471            | ChannelState::Open { .. }
472            | ChannelState::Closing { .. }
473            | ChannelState::ClosingReopen { .. }
474            | ChannelState::Revoked
475            | ChannelState::Reoffered => false,
476
477            ChannelState::ClientReleased
478            | ChannelState::ClosingClientRelease
479            | ChannelState::OpeningClientRelease => true,
480        }
481    }
482
483    /// If true, the channel has been revoked.
484    fn is_revoked(&self) -> bool {
485        match self {
486            ChannelState::Revoked | ChannelState::Reoffered => true,
487
488            ChannelState::ClientReleased
489            | ChannelState::Closed
490            | ChannelState::Opening { .. }
491            | ChannelState::Open { .. }
492            | ChannelState::Closing { .. }
493            | ChannelState::ClosingReopen { .. }
494            | ChannelState::ClosingClientRelease
495            | ChannelState::OpeningClientRelease => false,
496        }
497    }
498
499    fn is_reserved(&self) -> bool {
500        match self {
501            // TODO: Should closing be included here?
502            ChannelState::Open {
503                reserved_state: Some(_),
504                ..
505            }
506            | ChannelState::Opening {
507                reserved_state: Some(_),
508                ..
509            }
510            | ChannelState::Closing {
511                reserved_state: Some(_),
512                ..
513            } => true,
514
515            ChannelState::Opening { .. }
516            | ChannelState::Open { .. }
517            | ChannelState::Closing { .. }
518            | ChannelState::ClientReleased
519            | ChannelState::Closed
520            | ChannelState::ClosingReopen { .. }
521            | ChannelState::Revoked
522            | ChannelState::Reoffered
523            | ChannelState::ClosingClientRelease
524            | ChannelState::OpeningClientRelease => false,
525        }
526    }
527}
528
529impl Display for ChannelState {
530    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
531        let state = match self {
532            Self::ClientReleased => "ClientReleased",
533            Self::Closed => "Closed",
534            Self::Opening { .. } => "Opening",
535            Self::Open { .. } => "Open",
536            Self::Closing { .. } => "Closing",
537            Self::ClosingReopen { .. } => "ClosingReopen",
538            Self::Revoked => "Revoked",
539            Self::Reoffered => "Reoffered",
540            Self::ClosingClientRelease => "ClosingClientRelease",
541            Self::OpeningClientRelease => "OpeningClientRelease",
542        };
543        write!(f, "{}", state)
544    }
545}
546
547/// Indicates how a MNF (monitored interrupts) should be used for a channel.
548#[derive(Debug, Clone, Default, mesh::MeshPayload)]
549pub enum MnfUsage {
550    /// The channel does not use MNF.
551    #[default]
552    Disabled,
553    /// The channel uses MNF, handled by this server, with the specified interrupt latency.
554    Enabled { latency: Duration },
555    /// The channel uses MNF, handled by the relay host, with the monitor ID specified by the relay
556    /// host.
557    Relayed { monitor_id: u8 },
558}
559
560impl MnfUsage {
561    pub fn is_enabled(&self) -> bool {
562        matches!(self, Self::Enabled { .. })
563    }
564
565    pub fn is_relayed(&self) -> bool {
566        matches!(self, Self::Relayed { .. })
567    }
568
569    pub fn enabled_and_then<T>(&self, f: impl FnOnce(Duration) -> Option<T>) -> Option<T> {
570        if let Self::Enabled { latency } = self {
571            f(*latency)
572        } else {
573            None
574        }
575    }
576}
577
578impl From<Option<Duration>> for MnfUsage {
579    fn from(value: Option<Duration>) -> Self {
580        match value {
581            None => Self::Disabled,
582            Some(latency) => Self::Enabled { latency },
583        }
584    }
585}
586
587#[derive(Debug, Clone, Default, mesh::MeshPayload)]
588pub struct OfferParamsInternal {
589    /// An informational string describing the channel type.
590    pub interface_name: String,
591    pub instance_id: Guid,
592    pub interface_id: Guid,
593    pub mmio_megabytes: u16,
594    pub mmio_megabytes_optional: u16,
595    pub subchannel_index: u16,
596    pub use_mnf: MnfUsage,
597    pub offer_order: Option<u32>,
598    pub flags: OfferFlags,
599    pub user_defined: UserDefinedData,
600}
601
602impl OfferParamsInternal {
603    /// Gets the offer key for this offer.
604    pub fn key(&self) -> OfferKey {
605        OfferKey {
606            interface_id: self.interface_id,
607            instance_id: self.instance_id,
608            subchannel_index: self.subchannel_index,
609        }
610    }
611}
612
613impl From<OfferParams> for OfferParamsInternal {
614    fn from(value: OfferParams) -> Self {
615        let mut user_defined = UserDefinedData::new_zeroed();
616
617        // All non-relay channels are capable of using a confidential ring buffer, but external
618        // memory is dependent on the device.
619        let mut flags = OfferFlags::new()
620            .with_confidential_ring_buffer(true)
621            .with_confidential_external_memory(value.allow_confidential_external_memory);
622
623        match value.channel_type {
624            ChannelType::Device { pipe_packets } => {
625                if pipe_packets {
626                    flags.set_named_pipe_mode(true);
627                    user_defined.as_pipe_params_mut().pipe_type = protocol::PipeType::MESSAGE;
628                }
629            }
630            ChannelType::Interface {
631                user_defined: interface_user_defined,
632            } => {
633                flags.set_enumerate_device_interface(true);
634                user_defined = interface_user_defined;
635            }
636            ChannelType::Pipe { message_mode } => {
637                flags.set_enumerate_device_interface(true);
638                flags.set_named_pipe_mode(true);
639                user_defined.as_pipe_params_mut().pipe_type = if message_mode {
640                    protocol::PipeType::MESSAGE
641                } else {
642                    protocol::PipeType::BYTE
643                };
644            }
645            ChannelType::HvSocket {
646                is_connect,
647                is_for_container,
648                silo_id,
649            } => {
650                flags.set_enumerate_device_interface(true);
651                flags.set_tlnpi_provider(true);
652                flags.set_named_pipe_mode(true);
653                *user_defined.as_hvsock_params_mut() = protocol::HvsockUserDefinedParameters::new(
654                    is_connect,
655                    is_for_container,
656                    silo_id,
657                );
658            }
659        };
660
661        Self {
662            interface_name: value.interface_name,
663            instance_id: value.instance_id,
664            interface_id: value.interface_id,
665            mmio_megabytes: value.mmio_megabytes,
666            mmio_megabytes_optional: value.mmio_megabytes_optional,
667            subchannel_index: value.subchannel_index,
668            use_mnf: value.mnf_interrupt_latency.into(),
669            offer_order: value.offer_order,
670            user_defined,
671            flags,
672        }
673    }
674}
675
676#[derive(Debug, Copy, Clone, Inspect, PartialEq, Eq)]
677pub struct ConnectionTarget {
678    pub vp: u32,
679    pub sint: u8,
680}
681
682#[derive(Debug, Copy, Clone, PartialEq, Eq)]
683pub enum MessageTarget {
684    Default,
685    ReservedChannel(OfferId, ConnectionTarget),
686    Custom(ConnectionTarget),
687}
688
689impl MessageTarget {
690    pub fn for_offer(offer_id: OfferId, reserved_state: &Option<ReservedState>) -> Self {
691        if let Some(state) = reserved_state {
692            Self::ReservedChannel(offer_id, state.target)
693        } else {
694            Self::Default
695        }
696    }
697}
698
699#[derive(Debug, Copy, Clone)]
700pub struct ReservedState {
701    version: VersionInfo,
702    target: ConnectionTarget,
703}
704
705/// A VMBus channel.
706#[derive(Debug)]
707struct Channel {
708    info: Option<OfferedInfo>,
709    offer: OfferParamsInternal,
710    state: ChannelState,
711    restore_state: RestoreState,
712}
713
714#[derive(Debug, Copy, Clone)]
715struct OfferedInfo {
716    channel_id: ChannelId,
717    connection_id: u32,
718    monitor_id: Option<MonitorId>,
719}
720
721impl Channel {
722    fn inspect_state(&self, resp: &mut inspect::Response<'_>) {
723        let mut target_vp = None;
724        let mut event_flag = None;
725        let mut connection_id = None;
726        let mut reserved_target = None;
727        let state = match &self.state {
728            ChannelState::ClientReleased => "client_released",
729            ChannelState::Closed => "closed",
730            ChannelState::Opening { reserved_state, .. } => {
731                reserved_target = reserved_state.map(|state| state.target);
732                "opening"
733            }
734            ChannelState::Open {
735                params,
736                reserved_state,
737                ..
738            } => {
739                target_vp = Some(params.target_vp);
740                if let Some(id) = params.guest_specified_interrupt_info {
741                    event_flag = Some(id.event_flag);
742                    connection_id = Some(id.connection_id);
743                }
744                reserved_target = reserved_state.map(|state| state.target);
745                "open"
746            }
747            ChannelState::Closing { reserved_state, .. } => {
748                reserved_target = reserved_state.map(|state| state.target);
749                "closing"
750            }
751            ChannelState::ClosingReopen { .. } => "closing_reopen",
752            ChannelState::Revoked => "revoked",
753            ChannelState::Reoffered => "reoffered",
754            ChannelState::ClosingClientRelease => "closing_client_release",
755            ChannelState::OpeningClientRelease => "opening_client_release",
756        };
757        let restore_state = match self.restore_state {
758            RestoreState::New => "new",
759            RestoreState::Restoring => "restoring",
760            RestoreState::Restored => "restored",
761            RestoreState::Unmatched => "unmatched",
762        };
763        if let Some(info) = &self.info {
764            resp.field("channel_id", info.channel_id.0)
765                .field("offered_connection_id", info.connection_id)
766                .field("monitor_id", info.monitor_id.map(|id| id.0));
767        }
768        resp.field("state", state)
769            .field("restore_state", restore_state)
770            .field("interface_name", self.offer.interface_name.clone())
771            .display("instance_id", &self.offer.instance_id)
772            .display("interface_id", &self.offer.interface_id)
773            .field("mmio_megabytes", self.offer.mmio_megabytes)
774            .field("target_vp", target_vp)
775            .field("guest_specified_event_flag", event_flag)
776            .field("guest_specified_connection_id", connection_id)
777            .field("reserved_connection_target", reserved_target)
778            .binary("offer_flags", self.offer.flags.into_bits());
779    }
780
781    /// Returns the monitor ID and latency only if it's being handled by this server.
782    ///
783    /// The monitor ID can be set while use_mnf is Relayed, which is the case if
784    /// the relay host is handling MNF.
785    ///
786    /// Also returns `None` for reserved channels, since monitored notifications
787    /// are only usable for standard channels. Otherwise, we fail later when we
788    /// try to change the MNF page as part of vmbus protocol renegotiation,
789    /// since the page still appears to be in use by a device.
790    fn handled_monitor_info(&self) -> Option<MonitorInfo> {
791        self.offer.use_mnf.enabled_and_then(|latency| {
792            if self.state.is_reserved() {
793                None
794            } else {
795                self.info.and_then(|info| {
796                    info.monitor_id.map(|monitor_id| MonitorInfo {
797                        monitor_id,
798                        latency,
799                    })
800                })
801            }
802        })
803    }
804
805    /// Prepares a channel to be sent to the guest by allocating a channel ID if
806    /// necessary and filling out channel.info.
807    fn prepare_channel(
808        &mut self,
809        offer_id: OfferId,
810        assigned_channels: &mut AssignedChannels,
811        assigned_monitors: &mut AssignedMonitors,
812    ) {
813        assert!(self.info.is_none());
814
815        // Allocate a channel ID.
816        let entry = assigned_channels
817            .allocate()
818            .expect("there are enough channel IDs for everything in ChannelList");
819
820        let channel_id = entry.id();
821        entry.insert(offer_id);
822        let connection_id = ConnectionId::new(channel_id.0, assigned_channels.vtl, SINT);
823
824        // Allocate a monitor ID if the channel uses MNF.
825        // N.B. If the synic doesn't support MNF or MNF is disabled by the server, use_mnf should
826        //      always be set to Disabled, except if the relay host is handling MnF in which case
827        //      we should use the monitor ID it provided.
828        let monitor_id = match self.offer.use_mnf {
829            MnfUsage::Enabled { .. } => {
830                let monitor_id = assigned_monitors.assign_monitor();
831                if monitor_id.is_none() {
832                    tracelimit::warn_ratelimited!("Out of monitor IDs.");
833                }
834
835                monitor_id
836            }
837            MnfUsage::Relayed { monitor_id } => Some(MonitorId(monitor_id)),
838            MnfUsage::Disabled => None,
839        };
840
841        self.info = Some(OfferedInfo {
842            channel_id,
843            connection_id: connection_id.0,
844            monitor_id,
845        });
846    }
847
848    /// Releases a channel's ID.
849    fn release_channel(
850        &mut self,
851        offer_id: OfferId,
852        assigned_channels: &mut AssignedChannels,
853        assigned_monitors: &mut AssignedMonitors,
854    ) {
855        if let Some(info) = self.info.take() {
856            assigned_channels.free(info.channel_id, offer_id);
857
858            // Only unassign the monitor ID if it was not a relayed ID provided by the offer.
859            if let Some(monitor_id) = info.monitor_id {
860                if self.offer.use_mnf.is_enabled() {
861                    assigned_monitors.release_monitor(monitor_id);
862                }
863            }
864        }
865    }
866}
867
868#[derive(Debug)]
869struct AssignedChannels {
870    assignments: Vec<Option<OfferId>>,
871    vtl: Vtl,
872    reserved_offset: usize,
873    /// The number of assigned channel IDs in the reserved range.
874    count_in_reserved_range: usize,
875}
876
877impl AssignedChannels {
878    fn new(vtl: Vtl, channel_id_offset: u16) -> Self {
879        Self {
880            assignments: vec![None; MAX_CHANNELS],
881            vtl,
882            reserved_offset: channel_id_offset as usize,
883            count_in_reserved_range: 0,
884        }
885    }
886
887    fn allowable_channel_count(&self) -> usize {
888        MAX_CHANNELS - self.reserved_offset + self.count_in_reserved_range
889    }
890
891    fn get(&self, channel_id: ChannelId) -> Option<OfferId> {
892        self.assignments
893            .get(Self::index(channel_id))
894            .copied()
895            .flatten()
896    }
897
898    fn set(&mut self, channel_id: ChannelId) -> Result<AssignmentEntry<'_>, OfferError> {
899        let index = Self::index(channel_id);
900        if self
901            .assignments
902            .get(index)
903            .ok_or(OfferError::InvalidChannelId(channel_id))?
904            .is_some()
905        {
906            return Err(OfferError::ChannelIdInUse(channel_id));
907        }
908        Ok(AssignmentEntry { list: self, index })
909    }
910
911    fn allocate(&mut self) -> Option<AssignmentEntry<'_>> {
912        let index = self.reserved_offset
913            + self.assignments[self.reserved_offset..]
914                .iter()
915                .position(|x| x.is_none())?;
916        Some(AssignmentEntry { list: self, index })
917    }
918
919    fn free(&mut self, channel_id: ChannelId, offer_id: OfferId) {
920        let index = Self::index(channel_id);
921        let slot = &mut self.assignments[index];
922        assert_eq!(slot.take(), Some(offer_id));
923        if index < self.reserved_offset {
924            self.count_in_reserved_range -= 1;
925        }
926    }
927
928    fn index(channel_id: ChannelId) -> usize {
929        channel_id.0.wrapping_sub(1) as usize
930    }
931}
932
933struct AssignmentEntry<'a> {
934    list: &'a mut AssignedChannels,
935    index: usize,
936}
937
938impl AssignmentEntry<'_> {
939    pub fn id(&self) -> ChannelId {
940        ChannelId(self.index as u32 + 1)
941    }
942
943    pub fn insert(self, offer_id: OfferId) {
944        assert!(
945            self.list.assignments[self.index]
946                .replace(offer_id)
947                .is_none()
948        );
949
950        if self.index < self.list.reserved_offset {
951            self.list.count_in_reserved_range += 1;
952        }
953    }
954}
955
956struct ChannelList {
957    channels: Slab<Channel>,
958}
959
960fn channel_inspect_path(offer: &OfferParamsInternal, suffix: std::fmt::Arguments<'_>) -> String {
961    if offer.subchannel_index == 0 {
962        format!("{}{}", offer.instance_id, suffix)
963    } else {
964        format!(
965            "{}/subchannels/{}{}",
966            offer.instance_id, offer.subchannel_index, suffix
967        )
968    }
969}
970
971impl ChannelList {
972    fn inspect(
973        &self,
974        notifier: &impl Notifier,
975        version: Option<VersionInfo>,
976        resp: &mut inspect::Response<'_>,
977    ) {
978        for (offer_id, channel) in self.iter() {
979            resp.child(
980                &channel_inspect_path(&channel.offer, format_args!("")),
981                |req| {
982                    let mut resp = req.respond();
983                    channel.inspect_state(&mut resp);
984
985                    // Merge in the inspection state from outside. Skip this if
986                    // the channel is revoked (and not reoffered) since in that
987                    // case the caller won't recognize the channel ID.
988                    resp.merge(inspect::adhoc(|req| {
989                        if !matches!(channel.state, ChannelState::Revoked) {
990                            notifier.inspect(version, offer_id, req);
991                        }
992                    }));
993                },
994            );
995        }
996    }
997}
998
999// This is limited by the size of the synic event flags bitmap (2048 bits per
1000// processor, bit 0 reserved for legacy channel bitmap multiplexing).
1001pub const MAX_CHANNELS: usize = 2047;
1002
1003impl ChannelList {
1004    fn new() -> Self {
1005        Self {
1006            channels: Slab::new(),
1007        }
1008    }
1009
1010    // The number of channels in the list.
1011    fn len(&self) -> usize {
1012        self.channels.len()
1013    }
1014
1015    /// Inserts a channel.
1016    fn offer(&mut self, new_channel: Channel) -> OfferId {
1017        OfferId(self.channels.insert(new_channel))
1018    }
1019
1020    /// Removes a channel by offer ID.
1021    fn remove(&mut self, offer_id: OfferId) {
1022        let channel = self.channels.remove(offer_id.0);
1023        assert!(channel.info.is_none());
1024    }
1025
1026    /// Gets a channel by guest channel ID.
1027    fn get_by_channel_id_mut(
1028        &mut self,
1029        assigned_channels: &AssignedChannels,
1030        channel_id: ChannelId,
1031    ) -> Result<(OfferId, &mut Channel), ChannelError> {
1032        let offer_id = assigned_channels
1033            .get(channel_id)
1034            .ok_or(ChannelError::UnknownChannelId)?;
1035        let channel = &mut self[offer_id];
1036        if channel.state.is_released() {
1037            return Err(ChannelError::ChannelReleased);
1038        }
1039        assert_eq!(
1040            channel.info.as_ref().map(|info| info.channel_id),
1041            Some(channel_id)
1042        );
1043        Ok((offer_id, channel))
1044    }
1045
1046    /// Gets a channel by guest channel ID.
1047    fn get_by_channel_id(
1048        &self,
1049        assigned_channels: &AssignedChannels,
1050        channel_id: ChannelId,
1051    ) -> Result<(OfferId, &Channel), ChannelError> {
1052        let offer_id = assigned_channels
1053            .get(channel_id)
1054            .ok_or(ChannelError::UnknownChannelId)?;
1055        let channel = &self[offer_id];
1056        if channel.state.is_released() {
1057            return Err(ChannelError::ChannelReleased);
1058        }
1059        assert_eq!(
1060            channel.info.as_ref().map(|info| info.channel_id),
1061            Some(channel_id)
1062        );
1063        Ok((offer_id, channel))
1064    }
1065
1066    /// Gets a channel by offer key (interface ID, instance ID, subchannel
1067    /// index).
1068    fn get_by_key_mut(&mut self, key: &OfferKey) -> Option<(OfferId, &mut Channel)> {
1069        for (offer_id, channel) in self.iter_mut() {
1070            if channel.offer.instance_id == key.instance_id
1071                && channel.offer.interface_id == key.interface_id
1072                && channel.offer.subchannel_index == key.subchannel_index
1073            {
1074                return Some((offer_id, channel));
1075            }
1076        }
1077        None
1078    }
1079
1080    /// Returns an iterator over the channels.
1081    fn iter(&self) -> impl Iterator<Item = (OfferId, &Channel)> {
1082        self.channels
1083            .iter()
1084            .map(|(id, channel)| (OfferId(id), channel))
1085    }
1086
1087    /// Returns an iterator over the channels.
1088    fn iter_mut(&mut self) -> impl Iterator<Item = (OfferId, &mut Channel)> {
1089        self.channels
1090            .iter_mut()
1091            .map(|(id, channel)| (OfferId(id), channel))
1092    }
1093
1094    /// Iterates through the channels, retaining those where `f` returns true.
1095    fn retain<F>(&mut self, mut f: F)
1096    where
1097        F: FnMut(OfferId, &mut Channel) -> bool,
1098    {
1099        self.channels.retain(|id, channel| {
1100            let retain = f(OfferId(id), channel);
1101            if !retain {
1102                assert!(channel.info.is_none());
1103            }
1104            retain
1105        })
1106    }
1107}
1108
1109impl Index<OfferId> for ChannelList {
1110    type Output = Channel;
1111
1112    fn index(&self, offer_id: OfferId) -> &Self::Output {
1113        &self.channels[offer_id.0]
1114    }
1115}
1116
1117impl IndexMut<OfferId> for ChannelList {
1118    fn index_mut(&mut self, offer_id: OfferId) -> &mut Self::Output {
1119        &mut self.channels[offer_id.0]
1120    }
1121}
1122
1123/// A GPADL.
1124#[derive(Debug, Inspect)]
1125struct Gpadl {
1126    count: u16,
1127    #[inspect(skip)]
1128    buf: Vec<u64>,
1129    state: GpadlState,
1130}
1131
1132#[derive(Debug, Copy, Clone, PartialEq, Eq, Inspect)]
1133enum GpadlState {
1134    /// The GPADL has not yet been fully sent to the host.
1135    InProgress,
1136    /// The GPADL has been sent to the device but is not yet acknowledged.
1137    Offered,
1138    /// The device has not acknowledged the GPADL but the GPADL is ready to be
1139    /// torn down.
1140    OfferedTearingDown,
1141    /// The device has acknowledged the GPADL.
1142    Accepted,
1143    /// The device has been notified that the GPADL is being torn down.
1144    TearingDown,
1145}
1146
1147impl Gpadl {
1148    /// Creates a new GPADL with `count` ranges and `len * 8` bytes in the range
1149    /// buffer.
1150    fn new(count: u16, len: usize) -> Self {
1151        Self {
1152            state: GpadlState::InProgress,
1153            count,
1154            buf: Vec::with_capacity(len),
1155        }
1156    }
1157
1158    /// Appends `data` to an in-progress GPADL. Returns whether the GPADL is complete.
1159    fn append(&mut self, data: &[u8]) -> Result<bool, ChannelError> {
1160        if self.state == GpadlState::InProgress {
1161            let buf = &mut self.buf;
1162            // data.len() may be longer than is actually valid since some
1163            // clients (e.g. UEFI) always pass the maximum message length. In
1164            // this case, calculate the useful length from the remaining
1165            // capacity instead.
1166            let len = min(data.len() & !7, (buf.capacity() - buf.len()) * 8);
1167            let data = &data[..len];
1168            let start = buf.len();
1169            buf.resize(buf.len() + data.len() / 8, 0);
1170            buf[start..].as_mut_bytes().copy_from_slice(data);
1171            Ok(if buf.len() == buf.capacity() {
1172                gparange::MultiPagedRangeBuf::<Vec<u64>>::validate(self.count as usize, buf)
1173                    .map_err(ChannelError::InvalidGpaRange)?;
1174                self.state = GpadlState::Offered;
1175                true
1176            } else {
1177                false
1178            })
1179        } else {
1180            Err(ChannelError::GpadlAlreadyComplete)
1181        }
1182    }
1183}
1184
1185/// The parameters provided by the guest when the channel is being opened.
1186#[derive(Debug, Copy, Clone)]
1187pub struct OpenParams {
1188    pub open_data: OpenData,
1189    pub connection_id: u32,
1190    pub event_flag: u16,
1191    pub monitor_info: Option<MonitorInfo>,
1192    pub flags: protocol::OpenChannelFlags,
1193    pub reserved_target: Option<ConnectionTarget>,
1194    pub channel_id: ChannelId,
1195}
1196
1197impl OpenParams {
1198    fn from_request(
1199        info: &OfferedInfo,
1200        request: &OpenRequest,
1201        monitor_info: Option<MonitorInfo>,
1202        reserved_target: Option<ConnectionTarget>,
1203    ) -> Self {
1204        // Determine whether to use the alternate IDs.
1205        // N.B. If not specified, the regular IDs are stored as "alternate" in the OpenData.
1206        let (event_flag, connection_id) = if let Some(id) = request.guest_specified_interrupt_info {
1207            (id.event_flag, id.connection_id)
1208        } else {
1209            (info.channel_id.0 as u16, info.connection_id)
1210        };
1211
1212        Self {
1213            open_data: OpenData {
1214                target_vp: request.target_vp,
1215                ring_offset: request.downstream_ring_buffer_page_offset,
1216                ring_gpadl_id: request.ring_buffer_gpadl_id,
1217                user_data: request.user_data,
1218                event_flag,
1219                connection_id,
1220            },
1221            connection_id,
1222            event_flag,
1223            monitor_info,
1224            flags: request.flags.with_unused(0),
1225            reserved_target,
1226            channel_id: info.channel_id,
1227        }
1228    }
1229}
1230
1231/// A channel action, sent to the device when a channel state changes.
1232#[derive(Debug)]
1233pub enum Action {
1234    Open(OpenParams, VersionInfo),
1235    Close,
1236    Gpadl(GpadlId, u16, Vec<u64>),
1237    TeardownGpadl {
1238        gpadl_id: GpadlId,
1239        post_restore: bool,
1240    },
1241    Modify {
1242        target_vp: u32,
1243    },
1244}
1245
1246/// The supported VMBus protocol versions.
1247static SUPPORTED_VERSIONS: &[Version] = &[
1248    Version::V1,
1249    Version::Win7,
1250    Version::Win8,
1251    Version::Win8_1,
1252    Version::Win10,
1253    Version::Win10Rs3_0,
1254    Version::Win10Rs3_1,
1255    Version::Win10Rs4,
1256    Version::Win10Rs5,
1257    Version::Iron,
1258    Version::Copper,
1259];
1260
1261// Feature flags that are always supported.
1262// N.B. Confidential channels are conditionally supported if running in the paravisor.
1263const SUPPORTED_FEATURE_FLAGS: FeatureFlags = FeatureFlags::new()
1264    .with_guest_specified_signal_parameters(true)
1265    .with_channel_interrupt_redirection(true)
1266    .with_modify_connection(true)
1267    .with_client_id(true)
1268    .with_pause_resume(true);
1269
1270/// Trait for sending requests to devices and the guest.
1271pub trait Notifier: Send {
1272    /// Requests a channel action.
1273    fn notify(&mut self, offer_id: OfferId, action: Action);
1274
1275    /// Forward an unhandled InitiateContact request to an external server.
1276    fn forward_unhandled(&mut self, request: InitiateContactRequest);
1277
1278    /// Update server state with information from the connection, and optionally notify the relay.
1279    ///
1280    /// N.B. If `ModifyConnectionRequest::notify_relay` is true and the function does not return an
1281    /// error, the server expects `Server::complete_modify_connection()` to be called, regardless of
1282    /// whether or not there is a relay.
1283    fn modify_connection(&mut self, request: ModifyConnectionRequest) -> anyhow::Result<()>;
1284
1285    /// Inspects a channel.
1286    fn inspect(&self, version: Option<VersionInfo>, offer_id: OfferId, req: inspect::Request<'_>) {
1287        let _ = (version, offer_id, req);
1288    }
1289
1290    /// Sends a synic message to the guest.
1291    /// Returns true if the message was sent, and false if it must be retried.
1292    #[must_use]
1293    fn send_message(&mut self, message: &OutgoingMessage, target: MessageTarget) -> bool;
1294
1295    /// Used to signal the hvsocket handler that there is a new connection request.
1296    fn notify_hvsock(&mut self, request: &HvsockConnectRequest);
1297
1298    /// Notifies that a requested reset is complete.
1299    fn reset_complete(&mut self);
1300
1301    /// Notifies that a guest-requested unload is complete.
1302    fn unload_complete(&mut self);
1303}
1304
1305impl Server {
1306    /// Creates a new VMBus server.
1307    pub fn new(vtl: Vtl, child_connection_id: u32, channel_id_offset: u16) -> Self {
1308        Server {
1309            state: ConnectionState::Disconnected,
1310            channels: ChannelList::new(),
1311            assigned_channels: AssignedChannels::new(vtl, channel_id_offset),
1312            assigned_monitors: AssignedMonitors::new(),
1313            gpadls: Default::default(),
1314            incomplete_gpadls: Default::default(),
1315            child_connection_id,
1316            max_version: None,
1317            delayed_max_version: None,
1318            pending_messages: PendingMessages(VecDeque::new()),
1319        }
1320    }
1321
1322    /// Associates a `Notifier` with the server.
1323    pub fn with_notifier<'a, T: Notifier>(
1324        &'a mut self,
1325        notifier: &'a mut T,
1326    ) -> ServerWithNotifier<'a, T> {
1327        self.validate();
1328        ServerWithNotifier {
1329            inner: self,
1330            notifier,
1331        }
1332    }
1333
1334    fn validate(&self) {
1335        #[cfg(debug_assertions)]
1336        for (_, channel) in self.channels.iter() {
1337            let should_have_info = !channel.state.is_released();
1338            if channel.info.is_some() != should_have_info {
1339                panic!("channel invariant violation: {channel:?}");
1340            }
1341        }
1342    }
1343
1344    /// Indicates the maximum supported version by the real host in an Underhill relay scenario.
1345    pub fn set_compatibility_version(&mut self, version: MaxVersionInfo, delay: bool) {
1346        if delay {
1347            self.delayed_max_version = Some(version)
1348        } else {
1349            tracing::info!(?version, "Limiting VmBus connections to version");
1350            self.max_version = Some(version);
1351        }
1352    }
1353
1354    pub fn channel_gpadls(&self, offer_id: OfferId) -> Vec<RestoredGpadl> {
1355        self.gpadls
1356            .iter()
1357            .filter_map(|(&(gpadl_id, gpadl_offer_id), gpadl)| {
1358                if offer_id != gpadl_offer_id {
1359                    return None;
1360                }
1361                let accepted = match gpadl.state {
1362                    GpadlState::Offered | GpadlState::OfferedTearingDown => false,
1363                    GpadlState::Accepted => true,
1364                    GpadlState::InProgress | GpadlState::TearingDown => return None,
1365                };
1366                Some(RestoredGpadl {
1367                    request: GpadlRequest {
1368                        id: gpadl_id,
1369                        count: gpadl.count,
1370                        buf: gpadl.buf.clone(),
1371                    },
1372                    accepted,
1373                })
1374            })
1375            .collect()
1376    }
1377
1378    pub fn get_version(&self) -> Option<VersionInfo> {
1379        self.state.get_version()
1380    }
1381
1382    pub fn get_restore_open_params(&self, offer_id: OfferId) -> Result<OpenParams, RestoreError> {
1383        let channel = &self.channels[offer_id];
1384
1385        // Check this here to avoid doing unnecessary work.
1386        match channel.restore_state {
1387            RestoreState::New => {
1388                // This channel was never offered, or was released by the guest during the save.
1389                // This is a problem since if this was called the device expects the channel to be
1390                // open.
1391                return Err(RestoreError::MissingChannel(channel.offer.key()));
1392            }
1393            RestoreState::Restoring => {}
1394            RestoreState::Unmatched => unreachable!(),
1395            RestoreState::Restored => {
1396                return Err(RestoreError::AlreadyRestored(channel.offer.key()));
1397            }
1398        }
1399
1400        let info = channel
1401            .info
1402            .ok_or_else(|| RestoreError::MissingChannel(channel.offer.key()))?;
1403
1404        let (request, reserved_state) = match channel.state {
1405            ChannelState::Closed => {
1406                return Err(RestoreError::MismatchedOpenState(channel.offer.key()));
1407            }
1408            ChannelState::Closing { params, .. } | ChannelState::ClosingReopen { params, .. } => {
1409                (params, None)
1410            }
1411            ChannelState::Opening {
1412                request,
1413                reserved_state,
1414            } => (request, reserved_state),
1415            ChannelState::Open {
1416                params,
1417                reserved_state,
1418                ..
1419            } => (params, reserved_state),
1420            ChannelState::ClientReleased | ChannelState::Reoffered => {
1421                return Err(RestoreError::MissingChannel(channel.offer.key()));
1422            }
1423            ChannelState::Revoked
1424            | ChannelState::ClosingClientRelease
1425            | ChannelState::OpeningClientRelease => unreachable!(),
1426        };
1427
1428        Ok(OpenParams::from_request(
1429            &info,
1430            &request,
1431            channel.handled_monitor_info(),
1432            reserved_state.map(|state| state.target),
1433        ))
1434    }
1435
1436    /// Check if there are any messages in the pending queue.
1437    pub fn has_pending_messages(&self) -> bool {
1438        !self.pending_messages.0.is_empty() && !self.state.is_paused()
1439    }
1440
1441    /// Tries to resend pending messages using the provided `send`` function.
1442    pub fn poll_flush_pending_messages(
1443        &mut self,
1444        mut send: impl FnMut(&OutgoingMessage) -> Poll<()>,
1445    ) -> Poll<()> {
1446        if !self.state.is_paused() {
1447            while let Some(message) = self.pending_messages.0.front() {
1448                ready!(send(message));
1449                self.pending_messages.0.pop_front();
1450            }
1451        }
1452
1453        Poll::Ready(())
1454    }
1455}
1456
1457impl<'a, N: 'a + Notifier> ServerWithNotifier<'a, N> {
1458    /// Marks a channel as restored.
1459    ///
1460    /// If this is not called for a channel but vmbus state is restored, then it
1461    /// is assumed that the offer is a fresh one, and the channel will be
1462    /// revoked and reoffered.
1463    pub fn restore_channel(&mut self, offer_id: OfferId, open: bool) -> Result<(), RestoreError> {
1464        let channel = &mut self.inner.channels[offer_id];
1465
1466        // We need to check this here as well, because get_restore_open_params may not have been
1467        // called.
1468        match channel.restore_state {
1469            RestoreState::New => {
1470                // This channel was never offered, or was released by the guest
1471                // during the save. This is fine as long as the device does not
1472                // expect the channel to be open.
1473                if open {
1474                    return Err(RestoreError::MissingChannel(channel.offer.key()));
1475                } else {
1476                    return Ok(());
1477                }
1478            }
1479            RestoreState::Restoring => {}
1480            RestoreState::Unmatched => unreachable!(),
1481            RestoreState::Restored => {
1482                return Err(RestoreError::AlreadyRestored(channel.offer.key()));
1483            }
1484        }
1485
1486        let info = channel
1487            .info
1488            .ok_or_else(|| RestoreError::MissingChannel(channel.offer.key()))?;
1489
1490        if let Some(monitor_info) = channel.handled_monitor_info() {
1491            if !self
1492                .inner
1493                .assigned_monitors
1494                .claim_monitor(monitor_info.monitor_id)
1495            {
1496                return Err(RestoreError::DuplicateMonitorId(monitor_info.monitor_id.0));
1497            }
1498        }
1499
1500        if open {
1501            match channel.state {
1502                ChannelState::Closed => {
1503                    return Err(RestoreError::MismatchedOpenState(channel.offer.key()));
1504                }
1505                ChannelState::Closing { .. } | ChannelState::ClosingReopen { .. } => {
1506                    self.notifier.notify(offer_id, Action::Close);
1507                }
1508                ChannelState::Opening {
1509                    request,
1510                    reserved_state,
1511                } => {
1512                    self.inner
1513                        .pending_messages
1514                        .sender(self.notifier, self.inner.state.is_paused())
1515                        .send_open_result(
1516                            info.channel_id,
1517                            &request,
1518                            protocol::STATUS_SUCCESS,
1519                            MessageTarget::for_offer(offer_id, &reserved_state),
1520                        );
1521                    channel.state = ChannelState::Open {
1522                        params: request,
1523                        modify_state: ModifyState::NotModifying,
1524                        reserved_state,
1525                    };
1526                }
1527                ChannelState::Open { .. } => {}
1528                ChannelState::ClientReleased | ChannelState::Reoffered => {
1529                    return Err(RestoreError::MissingChannel(channel.offer.key()));
1530                }
1531                ChannelState::Revoked
1532                | ChannelState::ClosingClientRelease
1533                | ChannelState::OpeningClientRelease => unreachable!(),
1534            };
1535        } else {
1536            match channel.state {
1537                ChannelState::Closed => {}
1538                // If a channel was reoffered before the save, it was saved as revoked and then
1539                // restored to reoffered if the device is offering it again. If we reach this state,
1540                // the device has offered the channel but we are still waiting for the client to
1541                // release the old revoked channel, so the state must remain reoffered.
1542                ChannelState::Reoffered => {}
1543                ChannelState::Closing { .. } => {
1544                    channel.state = ChannelState::Closed;
1545                }
1546                ChannelState::ClosingReopen { request, .. } => {
1547                    self.notifier.notify(
1548                        offer_id,
1549                        Action::Open(
1550                            OpenParams::from_request(
1551                                &info,
1552                                &request,
1553                                channel.handled_monitor_info(),
1554                                None,
1555                            ),
1556                            self.inner.state.get_version().expect("must be connected"),
1557                        ),
1558                    );
1559                    channel.state = ChannelState::Opening {
1560                        request,
1561                        reserved_state: None,
1562                    };
1563                }
1564                ChannelState::Opening {
1565                    request,
1566                    reserved_state,
1567                } => {
1568                    self.notifier.notify(
1569                        offer_id,
1570                        Action::Open(
1571                            OpenParams::from_request(
1572                                &info,
1573                                &request,
1574                                channel.handled_monitor_info(),
1575                                reserved_state.map(|state| state.target),
1576                            ),
1577                            self.inner.state.get_version().expect("must be connected"),
1578                        ),
1579                    );
1580                }
1581                ChannelState::Open { .. } => {
1582                    return Err(RestoreError::MismatchedOpenState(channel.offer.key()));
1583                }
1584                ChannelState::ClientReleased => {
1585                    return Err(RestoreError::MissingChannel(channel.offer.key()));
1586                }
1587                ChannelState::Revoked
1588                | ChannelState::ClosingClientRelease
1589                | ChannelState::OpeningClientRelease => unreachable!(),
1590            }
1591        }
1592
1593        channel.restore_state = RestoreState::Restored;
1594        Ok(())
1595    }
1596
1597    /// Revoke and reoffer channels to the guest, depending on their `RestoreState.`
1598    /// This function should be called after [`ServerWithNotifier::restore`].
1599    pub fn revoke_unclaimed_channels(&mut self) {
1600        for (offer_id, channel) in self.inner.channels.iter_mut() {
1601            match channel.restore_state {
1602                RestoreState::Restored => {
1603                    // The channel is fully restored. Nothing more to do.
1604                }
1605                RestoreState::New => {
1606                    // This is a fresh channel offer, not in the saved state.
1607                    // Send the offer to the guest if it has not already been
1608                    // sent (which could have happened if the channel was
1609                    // offered after restore() but before revoke_unclaimed_channels()).
1610                    if let ConnectionState::Connected(info) = &self.inner.state {
1611                        if matches!(channel.state, ChannelState::ClientReleased) {
1612                            channel.prepare_channel(
1613                                offer_id,
1614                                &mut self.inner.assigned_channels,
1615                                &mut self.inner.assigned_monitors,
1616                            );
1617                            channel.state = ChannelState::Closed;
1618                            self.inner
1619                                .pending_messages
1620                                .sender(self.notifier, self.inner.state.is_paused())
1621                                .send_offer(channel, info.version);
1622                        }
1623                    }
1624                }
1625                RestoreState::Restoring => {
1626                    // restore_channel was never called for this, but it was in
1627                    // the saved state. This indicates the offer is meant to be
1628                    // fresh, so revoke and reoffer it.
1629                    let retain = revoke(
1630                        self.inner
1631                            .pending_messages
1632                            .sender(self.notifier, self.inner.state.is_paused()),
1633                        offer_id,
1634                        channel,
1635                        &mut self.inner.gpadls,
1636                    );
1637                    assert!(retain, "channel has not been released");
1638                    channel.state = ChannelState::Reoffered;
1639                }
1640                RestoreState::Unmatched => {
1641                    // offer_channel was never called for this, but it was in
1642                    // the saved state. Revoke it.
1643                    let retain = revoke(
1644                        self.inner
1645                            .pending_messages
1646                            .sender(self.notifier, self.inner.state.is_paused()),
1647                        offer_id,
1648                        channel,
1649                        &mut self.inner.gpadls,
1650                    );
1651                    assert!(retain, "channel has not been released");
1652                }
1653            }
1654        }
1655
1656        // Notify the channels for any GPADLs in progress.
1657        for (&(gpadl_id, offer_id), gpadl) in self.inner.gpadls.iter_mut() {
1658            match gpadl.state {
1659                GpadlState::InProgress | GpadlState::Accepted => {}
1660                GpadlState::Offered => {
1661                    self.notifier.notify(
1662                        offer_id,
1663                        Action::Gpadl(gpadl_id, gpadl.count, gpadl.buf.clone()),
1664                    );
1665                }
1666                GpadlState::TearingDown => {
1667                    self.notifier.notify(
1668                        offer_id,
1669                        Action::TeardownGpadl {
1670                            gpadl_id,
1671                            post_restore: true,
1672                        },
1673                    );
1674                }
1675                GpadlState::OfferedTearingDown => unreachable!(),
1676            }
1677        }
1678
1679        self.check_disconnected();
1680    }
1681
1682    /// Initiates a state reset and a closing of all channels.
1683    ///
1684    /// Only one reset is allowed at a time, and no calls to
1685    /// `handle_synic_message` are allowed during a reset operation.
1686    pub fn reset(&mut self) {
1687        assert!(!self.is_resetting());
1688        if self.request_disconnect(ConnectionAction::Reset) {
1689            self.complete_reset();
1690        }
1691    }
1692
1693    fn complete_reset(&mut self) {
1694        // Reset the restore state since everything is now in a clean state.
1695        for (_, channel) in self.inner.channels.iter_mut() {
1696            channel.restore_state = RestoreState::New;
1697        }
1698        self.inner.pending_messages.0.clear();
1699        self.notifier.reset_complete();
1700    }
1701
1702    /// Creates a new channel, returning its channel ID.
1703    pub fn offer_channel(&mut self, offer: OfferParamsInternal) -> Result<OfferId, OfferError> {
1704        // Ensure no channel with this interface and instance ID exists.
1705        if let Some((offer_id, channel)) = self.inner.channels.get_by_key_mut(&offer.key()) {
1706            // Replace the current offer if this is an unmatched restored
1707            // channel, or if this matching offer has been revoked by the host
1708            // but not yet released by the guest.
1709            if channel.restore_state != RestoreState::Unmatched
1710                && !matches!(channel.state, ChannelState::Revoked)
1711            {
1712                return Err(OfferError::AlreadyExists(offer.key()));
1713            }
1714
1715            let info = channel.info.expect("assigned");
1716            if channel.restore_state == RestoreState::Unmatched {
1717                tracing::debug!(
1718                    offer_id = offer_id.0,
1719                    key = %channel.offer.key(),
1720                    "matched channel"
1721                );
1722
1723                assert!(!matches!(channel.state, ChannelState::Revoked));
1724                // This channel was previously offered to the guest in the saved
1725                // state. Match this back up to handle future calls to
1726                // restore_channel and revoke_unclaimed_channels.
1727                channel.restore_state = RestoreState::Restoring;
1728
1729                // The relay can specify a host-determined monitor ID, which needs to match what's
1730                // in the saved state.
1731                if let MnfUsage::Relayed { monitor_id } = offer.use_mnf {
1732                    if info.monitor_id != Some(MonitorId(monitor_id)) {
1733                        return Err(OfferError::MismatchedMonitorId(
1734                            info.monitor_id,
1735                            MonitorId(monitor_id),
1736                        ));
1737                    }
1738                }
1739            } else {
1740                // The channel has been revoked but the guest still has a
1741                // reference to it. Save the offer for reoffering immediately
1742                // after the child releases it.
1743                channel.state = ChannelState::Reoffered;
1744                tracing::info!(?offer_id, key = %channel.offer.key(), "channel marked for reoffer");
1745            }
1746
1747            channel.offer = offer;
1748            return Ok(offer_id);
1749        }
1750
1751        let mut connected_version = None;
1752        let state = match self.inner.state {
1753            ConnectionState::Connected(ConnectionInfo {
1754                offers_sent: true,
1755                version,
1756                ..
1757            }) => {
1758                connected_version = Some(version);
1759                ChannelState::Closed
1760            }
1761            ConnectionState::Connected(ConnectionInfo {
1762                offers_sent: false, ..
1763            })
1764            | ConnectionState::Connecting { .. }
1765            | ConnectionState::Disconnecting { .. }
1766            | ConnectionState::Disconnected => ChannelState::ClientReleased,
1767        };
1768
1769        // Ensure there will be enough channel IDs for this channel.
1770        if self.inner.channels.len() >= self.inner.assigned_channels.allowable_channel_count() {
1771            return Err(OfferError::TooManyChannels);
1772        }
1773
1774        let key = offer.key();
1775        let confidential_ring_buffer = offer.flags.confidential_ring_buffer();
1776        let confidential_external_memory = offer.flags.confidential_external_memory();
1777        let channel = Channel {
1778            info: None,
1779            offer,
1780            state,
1781            restore_state: RestoreState::New,
1782        };
1783
1784        let offer_id = self.inner.channels.offer(channel);
1785        if let Some(version) = connected_version {
1786            let channel = &mut self.inner.channels[offer_id];
1787            channel.prepare_channel(
1788                offer_id,
1789                &mut self.inner.assigned_channels,
1790                &mut self.inner.assigned_monitors,
1791            );
1792
1793            self.inner
1794                .pending_messages
1795                .sender(self.notifier, self.inner.state.is_paused())
1796                .send_offer(channel, version);
1797        }
1798
1799        tracing::info!(?offer_id, %key, confidential_ring_buffer, confidential_external_memory, "new channel");
1800        Ok(offer_id)
1801    }
1802
1803    /// Revokes a channel by ID.
1804    pub fn revoke_channel(&mut self, offer_id: OfferId) {
1805        let channel = &mut self.inner.channels[offer_id];
1806        let retain = revoke(
1807            self.inner
1808                .pending_messages
1809                .sender(self.notifier, self.inner.state.is_paused()),
1810            offer_id,
1811            channel,
1812            &mut self.inner.gpadls,
1813        );
1814        if !retain {
1815            self.inner.channels.remove(offer_id);
1816        }
1817
1818        self.check_disconnected();
1819    }
1820
1821    /// Completes an open operation with `result`.
1822    pub fn open_complete(&mut self, offer_id: OfferId, result: i32) {
1823        tracing::debug!(offer_id = offer_id.0, result, "open complete");
1824
1825        let channel = &mut self.inner.channels[offer_id];
1826        match channel.state {
1827            ChannelState::Opening {
1828                request,
1829                reserved_state,
1830            } => {
1831                let channel_id = channel.info.expect("assigned").channel_id;
1832                if result >= 0 {
1833                    tracelimit::info_ratelimited!(
1834                        offer_id = offer_id.0,
1835                        channel_id = channel_id.0,
1836                        result,
1837                        "opened channel"
1838                    );
1839                } else {
1840                    // Log channel open failures at error level for visibility.
1841                    tracelimit::error_ratelimited!(
1842                        offer_id = offer_id.0,
1843                        channel_id = channel_id.0,
1844                        result,
1845                        "failed to open channel"
1846                    );
1847                }
1848
1849                self.inner
1850                    .pending_messages
1851                    .sender(self.notifier, self.inner.state.is_paused())
1852                    .send_open_result(
1853                        channel_id,
1854                        &request,
1855                        result,
1856                        MessageTarget::for_offer(offer_id, &reserved_state),
1857                    );
1858                channel.state = if result >= 0 {
1859                    ChannelState::Open {
1860                        params: request,
1861                        modify_state: ModifyState::NotModifying,
1862                        reserved_state,
1863                    }
1864                } else {
1865                    ChannelState::Closed
1866                };
1867            }
1868            ChannelState::OpeningClientRelease => {
1869                tracing::info!(
1870                    offer_id = offer_id.0,
1871                    result,
1872                    "opened channel (client released)"
1873                );
1874
1875                if result >= 0 {
1876                    channel.state = ChannelState::ClosingClientRelease;
1877                    self.notifier.notify(offer_id, Action::Close);
1878                } else {
1879                    channel.state = ChannelState::ClientReleased;
1880                    self.check_disconnected();
1881                }
1882            }
1883
1884            ChannelState::ClientReleased
1885            | ChannelState::Closed
1886            | ChannelState::Open { .. }
1887            | ChannelState::Closing { .. }
1888            | ChannelState::ClosingReopen { .. }
1889            | ChannelState::Revoked
1890            | ChannelState::Reoffered
1891            | ChannelState::ClosingClientRelease => {
1892                tracing::error!(?offer_id, state = ?channel.state, "invalid open complete")
1893            }
1894        }
1895    }
1896
1897    /// If true, all channels are in a reset state, with no references by the
1898    /// guest. Reserved channels should only be included if the VM is resetting.
1899    fn are_channels_reset(&self, include_reserved: bool) -> bool {
1900        self.inner.gpadls.keys().all(|(_, offer_id)| {
1901            !include_reserved && self.inner.channels[*offer_id].state.is_reserved()
1902        }) && self.inner.channels.iter().all(|(_, channel)| {
1903            matches!(channel.state, ChannelState::ClientReleased)
1904                || (!include_reserved && channel.state.is_reserved())
1905        })
1906    }
1907
1908    /// Checks if the connection state is fully disconnected and advances the
1909    /// connection state machine. Must be called any time a GPADL is deleted or
1910    /// a channel enters the ClientReleased state.
1911    fn check_disconnected(&mut self) {
1912        match self.inner.state {
1913            ConnectionState::Disconnecting {
1914                next_action,
1915                modify_sent: false,
1916            } => {
1917                if self.are_channels_reset(matches!(next_action, ConnectionAction::Reset)) {
1918                    self.inner.state = ConnectionState::Disconnecting {
1919                        next_action,
1920                        modify_sent: true,
1921                    };
1922
1923                    // Reset server state and disconnect the relay if there is one.
1924                    self.notifier
1925                        .modify_connection(ModifyConnectionRequest {
1926                            monitor_page: Update::Reset,
1927                            interrupt_page: Update::Reset,
1928                            ..Default::default()
1929                        })
1930                        .expect("resetting state should not fail");
1931                }
1932            }
1933            ConnectionState::Disconnecting {
1934                modify_sent: true, ..
1935            }
1936            | ConnectionState::Disconnected
1937            | ConnectionState::Connected { .. }
1938            | ConnectionState::Connecting { .. } => (),
1939        }
1940    }
1941
1942    /// If true, the server is mid-reset and cannot take certain actions such
1943    /// as handling synic messages or saving state.
1944    fn is_resetting(&self) -> bool {
1945        matches!(
1946            &self.inner.state,
1947            ConnectionState::Connecting {
1948                next_action: ConnectionAction::Reset,
1949                ..
1950            } | ConnectionState::Disconnecting {
1951                next_action: ConnectionAction::Reset,
1952                ..
1953            }
1954        )
1955    }
1956
1957    /// Completes a channel close operation.
1958    pub fn close_complete(&mut self, offer_id: OfferId) {
1959        let channel = &mut self.inner.channels[offer_id];
1960        tracing::info!(offer_id = offer_id.0, "closed channel");
1961        match channel.state {
1962            ChannelState::Closing {
1963                reserved_state: Some(reserved_state),
1964                ..
1965            } => {
1966                channel.state = ChannelState::Closed;
1967                if matches!(self.inner.state, ConnectionState::Connected { .. }) {
1968                    let channel_id = channel.info.expect("assigned").channel_id;
1969                    self.send_close_reserved_channel_response(
1970                        channel_id,
1971                        offer_id,
1972                        reserved_state.target,
1973                    );
1974                } else {
1975                    // Handle closing reserved channels while disconnected/ing. Since we weren't waiting
1976                    // on the channel, no need to call check_disconnected, but we do need to release it.
1977                    if Self::client_release_channel(
1978                        self.inner
1979                            .pending_messages
1980                            .sender(self.notifier, self.inner.state.is_paused()),
1981                        offer_id,
1982                        channel,
1983                        &mut self.inner.gpadls,
1984                        &mut self.inner.assigned_channels,
1985                        &mut self.inner.assigned_monitors,
1986                        None,
1987                    ) {
1988                        self.inner.channels.remove(offer_id);
1989                    }
1990                }
1991            }
1992            ChannelState::Closing { .. } => {
1993                channel.state = ChannelState::Closed;
1994            }
1995            ChannelState::ClosingClientRelease => {
1996                channel.state = ChannelState::ClientReleased;
1997                self.check_disconnected();
1998            }
1999            ChannelState::ClosingReopen { request, .. } => {
2000                channel.state = ChannelState::Closed;
2001                self.open_channel(offer_id, &request, None);
2002            }
2003
2004            ChannelState::Closed
2005            | ChannelState::ClientReleased
2006            | ChannelState::Opening { .. }
2007            | ChannelState::Open { .. }
2008            | ChannelState::Revoked
2009            | ChannelState::Reoffered
2010            | ChannelState::OpeningClientRelease => {
2011                tracing::error!(?offer_id, state = ?channel.state, "invalid close complete")
2012            }
2013        }
2014    }
2015
2016    fn send_close_reserved_channel_response(
2017        &mut self,
2018        channel_id: ChannelId,
2019        offer_id: OfferId,
2020        target: ConnectionTarget,
2021    ) {
2022        self.sender().send_message_with_target(
2023            &protocol::CloseReservedChannelResponse { channel_id },
2024            MessageTarget::ReservedChannel(offer_id, target),
2025        );
2026    }
2027
2028    /// Handles MessageType::INITIATE_CONTACT, which requests version
2029    /// negotiation.
2030    fn handle_initiate_contact(
2031        &mut self,
2032        input: &protocol::InitiateContact2,
2033        message: &SynicMessage,
2034        includes_client_id: bool,
2035    ) -> Result<(), ChannelError> {
2036        let target_info =
2037            protocol::TargetInfo::from(input.initiate_contact.interrupt_page_or_target_info);
2038
2039        let target_sint = if message.multiclient
2040            && input.initiate_contact.version_requested >= Version::Win10Rs3_1 as u32
2041        {
2042            target_info.sint()
2043        } else {
2044            SINT
2045        };
2046
2047        let target_vtl = if message.multiclient
2048            && input.initiate_contact.version_requested >= Version::Win10Rs4 as u32
2049        {
2050            target_info.vtl()
2051        } else {
2052            0
2053        };
2054
2055        let feature_flags = if input.initiate_contact.version_requested >= Version::Copper as u32 {
2056            target_info.feature_flags()
2057        } else {
2058            0
2059        };
2060
2061        // Originally, messages were always sent to processor zero.
2062        // Post-Windows 8, it became necessary to send messages to other
2063        // processors in order to support establishing channel connections
2064        // on arbitrary processors after crashing.
2065        let target_message_vp =
2066            if input.initiate_contact.version_requested >= Version::Win8_1 as u32 {
2067                input.initiate_contact.target_message_vp
2068            } else {
2069                0
2070            };
2071
2072        // Guests can send an interrupt page up to protocol Win10Rs3_1 (at which point the
2073        // interrupt page field was reused), but as of Win8 the host can ignore it as it won't be
2074        // used for channels with dedicated interrupts (which is all channels).
2075        //
2076        // V1 doesn't support dedicated interrupts and Win7 only uses dedicated interrupts for
2077        // guest-to-host, so the interrupt page is still used for host-to-guest.
2078        let interrupt_page = (input.initiate_contact.version_requested < Version::Win8 as u32
2079            && input.initiate_contact.interrupt_page_or_target_info != 0)
2080            .then_some(input.initiate_contact.interrupt_page_or_target_info);
2081
2082        // The guest must specify both monitor pages, or neither. Store this information in the
2083        // request so the response can be sent after the version check, and to the correct VTL.
2084        let monitor_page = if (input.initiate_contact.parent_to_child_monitor_page_gpa == 0)
2085            != (input.initiate_contact.child_to_parent_monitor_page_gpa == 0)
2086        {
2087            MonitorPageRequest::Invalid
2088        } else if input.initiate_contact.parent_to_child_monitor_page_gpa != 0 {
2089            MonitorPageRequest::Some(MonitorPageGpas {
2090                parent_to_child: input.initiate_contact.parent_to_child_monitor_page_gpa,
2091                child_to_parent: input.initiate_contact.child_to_parent_monitor_page_gpa,
2092            })
2093        } else {
2094            MonitorPageRequest::None
2095        };
2096
2097        // We differentiate between InitiateContact and InitiateContact2 only by size, so we need to
2098        // check the feature flags here to ensure the client ID should actually be set to the input GUID.
2099        let client_id = if FeatureFlags::from(feature_flags).client_id() {
2100            if includes_client_id {
2101                input.client_id
2102            } else {
2103                return Err(ChannelError::ParseError(
2104                    protocol::ParseError::MessageTooSmall(Some(
2105                        protocol::MessageType::INITIATE_CONTACT,
2106                    )),
2107                ));
2108            }
2109        } else {
2110            Guid::ZERO
2111        };
2112
2113        let request = InitiateContactRequest {
2114            version_requested: input.initiate_contact.version_requested,
2115            target_message_vp,
2116            monitor_page,
2117            target_sint,
2118            target_vtl,
2119            feature_flags,
2120            interrupt_page,
2121            client_id,
2122            trusted: message.trusted,
2123        };
2124        self.initiate_contact(request);
2125        Ok(())
2126    }
2127
2128    pub fn initiate_contact(&mut self, request: InitiateContactRequest) {
2129        // If the request is not for this server's VTL, inform the notifier it wasn't handled so it
2130        // can be forwarded to the correct server.
2131        let vtl = self.inner.assigned_channels.vtl as u8;
2132        if request.target_vtl != vtl {
2133            // Send a notification to a linked server (which handles a different VTL).
2134            self.notifier.forward_unhandled(request);
2135            return;
2136        }
2137
2138        if request.target_sint != SINT {
2139            tracelimit::warn_ratelimited!(
2140                "unsupported multiclient request for VTL {} SINT {}, version {:#x}",
2141                request.target_vtl,
2142                request.target_sint,
2143                request.version_requested,
2144            );
2145
2146            // Send an unsupported response to the requested SINT.
2147            self.send_version_response_with_target(
2148                None,
2149                MessageTarget::Custom(ConnectionTarget {
2150                    vp: request.target_message_vp,
2151                    sint: request.target_sint,
2152                }),
2153            );
2154
2155            return;
2156        }
2157
2158        if !self.request_disconnect(ConnectionAction::Reconnect {
2159            initiate_contact: request,
2160        }) {
2161            return;
2162        }
2163
2164        let Some(version) = self.check_version_supported(&request) else {
2165            tracelimit::warn_ratelimited!(
2166                vtl,
2167                version = request.version_requested,
2168                client_id = ?request.client_id,
2169                "Guest requested unsupported version"
2170            );
2171
2172            // Do not notify the relay in this case.
2173            self.send_version_response(None);
2174            return;
2175        };
2176
2177        tracelimit::info_ratelimited!(
2178            vtl,
2179            ?version,
2180            client_id = ?request.client_id,
2181            trusted = request.trusted,
2182            "Guest negotiated version"
2183        );
2184
2185        // Make sure we can receive incoming interrupts on the monitor page. The parent to child
2186        // page is not used as this server doesn't send monitored interrupts.
2187        let monitor_page = match request.monitor_page {
2188            MonitorPageRequest::Some(mp) => Some(mp),
2189            MonitorPageRequest::None => None,
2190            MonitorPageRequest::Invalid => {
2191                // Do not notify the relay in this case.
2192                self.send_version_response(Some((
2193                    version,
2194                    protocol::ConnectionState::FAILED_UNKNOWN_FAILURE,
2195                )));
2196
2197                return;
2198            }
2199        };
2200
2201        self.inner.state = ConnectionState::Connecting {
2202            info: ConnectionInfo {
2203                version,
2204                trusted: request.trusted,
2205                interrupt_page: request.interrupt_page,
2206                monitor_page,
2207                target_message_vp: request.target_message_vp,
2208                modifying: false,
2209                offers_sent: false,
2210                client_id: request.client_id,
2211                paused: false,
2212            },
2213            next_action: ConnectionAction::None,
2214        };
2215
2216        // Update server state and notify the relay, if any. When complete,
2217        // complete_initiate_contact will be invoked.
2218        if let Err(err) = self.notifier.modify_connection(ModifyConnectionRequest {
2219            version: Some(request.version_requested),
2220            monitor_page: monitor_page.into(),
2221            interrupt_page: request.interrupt_page.into(),
2222            target_message_vp: Some(request.target_message_vp),
2223            notify_relay: true,
2224        }) {
2225            tracelimit::error_ratelimited!(?err, "server failed to change state");
2226            self.inner.state = ConnectionState::Disconnected;
2227            self.send_version_response(Some((
2228                version,
2229                protocol::ConnectionState::FAILED_UNKNOWN_FAILURE,
2230            )));
2231        }
2232    }
2233
2234    pub(crate) fn complete_initiate_contact(&mut self, response: ModifyConnectionResponse) {
2235        let ConnectionState::Connecting {
2236            mut info,
2237            next_action,
2238        } = self.inner.state
2239        else {
2240            panic!("Invalid state for completing InitiateContact.");
2241        };
2242
2243        // Some features are handled locally without needing relay support.
2244        const LOCAL_FEATURE_FLAGS: FeatureFlags = FeatureFlags::new()
2245            .with_client_id(true)
2246            .with_confidential_channels(true);
2247
2248        let relay_feature_flags = match response {
2249            // There is no relay, or it successfully processed our request.
2250            ModifyConnectionResponse::Supported(
2251                protocol::ConnectionState::SUCCESSFUL,
2252                feature_flags,
2253            ) => feature_flags,
2254            // The relay supports the requested version, but encountered an error, so pass it
2255            // along to the guest.
2256            ModifyConnectionResponse::Supported(connection_state, feature_flags) => {
2257                tracelimit::error_ratelimited!(
2258                    ?connection_state,
2259                    "initiate contact failed because relay request failed"
2260                );
2261
2262                // We still report the supported feature flags with an error, so make sure those
2263                // are correct.
2264                info.version.feature_flags &= feature_flags | LOCAL_FEATURE_FLAGS;
2265
2266                self.send_version_response(Some((info.version, connection_state)));
2267                self.inner.state = ConnectionState::Disconnected;
2268                return;
2269            }
2270            // The relay doesn't support the requested version, so tell the guest to negotiate a new
2271            // one.
2272            ModifyConnectionResponse::Unsupported => {
2273                self.send_version_response(None);
2274                self.inner.state = ConnectionState::Disconnected;
2275                return;
2276            }
2277        };
2278
2279        // The relay responds with all the feature flags it supports, so limit the flags reported to
2280        // the guest to include only those handled by the relay or locally.
2281        info.version.feature_flags &= relay_feature_flags | LOCAL_FEATURE_FLAGS;
2282        self.inner.state = ConnectionState::Connected(info);
2283
2284        self.send_version_response(Some((info.version, protocol::ConnectionState::SUCCESSFUL)));
2285        if !matches!(next_action, ConnectionAction::None) && self.request_disconnect(next_action) {
2286            self.do_next_action(next_action);
2287        }
2288    }
2289
2290    /// Determine if a guest's requested version and feature flags are supported.
2291    fn check_version_supported(&self, request: &InitiateContactRequest) -> Option<VersionInfo> {
2292        let version = SUPPORTED_VERSIONS
2293            .iter()
2294            .find(|v| request.version_requested == **v as u32)
2295            .copied()?;
2296
2297        // The max version may be limited in order to test older protocol versions.
2298        if let Some(max_version) = self.inner.max_version {
2299            if version as u32 > max_version.version {
2300                return None;
2301            }
2302        }
2303
2304        let supported_flags = if version >= Version::Copper {
2305            // Confidential channels should only be enabled if the connection is trusted.
2306            let max_supported_flags =
2307                SUPPORTED_FEATURE_FLAGS.with_confidential_channels(request.trusted);
2308
2309            // The max features may be limited in order to test older protocol versions.
2310            if let Some(max_version) = self.inner.max_version {
2311                max_supported_flags & max_version.feature_flags
2312            } else {
2313                max_supported_flags
2314            }
2315        } else {
2316            FeatureFlags::new()
2317        };
2318
2319        let feature_flags = supported_flags & request.feature_flags.into();
2320
2321        assert!(version >= Version::Copper || feature_flags == FeatureFlags::new());
2322        if feature_flags.into_bits() != request.feature_flags {
2323            tracelimit::warn_ratelimited!(
2324                supported = feature_flags.into_bits(),
2325                requested = request.feature_flags,
2326                "Guest requested unsupported feature flags."
2327            );
2328        }
2329
2330        Some(VersionInfo {
2331            version,
2332            feature_flags,
2333        })
2334    }
2335
2336    fn send_version_response(&mut self, data: Option<(VersionInfo, protocol::ConnectionState)>) {
2337        self.send_version_response_with_target(data, MessageTarget::Default);
2338    }
2339
2340    fn send_version_response_with_target(
2341        &mut self,
2342        data: Option<(VersionInfo, protocol::ConnectionState)>,
2343        target: MessageTarget,
2344    ) {
2345        let mut response2 = protocol::VersionResponse2::new_zeroed();
2346        let response = &mut response2.version_response;
2347        let mut send_response2 = false;
2348        if let Some((version, state)) = data {
2349            // Pre-Win8, there is no way to report failures to the guest, so those should be treated
2350            // as unsupported.
2351            if state == protocol::ConnectionState::SUCCESSFUL || version.version >= Version::Win8 {
2352                response.version_supported = 1;
2353                response.connection_state = state;
2354                response.selected_version_or_connection_id =
2355                    if version.version >= Version::Win10Rs3_1 {
2356                        self.inner.child_connection_id
2357                    } else {
2358                        version.version as u32
2359                    };
2360
2361                if version.version >= Version::Copper {
2362                    response2.supported_features = version.feature_flags.into();
2363                    send_response2 = true;
2364                }
2365            }
2366        }
2367
2368        if send_response2 {
2369            self.sender().send_message_with_target(&response2, target);
2370        } else {
2371            self.sender().send_message_with_target(response, target);
2372        }
2373    }
2374
2375    /// Disconnects the guest, putting the server into `new_state` and returning
2376    /// false if there are channels that are not yet fully reset.
2377    fn request_disconnect(&mut self, new_action: ConnectionAction) -> bool {
2378        assert!(!self.is_resetting());
2379
2380        // Release all channels.
2381        let gpadls = &mut self.inner.gpadls;
2382        let vm_reset = matches!(new_action, ConnectionAction::Reset);
2383        self.inner.channels.retain(|offer_id, channel| {
2384            // Release reserved channels only if the VM is resetting
2385            (!vm_reset && channel.state.is_reserved())
2386                || !Self::client_release_channel(
2387                    self.inner
2388                        .pending_messages
2389                        .sender(self.notifier, self.inner.state.is_paused()),
2390                    offer_id,
2391                    channel,
2392                    gpadls,
2393                    &mut self.inner.assigned_channels,
2394                    &mut self.inner.assigned_monitors,
2395                    None,
2396                )
2397        });
2398
2399        // Transition to disconnected or one of the pending disconnect states,
2400        // depending on whether there are still GPADLs or channels in use by the
2401        // server.
2402        match &mut self.inner.state {
2403            ConnectionState::Disconnected => {
2404                // Cleanup open reserved channels when doing disconnected VM reset
2405                if vm_reset {
2406                    if !self.are_channels_reset(true) {
2407                        self.inner.state = ConnectionState::Disconnecting {
2408                            next_action: ConnectionAction::Reset,
2409                            modify_sent: false,
2410                        };
2411                    }
2412                } else {
2413                    assert!(self.are_channels_reset(false));
2414                }
2415            }
2416
2417            ConnectionState::Connected { .. } => {
2418                if self.are_channels_reset(vm_reset) {
2419                    self.inner.state = ConnectionState::Disconnected;
2420                } else {
2421                    self.inner.state = ConnectionState::Disconnecting {
2422                        next_action: new_action,
2423                        modify_sent: false,
2424                    };
2425                }
2426            }
2427
2428            ConnectionState::Connecting { next_action, .. }
2429            | ConnectionState::Disconnecting { next_action, .. } => {
2430                *next_action = new_action;
2431            }
2432        }
2433
2434        matches!(self.inner.state, ConnectionState::Disconnected)
2435    }
2436
2437    pub(crate) fn complete_disconnect(&mut self) {
2438        if let ConnectionState::Disconnecting {
2439            next_action,
2440            modify_sent,
2441        } = std::mem::replace(&mut self.inner.state, ConnectionState::Disconnected)
2442        {
2443            assert!(self.are_channels_reset(matches!(next_action, ConnectionAction::Reset)));
2444            if !modify_sent {
2445                tracelimit::warn_ratelimited!("unexpected modify response");
2446            }
2447
2448            self.inner.state = ConnectionState::Disconnected;
2449            self.do_next_action(next_action);
2450        } else {
2451            unreachable!("not ready for disconnect");
2452        }
2453    }
2454
2455    fn do_next_action(&mut self, action: ConnectionAction) {
2456        match action {
2457            ConnectionAction::None => {}
2458            ConnectionAction::Reset => {
2459                self.complete_reset();
2460            }
2461            ConnectionAction::SendUnloadComplete => {
2462                self.complete_unload();
2463            }
2464            ConnectionAction::Reconnect { initiate_contact } => {
2465                self.initiate_contact(initiate_contact);
2466            }
2467            ConnectionAction::SendFailedVersionResponse => {
2468                // Used when the relay didn't support the requested version, so send a failed
2469                // response.
2470                self.send_version_response(None);
2471            }
2472        }
2473    }
2474
2475    /// Handles MessageType::UNLOAD, which disconnects the guest.
2476    fn handle_unload(&mut self) {
2477        tracing::debug!(
2478            vtl = self.inner.assigned_channels.vtl as u8,
2479            state = ?self.inner.state,
2480            "VmBus received unload request from guest",
2481        );
2482
2483        if self.request_disconnect(ConnectionAction::SendUnloadComplete) {
2484            self.complete_unload();
2485        }
2486    }
2487
2488    fn complete_unload(&mut self) {
2489        self.notifier.unload_complete();
2490        if let Some(version) = self.inner.delayed_max_version.take() {
2491            self.inner.set_compatibility_version(version, false);
2492        }
2493
2494        self.sender().send_message(&protocol::UnloadComplete {});
2495        tracelimit::info_ratelimited!("Vmbus disconnected");
2496    }
2497
2498    /// Handles MessageType::REQUEST_OFFERS, which requests a list of channel offers.
2499    fn handle_request_offers(&mut self) -> Result<(), ChannelError> {
2500        let ConnectionState::Connected(info) = &mut self.inner.state else {
2501            unreachable!(
2502                "in unexpected state {:?}, should be prevented by Message::parse()",
2503                self.inner.state
2504            );
2505        };
2506
2507        if info.offers_sent {
2508            return Err(ChannelError::OffersAlreadySent);
2509        }
2510
2511        info.offers_sent = true;
2512
2513        // The guest expects channel IDs to stay consistent across hibernation and
2514        // resume, so sort the current offers before assigning channel IDs.
2515        let mut sorted_channels: Vec<_> = self
2516            .inner
2517            .channels
2518            .iter_mut()
2519            .filter(|(_, channel)| !channel.state.is_reserved())
2520            .collect();
2521
2522        sorted_channels.sort_unstable_by_key(|(_, channel)| {
2523            (
2524                channel.offer.interface_id,
2525                channel.offer.offer_order.unwrap_or(u32::MAX),
2526                channel.offer.instance_id,
2527            )
2528        });
2529
2530        for (offer_id, channel) in sorted_channels {
2531            assert!(matches!(channel.state, ChannelState::ClientReleased));
2532            assert!(channel.info.is_none());
2533
2534            channel.prepare_channel(
2535                offer_id,
2536                &mut self.inner.assigned_channels,
2537                &mut self.inner.assigned_monitors,
2538            );
2539
2540            channel.state = ChannelState::Closed;
2541            self.inner
2542                .pending_messages
2543                .sender(self.notifier, info.paused)
2544                .send_offer(channel, info.version);
2545        }
2546        self.sender().send_message(&protocol::AllOffersDelivered {});
2547
2548        Ok(())
2549    }
2550
2551    /// Sends a GPADL to the device when `ranges` is Some. Returns false if the
2552    /// GPADL should be removed because the channel is already revoked.
2553    #[must_use]
2554    fn gpadl_updated(
2555        mut sender: MessageSender<'_, N>,
2556        offer_id: OfferId,
2557        channel: &Channel,
2558        gpadl_id: GpadlId,
2559        gpadl: &Gpadl,
2560    ) -> bool {
2561        if channel.state.is_revoked() {
2562            let channel_id = channel.info.as_ref().expect("assigned").channel_id;
2563            sender.send_gpadl_created(channel_id, gpadl_id, protocol::STATUS_UNSUCCESSFUL);
2564            false
2565        } else {
2566            // Notify the channel if the GPADL is done.
2567            sender.notifier.notify(
2568                offer_id,
2569                Action::Gpadl(gpadl_id, gpadl.count, gpadl.buf.clone()),
2570            );
2571            true
2572        }
2573    }
2574
2575    /// Handles MessageType::GPADL_HEADER, which creates a new GPADL.
2576    fn handle_gpadl_header(
2577        &mut self,
2578        input: &protocol::GpadlHeader,
2579        range: &[u8],
2580    ) -> Result<(), ChannelError> {
2581        // Validate the channel ID.
2582        let (offer_id, channel) = self
2583            .inner
2584            .channels
2585            .get_by_channel_id_mut(&self.inner.assigned_channels, input.channel_id)?;
2586
2587        // GPADL body messages don't contain the channel ID, so prevent creating new
2588        // GPADLs for reserved channels to avoid GPADL ID conflicts.
2589        if channel.state.is_reserved() {
2590            return Err(ChannelError::ChannelReserved);
2591        }
2592
2593        // Create a new GPADL.
2594        let mut gpadl = Gpadl::new(input.count, input.len as usize / 8);
2595        let done = gpadl.append(range)?;
2596
2597        // Store the GPADL in the table.
2598        let gpadl = match self.inner.gpadls.entry((input.gpadl_id, offer_id)) {
2599            Entry::Vacant(entry) => entry.insert(gpadl),
2600            Entry::Occupied(_) => return Err(ChannelError::DuplicateGpadlId),
2601        };
2602
2603        // If we're not done, track the offer ID for GPADL body requests
2604        if !done
2605            && self
2606                .inner
2607                .incomplete_gpadls
2608                .insert(input.gpadl_id, offer_id)
2609                .is_some()
2610        {
2611            unreachable!("gpadl ID validated above");
2612        }
2613
2614        if done
2615            && !Self::gpadl_updated(
2616                self.inner
2617                    .pending_messages
2618                    .sender(self.notifier, self.inner.state.is_paused()),
2619                offer_id,
2620                channel,
2621                input.gpadl_id,
2622                gpadl,
2623            )
2624        {
2625            self.inner.gpadls.remove(&(input.gpadl_id, offer_id));
2626        }
2627        Ok(())
2628    }
2629
2630    /// Handles MessageType::GPADL_BODY, which adds more to an in-progress
2631    /// GPADL.
2632    fn handle_gpadl_body(
2633        &mut self,
2634        input: &protocol::GpadlBody,
2635        range: &[u8],
2636    ) -> Result<(), ChannelError> {
2637        // Find and update the GPADL.
2638        let &offer_id = self
2639            .inner
2640            .incomplete_gpadls
2641            .get(&input.gpadl_id)
2642            .ok_or(ChannelError::UnknownGpadlId)?;
2643        let gpadl = self
2644            .inner
2645            .gpadls
2646            .get_mut(&(input.gpadl_id, offer_id))
2647            .ok_or(ChannelError::UnknownGpadlId)?;
2648        let channel = &mut self.inner.channels[offer_id];
2649
2650        if gpadl.append(range)? {
2651            self.inner.incomplete_gpadls.remove(&input.gpadl_id);
2652            if !Self::gpadl_updated(
2653                self.inner
2654                    .pending_messages
2655                    .sender(self.notifier, self.inner.state.is_paused()),
2656                offer_id,
2657                channel,
2658                input.gpadl_id,
2659                gpadl,
2660            ) {
2661                self.inner.gpadls.remove(&(input.gpadl_id, offer_id));
2662            }
2663        }
2664
2665        Ok(())
2666    }
2667
2668    /// Handles MessageType::GPADL_TEARDOWN, which tears down a GPADL.
2669    fn handle_gpadl_teardown(
2670        &mut self,
2671        input: &protocol::GpadlTeardown,
2672    ) -> Result<(), ChannelError> {
2673        tracing::debug!(
2674            channel_id = input.channel_id.0,
2675            gpadl_id = input.gpadl_id.0,
2676            "Received GPADL teardown request"
2677        );
2678
2679        let (offer_id, channel) = self
2680            .inner
2681            .channels
2682            .get_by_channel_id_mut(&self.inner.assigned_channels, input.channel_id)?;
2683
2684        let gpadl = self
2685            .inner
2686            .gpadls
2687            .get_mut(&(input.gpadl_id, offer_id))
2688            .ok_or(ChannelError::UnknownGpadlId)?;
2689
2690        match gpadl.state {
2691            GpadlState::InProgress
2692            | GpadlState::Offered
2693            | GpadlState::OfferedTearingDown
2694            | GpadlState::TearingDown => {
2695                return Err(ChannelError::InvalidGpadlState);
2696            }
2697            GpadlState::Accepted => {
2698                if channel.info.as_ref().map(|info| info.channel_id) != Some(input.channel_id) {
2699                    return Err(ChannelError::WrongGpadlChannelId);
2700                }
2701
2702                // GPADL IDs must be unique during teardown. Disallow reserved
2703                // channels to avoid collisions with non-reserved channel GPADL
2704                // IDs across disconnects.
2705                if channel.state.is_reserved() {
2706                    return Err(ChannelError::ChannelReserved);
2707                }
2708
2709                if channel.state.is_revoked() {
2710                    tracing::trace!(
2711                        channel_id = input.channel_id.0,
2712                        gpadl_id = input.gpadl_id.0,
2713                        "Gpadl teardown for revoked channel"
2714                    );
2715
2716                    self.inner.gpadls.remove(&(input.gpadl_id, offer_id));
2717                    self.sender().send_gpadl_torndown(input.gpadl_id);
2718                } else {
2719                    gpadl.state = GpadlState::TearingDown;
2720                    self.notifier.notify(
2721                        offer_id,
2722                        Action::TeardownGpadl {
2723                            gpadl_id: input.gpadl_id,
2724                            post_restore: false,
2725                        },
2726                    );
2727                }
2728            }
2729        }
2730        Ok(())
2731    }
2732
2733    /// Moves a channel from the `Closed` to `Opening` state, notifying the
2734    /// device.
2735    fn open_channel(
2736        &mut self,
2737        offer_id: OfferId,
2738        input: &OpenRequest,
2739        reserved_state: Option<ReservedState>,
2740    ) {
2741        let channel = &mut self.inner.channels[offer_id];
2742        assert!(matches!(channel.state, ChannelState::Closed));
2743
2744        channel.state = ChannelState::Opening {
2745            request: *input,
2746            reserved_state,
2747        };
2748
2749        // Do not update info with the guest-provided connection ID, since the
2750        // value must be remembered if the channel is closed and re-opened.
2751        let info = channel.info.as_ref().expect("assigned");
2752        self.notifier.notify(
2753            offer_id,
2754            Action::Open(
2755                OpenParams::from_request(
2756                    info,
2757                    input,
2758                    channel.handled_monitor_info(),
2759                    reserved_state.map(|state| state.target),
2760                ),
2761                self.inner.state.get_version().expect("must be connected"),
2762            ),
2763        );
2764    }
2765
2766    /// Handles MessageType::OPEN_CHANNEL, which opens a channel.
2767    fn handle_open_channel(&mut self, input: &protocol::OpenChannel2) -> Result<(), ChannelError> {
2768        let (offer_id, channel) = self
2769            .inner
2770            .channels
2771            .get_by_channel_id_mut(&self.inner.assigned_channels, input.open_channel.channel_id)?;
2772
2773        let guest_specified_interrupt_info = self
2774            .inner
2775            .state
2776            .check_feature_flags(|ff| ff.guest_specified_signal_parameters())
2777            .then_some(SignalInfo {
2778                event_flag: input.event_flag,
2779                connection_id: input.connection_id,
2780            });
2781
2782        let flags = if self
2783            .inner
2784            .state
2785            .check_feature_flags(|ff| ff.channel_interrupt_redirection())
2786        {
2787            input.flags
2788        } else {
2789            Default::default()
2790        };
2791
2792        let request = OpenRequest {
2793            open_id: input.open_channel.open_id,
2794            ring_buffer_gpadl_id: input.open_channel.ring_buffer_gpadl_id,
2795            target_vp: input.open_channel.target_vp,
2796            downstream_ring_buffer_page_offset: input
2797                .open_channel
2798                .downstream_ring_buffer_page_offset,
2799            user_data: input.open_channel.user_data,
2800            guest_specified_interrupt_info,
2801            flags,
2802        };
2803
2804        match channel.state {
2805            ChannelState::Closed => self.open_channel(offer_id, &request, None),
2806            ChannelState::Closing { params, .. } => {
2807                // Since there is no close complete message, this can happen
2808                // after the ring buffer GPADL is released but before the server
2809                // completes the close request.
2810                channel.state = ChannelState::ClosingReopen { params, request }
2811            }
2812            ChannelState::Revoked | ChannelState::Reoffered => {}
2813
2814            ChannelState::Open { .. }
2815            | ChannelState::Opening { .. }
2816            | ChannelState::ClosingReopen { .. } => return Err(ChannelError::ChannelAlreadyOpen),
2817
2818            ChannelState::ClientReleased
2819            | ChannelState::ClosingClientRelease
2820            | ChannelState::OpeningClientRelease => unreachable!(),
2821        }
2822        Ok(())
2823    }
2824
2825    /// Handles MessageType::CLOSE_CHANNEL, which closes a channel.
2826    fn handle_close_channel(&mut self, input: &protocol::CloseChannel) -> Result<(), ChannelError> {
2827        let (offer_id, channel) = self
2828            .inner
2829            .channels
2830            .get_by_channel_id_mut(&self.inner.assigned_channels, input.channel_id)?;
2831
2832        match channel.state {
2833            ChannelState::Open {
2834                params,
2835                modify_state,
2836                reserved_state: None,
2837            } => {
2838                if modify_state.is_modifying() {
2839                    tracelimit::warn_ratelimited!(
2840                        ?modify_state,
2841                        "Client is closing the channel with a modify in progress"
2842                    )
2843                }
2844
2845                channel.state = ChannelState::Closing {
2846                    params,
2847                    reserved_state: None,
2848                };
2849                self.notifier.notify(offer_id, Action::Close);
2850            }
2851
2852            ChannelState::Open {
2853                reserved_state: Some(_),
2854                ..
2855            } => return Err(ChannelError::ChannelReserved),
2856
2857            ChannelState::Revoked | ChannelState::Reoffered => {}
2858
2859            ChannelState::Closed
2860            | ChannelState::Opening { .. }
2861            | ChannelState::Closing { .. }
2862            | ChannelState::ClosingReopen { .. } => return Err(ChannelError::ChannelNotOpen),
2863
2864            ChannelState::ClientReleased
2865            | ChannelState::ClosingClientRelease
2866            | ChannelState::OpeningClientRelease => unreachable!(),
2867        }
2868
2869        Ok(())
2870    }
2871
2872    /// Handles MessageType::OPEN_RESERVED_CHANNEL, which reserves and opens a channel.
2873    /// The version must have already been validated in parse_message.
2874    fn handle_open_reserved_channel(
2875        &mut self,
2876        input: &protocol::OpenReservedChannel,
2877        version: VersionInfo,
2878    ) -> Result<(), ChannelError> {
2879        let (offer_id, channel) = self
2880            .inner
2881            .channels
2882            .get_by_channel_id_mut(&self.inner.assigned_channels, input.channel_id)?;
2883
2884        let target = ConnectionTarget {
2885            vp: input.target_vp,
2886            sint: input.target_sint as u8,
2887        };
2888
2889        let reserved_state = Some(ReservedState { version, target });
2890
2891        let request = OpenRequest {
2892            ring_buffer_gpadl_id: input.ring_buffer_gpadl,
2893            // Interrupts are disabled for reserved channels; this matches Hyper-V behavior.
2894            target_vp: protocol::VP_INDEX_DISABLE_INTERRUPT,
2895            downstream_ring_buffer_page_offset: input.downstream_page_offset,
2896            open_id: 0,
2897            user_data: UserDefinedData::new_zeroed(),
2898            guest_specified_interrupt_info: None,
2899            flags: Default::default(),
2900        };
2901
2902        match channel.state {
2903            ChannelState::Closed => self.open_channel(offer_id, &request, reserved_state),
2904            ChannelState::Revoked | ChannelState::Reoffered => {}
2905
2906            ChannelState::Open { .. } | ChannelState::Opening { .. } => {
2907                return Err(ChannelError::ChannelAlreadyOpen);
2908            }
2909
2910            ChannelState::Closing { .. } | ChannelState::ClosingReopen { .. } => {
2911                return Err(ChannelError::InvalidChannelState);
2912            }
2913
2914            ChannelState::ClientReleased
2915            | ChannelState::ClosingClientRelease
2916            | ChannelState::OpeningClientRelease => unreachable!(),
2917        }
2918        Ok(())
2919    }
2920
2921    /// Handles MessageType::CLOSE_RESERVED_CHANNEL, which closes a reserved channel. Will send
2922    /// the response to the target provided in the request instead of the current reserved target.
2923    fn handle_close_reserved_channel(
2924        &mut self,
2925        input: &protocol::CloseReservedChannel,
2926    ) -> Result<(), ChannelError> {
2927        let (offer_id, channel) = self
2928            .inner
2929            .channels
2930            .get_by_channel_id_mut(&self.inner.assigned_channels, input.channel_id)?;
2931
2932        match channel.state {
2933            ChannelState::Open {
2934                params,
2935                reserved_state: Some(mut resvd),
2936                ..
2937            } => {
2938                resvd.target.vp = input.target_vp;
2939                resvd.target.sint = input.target_sint as u8;
2940                channel.state = ChannelState::Closing {
2941                    params,
2942                    reserved_state: Some(resvd),
2943                };
2944                self.notifier.notify(offer_id, Action::Close);
2945            }
2946
2947            ChannelState::Open {
2948                reserved_state: None,
2949                ..
2950            } => return Err(ChannelError::ChannelNotReserved),
2951
2952            ChannelState::Revoked | ChannelState::Reoffered => {}
2953
2954            ChannelState::Closed
2955            | ChannelState::Opening { .. }
2956            | ChannelState::Closing { .. }
2957            | ChannelState::ClosingReopen { .. } => return Err(ChannelError::ChannelNotOpen),
2958
2959            ChannelState::ClientReleased
2960            | ChannelState::ClosingClientRelease
2961            | ChannelState::OpeningClientRelease => unreachable!(),
2962        }
2963
2964        Ok(())
2965    }
2966
2967    /// Release all guest references on a channel, including GPADLs that are
2968    /// associated with the channel. Returns true if the channel should be
2969    /// deleted.
2970    #[must_use]
2971    fn client_release_channel(
2972        mut sender: MessageSender<'_, N>,
2973        offer_id: OfferId,
2974        channel: &mut Channel,
2975        gpadls: &mut GpadlMap,
2976        assigned_channels: &mut AssignedChannels,
2977        assigned_monitors: &mut AssignedMonitors,
2978        version: Option<VersionInfo>,
2979    ) -> bool {
2980        // Release any GPADLs that remain for this channel.
2981        gpadls.retain(|&(gpadl_id, gpadl_offer_id), gpadl| {
2982            if gpadl_offer_id != offer_id {
2983                return true;
2984            }
2985            match gpadl.state {
2986                GpadlState::InProgress => false,
2987                GpadlState::Offered => {
2988                    gpadl.state = GpadlState::OfferedTearingDown;
2989                    true
2990                }
2991                GpadlState::Accepted => {
2992                    if channel.state.is_revoked() {
2993                        // There is no need to tear down the GPADL.
2994                        false
2995                    } else {
2996                        gpadl.state = GpadlState::TearingDown;
2997                        sender.notifier.notify(
2998                            offer_id,
2999                            Action::TeardownGpadl {
3000                                gpadl_id,
3001                                post_restore: false,
3002                            },
3003                        );
3004                        true
3005                    }
3006                }
3007                GpadlState::OfferedTearingDown | GpadlState::TearingDown => true,
3008            }
3009        });
3010
3011        let remove = match &mut channel.state {
3012            ChannelState::Closed => {
3013                channel.state = ChannelState::ClientReleased;
3014                false
3015            }
3016            ChannelState::Reoffered => {
3017                if let Some(version) = version {
3018                    channel.state = ChannelState::Closed;
3019                    channel.restore_state = RestoreState::New;
3020                    sender.send_offer(channel, version);
3021                    // Do not release the channel ID.
3022                    return false;
3023                }
3024                channel.state = ChannelState::ClientReleased;
3025                false
3026            }
3027            ChannelState::Revoked => {
3028                channel.state = ChannelState::ClientReleased;
3029                true
3030            }
3031            ChannelState::Opening { .. } => {
3032                channel.state = ChannelState::OpeningClientRelease;
3033                false
3034            }
3035            ChannelState::Open { .. } => {
3036                channel.state = ChannelState::ClosingClientRelease;
3037                sender.notifier.notify(offer_id, Action::Close);
3038                false
3039            }
3040            ChannelState::Closing { .. } | ChannelState::ClosingReopen { .. } => {
3041                channel.state = ChannelState::ClosingClientRelease;
3042                false
3043            }
3044
3045            ChannelState::ClosingClientRelease
3046            | ChannelState::OpeningClientRelease
3047            | ChannelState::ClientReleased => false,
3048        };
3049
3050        assert!(channel.state.is_released());
3051
3052        channel.release_channel(offer_id, assigned_channels, assigned_monitors);
3053        remove
3054    }
3055
3056    /// Handles MessageType::REL_ID_RELEASED, which releases the guest references to a channel.
3057    fn handle_rel_id_released(
3058        &mut self,
3059        input: &protocol::RelIdReleased,
3060    ) -> Result<(), ChannelError> {
3061        let channel_id = input.channel_id;
3062        let (offer_id, channel) = self
3063            .inner
3064            .channels
3065            .get_by_channel_id_mut(&self.inner.assigned_channels, channel_id)?;
3066
3067        match channel.state {
3068            ChannelState::Closed
3069            | ChannelState::Revoked
3070            | ChannelState::Closing { .. }
3071            | ChannelState::Reoffered => {
3072                if Self::client_release_channel(
3073                    self.inner
3074                        .pending_messages
3075                        .sender(self.notifier, self.inner.state.is_paused()),
3076                    offer_id,
3077                    channel,
3078                    &mut self.inner.gpadls,
3079                    &mut self.inner.assigned_channels,
3080                    &mut self.inner.assigned_monitors,
3081                    self.inner.state.get_version(),
3082                ) {
3083                    self.inner.channels.remove(offer_id);
3084                }
3085
3086                self.check_disconnected();
3087            }
3088
3089            ChannelState::Opening { .. }
3090            | ChannelState::Open { .. }
3091            | ChannelState::ClosingReopen { .. } => return Err(ChannelError::InvalidChannelState),
3092
3093            ChannelState::ClientReleased
3094            | ChannelState::OpeningClientRelease
3095            | ChannelState::ClosingClientRelease => unreachable!(),
3096        }
3097        Ok(())
3098    }
3099
3100    /// Handles MessageType::TL_CONNECT_REQUEST, which requests for an hvsocket
3101    /// connection.
3102    fn handle_tl_connect_request(&mut self, request: protocol::TlConnectRequest2) {
3103        let version = self
3104            .inner
3105            .state
3106            .get_version()
3107            .expect("must be connected")
3108            .version;
3109
3110        let hosted_silo_unaware = version < Version::Win10Rs5;
3111        self.notifier
3112            .notify_hvsock(&HvsockConnectRequest::from_message(
3113                request,
3114                hosted_silo_unaware,
3115            ));
3116    }
3117
3118    /// Sends a message to the guest if an hvsocket connect request failed.
3119    pub fn send_tl_connect_result(&mut self, result: HvsockConnectResult) {
3120        // TODO: need save/restore handling for this... probably OK to just drop
3121        // all such requests given hvsock's general lack of save/restore
3122        // support.
3123        if !result.success && self.inner.state.check_version(Version::Win10Rs3_0) {
3124            // Windows guests care about the error code used here; using STATUS_CONNECTION_REFUSED
3125            // ensures a sensible error gets returned to the user that tried to connect to the
3126            // socket.
3127            self.sender().send_message(&protocol::TlConnectResult {
3128                service_id: result.service_id,
3129                endpoint_id: result.endpoint_id,
3130                status: protocol::STATUS_CONNECTION_REFUSED,
3131            })
3132        }
3133    }
3134
3135    /// Handles MessageType::MODIFY_CHANNEL, which allows the guest to request a
3136    /// new target VP for the channel's interrupts.
3137    fn handle_modify_channel(
3138        &mut self,
3139        request: &protocol::ModifyChannel,
3140    ) -> Result<(), ChannelError> {
3141        let result = self.modify_channel(request);
3142        if result.is_err() {
3143            self.send_modify_channel_response(request.channel_id, protocol::STATUS_UNSUCCESSFUL);
3144        }
3145
3146        result
3147    }
3148
3149    /// Modifies a channel's target VP.
3150    fn modify_channel(&mut self, request: &protocol::ModifyChannel) -> Result<(), ChannelError> {
3151        let (offer_id, channel) = self
3152            .inner
3153            .channels
3154            .get_by_channel_id_mut(&self.inner.assigned_channels, request.channel_id)?;
3155
3156        let (open_request, modify_state) = match &mut channel.state {
3157            ChannelState::Open {
3158                params,
3159                modify_state,
3160                reserved_state: None,
3161            } => (params, modify_state),
3162            _ => return Err(ChannelError::InvalidChannelState),
3163        };
3164
3165        if let ModifyState::Modifying { pending_target_vp } = modify_state {
3166            if self.inner.state.check_version(Version::Iron) {
3167                // On Iron or later, the client isn't allowed to send a ModifyChannel
3168                // request while another one is still in progress.
3169                tracelimit::warn_ratelimited!(
3170                    "Client sent new ModifyChannel before receiving ModifyChannelResponse."
3171                );
3172            } else {
3173                // On older versions, the client doesn't know if the operation is complete,
3174                // so store the latest request to execute when the current one completes.
3175                *pending_target_vp = Some(request.target_vp);
3176            }
3177        } else {
3178            self.notifier.notify(
3179                offer_id,
3180                Action::Modify {
3181                    target_vp: request.target_vp,
3182                },
3183            );
3184
3185            // Update the stored open_request so that save/restore will use the new value.
3186            open_request.target_vp = request.target_vp;
3187            *modify_state = ModifyState::Modifying {
3188                pending_target_vp: None,
3189            };
3190        }
3191
3192        Ok(())
3193    }
3194
3195    /// Complete the ModifyChannel message.
3196    ///
3197    /// N.B. The guest expects no further interrupts on the old VP at this point. This
3198    ///      is guaranteed because notify() handles updating the event port synchronously before,
3199    ///      notifying the device/relay, and all types of event port protect their VP settings
3200    ///      with locks.
3201    pub fn modify_channel_complete(&mut self, offer_id: OfferId, status: i32) {
3202        let channel = &mut self.inner.channels[offer_id];
3203
3204        if let ChannelState::Open {
3205            params,
3206            modify_state: ModifyState::Modifying { pending_target_vp },
3207            reserved_state: None,
3208        } = channel.state
3209        {
3210            channel.state = ChannelState::Open {
3211                params,
3212                modify_state: ModifyState::NotModifying,
3213                reserved_state: None,
3214            };
3215
3216            // Send the ModifyChannelResponse message if the protocol supports it.
3217            let channel_id = channel.info.as_ref().expect("assigned").channel_id;
3218            self.send_modify_channel_response(channel_id, status);
3219
3220            // Handle a pending ModifyChannel request if there is one.
3221            if let Some(target_vp) = pending_target_vp {
3222                let request = protocol::ModifyChannel {
3223                    channel_id,
3224                    target_vp,
3225                };
3226
3227                if let Err(error) = self.handle_modify_channel(&request) {
3228                    tracelimit::warn_ratelimited!(?error, "Pending ModifyChannel request failed.")
3229                }
3230            }
3231        }
3232    }
3233
3234    fn send_modify_channel_response(&mut self, channel_id: ChannelId, status: i32) {
3235        if self.inner.state.check_version(Version::Iron) {
3236            self.sender()
3237                .send_message(&protocol::ModifyChannelResponse { channel_id, status });
3238        }
3239    }
3240
3241    fn handle_modify_connection(&mut self, request: protocol::ModifyConnection) {
3242        if let Err(err) = self.modify_connection(request) {
3243            tracelimit::error_ratelimited!(?err, "modifying connection failed");
3244            self.complete_modify_connection(ModifyConnectionResponse::Supported(
3245                protocol::ConnectionState::FAILED_UNKNOWN_FAILURE,
3246                FeatureFlags::new(),
3247            ));
3248        }
3249    }
3250
3251    fn modify_connection(&mut self, request: protocol::ModifyConnection) -> anyhow::Result<()> {
3252        let ConnectionState::Connected(info) = &mut self.inner.state else {
3253            anyhow::bail!(
3254                "Invalid state for ModifyConnection request: {:?}",
3255                self.inner.state
3256            );
3257        };
3258
3259        if info.modifying {
3260            anyhow::bail!(
3261                "Duplicate ModifyConnection request, state: {:?}",
3262                self.inner.state
3263            );
3264        }
3265
3266        if (request.child_to_parent_monitor_page_gpa == 0)
3267            != (request.parent_to_child_monitor_page_gpa == 0)
3268        {
3269            anyhow::bail!("Guest must specify either both or no monitor pages, {request:?}");
3270        }
3271
3272        let monitor_page =
3273            (request.child_to_parent_monitor_page_gpa != 0).then_some(MonitorPageGpas {
3274                child_to_parent: request.child_to_parent_monitor_page_gpa,
3275                parent_to_child: request.parent_to_child_monitor_page_gpa,
3276            });
3277
3278        info.modifying = true;
3279        info.monitor_page = monitor_page;
3280        tracing::debug!("modifying connection parameters.");
3281        self.notifier.modify_connection(request.into())?;
3282
3283        Ok(())
3284    }
3285
3286    pub fn complete_modify_connection(&mut self, response: ModifyConnectionResponse) {
3287        tracing::debug!(?response, "modifying connection parameters complete");
3288
3289        // InitiateContact, Unload, and actual ModifyConnection messages are all sent to the relay
3290        // as ModifyConnection requests, so use the server state to determine how to handle the
3291        // response.
3292        match &mut self.inner.state {
3293            ConnectionState::Connecting { .. } => self.complete_initiate_contact(response),
3294            ConnectionState::Disconnecting { .. } => self.complete_disconnect(),
3295            ConnectionState::Connected(info) => {
3296                let ModifyConnectionResponse::Supported(connection_state, ..) = response else {
3297                    panic!(
3298                        "Relay should not return {:?} for a modify request with no version.",
3299                        response
3300                    );
3301                };
3302
3303                if !info.modifying {
3304                    panic!(
3305                        "ModifyConnection response while not modifying, state: {:?}",
3306                        self.inner.state
3307                    );
3308                }
3309
3310                info.modifying = false;
3311                self.sender()
3312                    .send_message(&protocol::ModifyConnectionResponse { connection_state });
3313            }
3314            _ => panic!(
3315                "Invalid state for ModifyConnection response: {:?}",
3316                self.inner.state
3317            ),
3318        }
3319    }
3320
3321    fn handle_pause(&mut self) {
3322        tracelimit::info_ratelimited!("pausing sending messages");
3323        self.sender().send_message(&protocol::PauseResponse {});
3324        let ConnectionState::Connected(info) = &mut self.inner.state else {
3325            unreachable!(
3326                "in unexpected state {:?}, should be prevented by Message::parse()",
3327                self.inner.state
3328            );
3329        };
3330        info.paused = true;
3331    }
3332
3333    /// Processes an incoming message from the guest.
3334    pub fn handle_synic_message(&mut self, message: SynicMessage) -> Result<(), ChannelError> {
3335        assert!(!self.is_resetting());
3336
3337        let version = self.inner.state.get_version();
3338        let msg = Message::parse(&message.data, version)?;
3339        tracing::trace!(?msg, message.trusted, "received vmbus message");
3340        // Do not allow untrusted messages if the connection was established
3341        // using a trusted message.
3342        //
3343        // TODO: Don't allow trusted messages if an untrusted connection was ever used.
3344        if self.inner.state.is_trusted() && !message.trusted {
3345            tracelimit::warn_ratelimited!(?msg, "Received untrusted message");
3346            return Err(ChannelError::UntrustedMessage);
3347        }
3348
3349        // Unpause channel responses if they are paused.
3350        match &mut self.inner.state {
3351            ConnectionState::Connected(info) if info.paused => {
3352                if !matches!(
3353                    msg,
3354                    Message::Resume(..)
3355                        | Message::Unload(..)
3356                        | Message::InitiateContact { .. }
3357                        | Message::InitiateContact2 { .. }
3358                ) {
3359                    tracelimit::warn_ratelimited!(?msg, "Received message while paused");
3360                    return Err(ChannelError::Paused);
3361                }
3362                tracelimit::info_ratelimited!("resuming sending messages");
3363                info.paused = false;
3364            }
3365            _ => {}
3366        }
3367
3368        match msg {
3369            Message::InitiateContact2(input, ..) => {
3370                self.handle_initiate_contact(&input, &message, true)?
3371            }
3372            Message::InitiateContact(input, ..) => {
3373                self.handle_initiate_contact(&input.into(), &message, false)?
3374            }
3375            Message::Unload(..) => self.handle_unload(),
3376            Message::RequestOffers(..) => self.handle_request_offers()?,
3377            Message::GpadlHeader(input, range) => self.handle_gpadl_header(&input, range)?,
3378            Message::GpadlBody(input, range) => self.handle_gpadl_body(&input, range)?,
3379            Message::GpadlTeardown(input, ..) => self.handle_gpadl_teardown(&input)?,
3380            Message::OpenChannel(input, ..) => self.handle_open_channel(&input.into())?,
3381            Message::OpenChannel2(input, ..) => self.handle_open_channel(&input)?,
3382            Message::CloseChannel(input, ..) => self.handle_close_channel(&input)?,
3383            Message::RelIdReleased(input, ..) => self.handle_rel_id_released(&input)?,
3384            Message::TlConnectRequest(input, ..) => self.handle_tl_connect_request(input.into()),
3385            Message::TlConnectRequest2(input, ..) => self.handle_tl_connect_request(input),
3386            Message::ModifyChannel(input, ..) => self.handle_modify_channel(&input)?,
3387            Message::ModifyConnection(input, ..) => self.handle_modify_connection(input),
3388            Message::OpenReservedChannel(input, ..) => self.handle_open_reserved_channel(
3389                &input,
3390                version.expect("version validated by Message::parse"),
3391            )?,
3392            Message::CloseReservedChannel(input, ..) => {
3393                self.handle_close_reserved_channel(&input)?
3394            }
3395            Message::Pause(protocol::Pause, ..) => self.handle_pause(),
3396            Message::Resume(protocol::Resume, ..) => {}
3397            // Messages that should only be received by a vmbus client.
3398            Message::OfferChannel(..)
3399            | Message::RescindChannelOffer(..)
3400            | Message::AllOffersDelivered(..)
3401            | Message::OpenResult(..)
3402            | Message::GpadlCreated(..)
3403            | Message::GpadlTorndown(..)
3404            | Message::VersionResponse(..)
3405            | Message::VersionResponse2(..)
3406            | Message::UnloadComplete(..)
3407            | Message::CloseReservedChannelResponse(..)
3408            | Message::TlConnectResult(..)
3409            | Message::ModifyChannelResponse(..)
3410            | Message::ModifyConnectionResponse(..)
3411            | Message::PauseResponse(..) => {
3412                unreachable!("Server received client message {:?}", msg);
3413            }
3414        }
3415        Ok(())
3416    }
3417
3418    fn get_gpadl(
3419        gpadls: &mut GpadlMap,
3420        offer_id: OfferId,
3421        gpadl_id: GpadlId,
3422    ) -> Option<&mut Gpadl> {
3423        let gpadl = gpadls.get_mut(&(gpadl_id, offer_id));
3424        if gpadl.is_none() {
3425            tracelimit::error_ratelimited!(?offer_id, ?gpadl_id, "invalid gpadl ID for channel");
3426        }
3427        gpadl
3428    }
3429
3430    /// Completes a GPADL creation, accepting it if `status >= 0`, rejecting it otherwise.
3431    pub fn gpadl_create_complete(&mut self, offer_id: OfferId, gpadl_id: GpadlId, status: i32) {
3432        let gpadl = if let Some(gpadl) = Self::get_gpadl(&mut self.inner.gpadls, offer_id, gpadl_id)
3433        {
3434            gpadl
3435        } else {
3436            return;
3437        };
3438        let retain = match gpadl.state {
3439            GpadlState::InProgress | GpadlState::TearingDown | GpadlState::Accepted => {
3440                tracelimit::error_ratelimited!(?offer_id, ?gpadl_id, ?gpadl, "invalid gpadl state");
3441                return;
3442            }
3443            GpadlState::Offered => {
3444                let channel_id = self.inner.channels[offer_id]
3445                    .info
3446                    .as_ref()
3447                    .expect("assigned")
3448                    .channel_id;
3449                self.inner
3450                    .pending_messages
3451                    .sender(self.notifier, self.inner.state.is_paused())
3452                    .send_gpadl_created(channel_id, gpadl_id, status);
3453                if status >= 0 {
3454                    gpadl.state = GpadlState::Accepted;
3455                    true
3456                } else {
3457                    false
3458                }
3459            }
3460            GpadlState::OfferedTearingDown => {
3461                if status >= 0 {
3462                    // Tear down the GPADL immediately.
3463                    self.notifier.notify(
3464                        offer_id,
3465                        Action::TeardownGpadl {
3466                            gpadl_id,
3467                            post_restore: false,
3468                        },
3469                    );
3470                    gpadl.state = GpadlState::TearingDown;
3471                    true
3472                } else {
3473                    false
3474                }
3475            }
3476        };
3477        if !retain {
3478            self.inner
3479                .gpadls
3480                .remove(&(gpadl_id, offer_id))
3481                .expect("gpadl validated above");
3482
3483            self.check_disconnected();
3484        }
3485    }
3486
3487    /// Releases a GPADL that is being torn down.
3488    pub fn gpadl_teardown_complete(&mut self, offer_id: OfferId, gpadl_id: GpadlId) {
3489        tracing::debug!(
3490            offer_id = offer_id.0,
3491            gpadl_id = gpadl_id.0,
3492            "Gpadl teardown complete"
3493        );
3494
3495        let gpadl = if let Some(gpadl) = Self::get_gpadl(&mut self.inner.gpadls, offer_id, gpadl_id)
3496        {
3497            gpadl
3498        } else {
3499            return;
3500        };
3501        let channel = &mut self.inner.channels[offer_id];
3502        match gpadl.state {
3503            GpadlState::InProgress
3504            | GpadlState::Offered
3505            | GpadlState::OfferedTearingDown
3506            | GpadlState::Accepted => {
3507                tracelimit::error_ratelimited!(?offer_id, ?gpadl_id, ?gpadl, "invalid gpadl state");
3508            }
3509            GpadlState::TearingDown => {
3510                if !channel.state.is_released() {
3511                    self.sender().send_gpadl_torndown(gpadl_id);
3512                }
3513                self.inner
3514                    .gpadls
3515                    .remove(&(gpadl_id, offer_id))
3516                    .expect("gpadl validated above");
3517
3518                self.check_disconnected();
3519            }
3520        }
3521    }
3522
3523    /// Creates a sender, in a convenient way for callers that are able to borrow all of `self`.
3524    ///
3525    /// If you cannot borrow all of `self`, you will need to use the `PendingMessages::sender`
3526    /// method instead.
3527    fn sender(&mut self) -> MessageSender<'_, N> {
3528        self.inner
3529            .pending_messages
3530            .sender(self.notifier, self.inner.state.is_paused())
3531    }
3532}
3533
3534fn revoke<N: Notifier>(
3535    mut sender: MessageSender<'_, N>,
3536    offer_id: OfferId,
3537    channel: &mut Channel,
3538    gpadls: &mut GpadlMap,
3539) -> bool {
3540    let info = match channel.state {
3541        ChannelState::Closed
3542        | ChannelState::Open { .. }
3543        | ChannelState::Opening { .. }
3544        | ChannelState::Closing { .. }
3545        | ChannelState::ClosingReopen { .. } => {
3546            channel.state = ChannelState::Revoked;
3547            Some(channel.info.as_ref().expect("assigned"))
3548        }
3549        ChannelState::Reoffered => {
3550            channel.state = ChannelState::Revoked;
3551            None
3552        }
3553        ChannelState::ClientReleased
3554        | ChannelState::OpeningClientRelease
3555        | ChannelState::ClosingClientRelease => None,
3556        // If the channel is being dropped, it may already have been revoked explicitly.
3557        ChannelState::Revoked => return true,
3558    };
3559    let retain = !channel.state.is_released();
3560
3561    // Release any GPADLs.
3562    gpadls.retain(|&(gpadl_id, gpadl_offer_id), gpadl| {
3563        if gpadl_offer_id != offer_id {
3564            return true;
3565        }
3566
3567        match gpadl.state {
3568            GpadlState::InProgress => true,
3569            GpadlState::Offered => {
3570                if let Some(info) = info {
3571                    sender.send_gpadl_created(
3572                        info.channel_id,
3573                        gpadl_id,
3574                        protocol::STATUS_UNSUCCESSFUL,
3575                    );
3576                }
3577                false
3578            }
3579            GpadlState::OfferedTearingDown => false,
3580            GpadlState::Accepted => true,
3581            GpadlState::TearingDown => {
3582                if info.is_some() {
3583                    sender.send_gpadl_torndown(gpadl_id);
3584                }
3585                false
3586            }
3587        }
3588    });
3589    if let Some(info) = info {
3590        sender.send_rescind(info);
3591    }
3592    // Revoking a channel effectively completes the restore operation for it.
3593    if channel.restore_state != RestoreState::New {
3594        channel.restore_state = RestoreState::Restored;
3595    }
3596    retain
3597}
3598
3599struct PendingMessages(VecDeque<OutgoingMessage>);
3600
3601impl PendingMessages {
3602    /// Creates a sender for the specified notifier.
3603    fn sender<'a, N: Notifier>(
3604        &'a mut self,
3605        notifier: &'a mut N,
3606        is_paused: bool,
3607    ) -> MessageSender<'a, N> {
3608        MessageSender {
3609            notifier,
3610            pending_messages: self,
3611            is_paused,
3612        }
3613    }
3614}
3615
3616/// Wraps the state needed to send messages to the guest through the notifier, and queue them if
3617/// they are not immediately sent.
3618struct MessageSender<'a, N> {
3619    notifier: &'a mut N,
3620    pending_messages: &'a mut PendingMessages,
3621    is_paused: bool,
3622}
3623
3624impl<N: Notifier> MessageSender<'_, N> {
3625    /// Sends a VMBus channel message to the guest.
3626    fn send_message<
3627        T: IntoBytes + protocol::VmbusMessage + std::fmt::Debug + Immutable + KnownLayout,
3628    >(
3629        &mut self,
3630        msg: &T,
3631    ) {
3632        let message = OutgoingMessage::new(msg);
3633
3634        tracing::trace!(typ = ?T::MESSAGE_TYPE, ?msg, "sending message");
3635        // Don't try to send the message if there are already pending messages.
3636        if !self.pending_messages.0.is_empty()
3637            || self.is_paused
3638            || !self.notifier.send_message(&message, MessageTarget::Default)
3639        {
3640            tracing::trace!("message queued");
3641            // Queue the message for retry later.
3642            self.pending_messages.0.push_back(message);
3643        }
3644    }
3645
3646    /// Sends a VMBus channel message to the guest via an alternate port.
3647    fn send_message_with_target<
3648        T: IntoBytes + protocol::VmbusMessage + std::fmt::Debug + Immutable + KnownLayout,
3649    >(
3650        &mut self,
3651        msg: &T,
3652        target: MessageTarget,
3653    ) {
3654        if target == MessageTarget::Default {
3655            self.send_message(msg);
3656        } else {
3657            tracing::trace!(typ = ?T::MESSAGE_TYPE, ?msg, "sending message");
3658            // Messages for other targets are not queued, nor are they affected
3659            // by the paused state.
3660            let message = OutgoingMessage::new(msg);
3661            if !self.notifier.send_message(&message, target) {
3662                tracelimit::warn_ratelimited!(?target, "failed to send message");
3663            }
3664        }
3665    }
3666
3667    /// Sends a channel offer message to the guest.
3668    fn send_offer(&mut self, channel: &mut Channel, version: VersionInfo) {
3669        let info = channel.info.as_ref().expect("assigned");
3670        let mut flags = channel.offer.flags;
3671        if !version.feature_flags.confidential_channels() {
3672            flags.set_confidential_ring_buffer(false);
3673            flags.set_confidential_external_memory(false);
3674        }
3675
3676        let msg = protocol::OfferChannel {
3677            interface_id: channel.offer.interface_id,
3678            instance_id: channel.offer.instance_id,
3679            rsvd: [0; 4],
3680            flags,
3681            mmio_megabytes: channel.offer.mmio_megabytes,
3682            user_defined: channel.offer.user_defined,
3683            subchannel_index: channel.offer.subchannel_index,
3684            mmio_megabytes_optional: channel.offer.mmio_megabytes_optional,
3685            channel_id: info.channel_id,
3686            monitor_id: info.monitor_id.unwrap_or(MonitorId::INVALID).0,
3687            monitor_allocated: info.monitor_id.is_some() as u8,
3688            // All channels are dedicated with Win8+ hosts.
3689            // These fields are sent to V1 guests as well, which will ignore them.
3690            is_dedicated: 1,
3691            connection_id: info.connection_id,
3692        };
3693        tracing::info!(
3694            channel_id = msg.channel_id.0,
3695            connection_id = msg.connection_id,
3696            key = %channel.offer.key(),
3697            "sending offer to guest"
3698        );
3699
3700        self.send_message(&msg);
3701    }
3702
3703    fn send_open_result(
3704        &mut self,
3705        channel_id: ChannelId,
3706        open_request: &OpenRequest,
3707        result: i32,
3708        target: MessageTarget,
3709    ) {
3710        self.send_message_with_target(
3711            &protocol::OpenResult {
3712                channel_id,
3713                open_id: open_request.open_id,
3714                status: result as u32,
3715            },
3716            target,
3717        );
3718    }
3719
3720    fn send_gpadl_created(&mut self, channel_id: ChannelId, gpadl_id: GpadlId, status: i32) {
3721        self.send_message(&protocol::GpadlCreated {
3722            channel_id,
3723            gpadl_id,
3724            status,
3725        });
3726    }
3727
3728    fn send_gpadl_torndown(&mut self, gpadl_id: GpadlId) {
3729        self.send_message(&protocol::GpadlTorndown { gpadl_id });
3730    }
3731
3732    fn send_rescind(&mut self, info: &OfferedInfo) {
3733        tracing::info!(
3734            channel_id = info.channel_id.0,
3735            "rescinding channel from guest"
3736        );
3737
3738        self.send_message(&protocol::RescindChannelOffer {
3739            channel_id: info.channel_id,
3740        });
3741    }
3742}
3743
3744#[cfg(test)]
3745mod tests {
3746    use crate::MESSAGE_CONNECTION_ID;
3747
3748    use super::*;
3749    use guid::Guid;
3750    use protocol::VmbusMessage;
3751    use std::collections::VecDeque;
3752    use std::sync::mpsc;
3753    use test_with_tracing::test;
3754    use vmbus_core::protocol::TargetInfo;
3755    use zerocopy::FromBytes;
3756
3757    fn in_msg<T: IntoBytes + Immutable + KnownLayout>(
3758        message_type: protocol::MessageType,
3759        t: T,
3760    ) -> SynicMessage {
3761        in_msg_ex(message_type, t, false, false)
3762    }
3763
3764    fn in_msg_ex<T: IntoBytes + Immutable + KnownLayout>(
3765        message_type: protocol::MessageType,
3766        t: T,
3767        multiclient: bool,
3768        trusted: bool,
3769    ) -> SynicMessage {
3770        let mut data = Vec::new();
3771        data.extend_from_slice(&message_type.0.to_ne_bytes());
3772        data.extend_from_slice(&0u32.to_ne_bytes());
3773        data.extend_from_slice(t.as_bytes());
3774        SynicMessage {
3775            data,
3776            multiclient,
3777            trusted,
3778        }
3779    }
3780
3781    #[test]
3782    fn test_version_negotiation_not_supported() {
3783        let (mut notifier, _recv) = TestNotifier::new();
3784        let mut server = Server::new(Vtl::Vtl0, MESSAGE_CONNECTION_ID, 0);
3785
3786        test_initiate_contact(&mut server, &mut notifier, 0xffffffff, 0, false, 0);
3787    }
3788
3789    #[test]
3790    fn test_version_negotiation_success() {
3791        let (mut notifier, _recv) = TestNotifier::new();
3792        let mut server = Server::new(Vtl::Vtl0, MESSAGE_CONNECTION_ID, 0);
3793
3794        test_initiate_contact(
3795            &mut server,
3796            &mut notifier,
3797            Version::Win10 as u32,
3798            0,
3799            true,
3800            0,
3801        );
3802    }
3803
3804    #[test]
3805    fn test_version_negotiation_multiclient_sint() {
3806        let (mut notifier, _recv) = TestNotifier::new();
3807        let mut server = Server::new(Vtl::Vtl0, MESSAGE_CONNECTION_ID, 0);
3808
3809        let target_info = TargetInfo::new()
3810            .with_sint(3)
3811            .with_vtl(0)
3812            .with_feature_flags(FeatureFlags::new().into());
3813
3814        server
3815            .with_notifier(&mut notifier)
3816            .handle_synic_message(in_msg_ex(
3817                protocol::MessageType::INITIATE_CONTACT,
3818                protocol::InitiateContact {
3819                    version_requested: Version::Win10Rs3_1 as u32,
3820                    target_message_vp: 0,
3821                    interrupt_page_or_target_info: target_info.into(),
3822                    parent_to_child_monitor_page_gpa: 0,
3823                    child_to_parent_monitor_page_gpa: 0,
3824                },
3825                true,
3826                false,
3827            ))
3828            .unwrap();
3829
3830        // No action is taken when a different SINT is requested, since it's not supported. An
3831        // unsupported message is sent to the requested SINT.
3832        assert!(notifier.modify_requests.is_empty());
3833        assert!(matches!(server.state, ConnectionState::Disconnected));
3834        notifier.check_message_with_target(
3835            OutgoingMessage::new(&protocol::VersionResponse {
3836                version_supported: 0,
3837                connection_state: protocol::ConnectionState::SUCCESSFUL,
3838                padding: 0,
3839                selected_version_or_connection_id: 0,
3840            }),
3841            MessageTarget::Custom(ConnectionTarget { vp: 0, sint: 3 }),
3842        );
3843
3844        // SINT is ignored if the multiclient port is not used.
3845        test_initiate_contact(
3846            &mut server,
3847            &mut notifier,
3848            Version::Win10Rs3_1 as u32,
3849            target_info.into(),
3850            true,
3851            0,
3852        );
3853    }
3854
3855    #[test]
3856    fn test_version_negotiation_multiclient_vtl() {
3857        let (mut notifier, _recv) = TestNotifier::new();
3858        let mut server = Server::new(Vtl::Vtl0, MESSAGE_CONNECTION_ID, 0);
3859
3860        let target_info = TargetInfo::new()
3861            .with_sint(SINT)
3862            .with_vtl(2)
3863            .with_feature_flags(FeatureFlags::new().into());
3864
3865        server
3866            .with_notifier(&mut notifier)
3867            .handle_synic_message(in_msg_ex(
3868                protocol::MessageType::INITIATE_CONTACT,
3869                protocol::InitiateContact {
3870                    version_requested: Version::Win10Rs4 as u32,
3871                    target_message_vp: 0,
3872                    interrupt_page_or_target_info: target_info.into(),
3873                    parent_to_child_monitor_page_gpa: 0,
3874                    child_to_parent_monitor_page_gpa: 0,
3875                },
3876                true,
3877                false,
3878            ))
3879            .unwrap();
3880
3881        let action = notifier.forward_request.take().unwrap();
3882        assert!(matches!(action, InitiateContactRequest { .. }));
3883
3884        // The VTL contact message was forwarded but no action was taken by this server.
3885        assert!(notifier.messages.is_empty());
3886        assert!(matches!(server.state, ConnectionState::Disconnected));
3887
3888        // VTL is ignored if the multiclient port is not used.
3889        test_initiate_contact(
3890            &mut server,
3891            &mut notifier,
3892            Version::Win10Rs4 as u32,
3893            target_info.into(),
3894            true,
3895            0,
3896        );
3897
3898        assert!(notifier.forward_request.is_none());
3899    }
3900
3901    #[test]
3902    fn test_version_negotiation_feature_flags() {
3903        let (mut notifier, _recv) = TestNotifier::new();
3904        let mut server = Server::new(Vtl::Vtl0, MESSAGE_CONNECTION_ID, 0);
3905
3906        // Test with no feature flags.
3907        let mut target_info = TargetInfo::new()
3908            .with_sint(SINT)
3909            .with_vtl(0)
3910            .with_feature_flags(FeatureFlags::new().into());
3911        test_initiate_contact(
3912            &mut server,
3913            &mut notifier,
3914            Version::Copper as u32,
3915            target_info.into(),
3916            true,
3917            0,
3918        );
3919
3920        // Request supported feature flags.
3921        target_info.set_feature_flags(
3922            FeatureFlags::new()
3923                .with_guest_specified_signal_parameters(true)
3924                .into(),
3925        );
3926        test_initiate_contact(
3927            &mut server,
3928            &mut notifier,
3929            Version::Copper as u32,
3930            target_info.into(),
3931            true,
3932            FeatureFlags::new()
3933                .with_guest_specified_signal_parameters(true)
3934                .into(),
3935        );
3936
3937        // Request unsupported feature flags. This will succeed and report back the supported ones.
3938        target_info.set_feature_flags(
3939            u32::from(FeatureFlags::new().with_guest_specified_signal_parameters(true))
3940                | 0xf0000000,
3941        );
3942        test_initiate_contact(
3943            &mut server,
3944            &mut notifier,
3945            Version::Copper as u32,
3946            target_info.into(),
3947            true,
3948            FeatureFlags::new()
3949                .with_guest_specified_signal_parameters(true)
3950                .into(),
3951        );
3952
3953        // Verify client ID feature flag.
3954        target_info.set_feature_flags(FeatureFlags::new().with_client_id(true).into());
3955        test_initiate_contact(
3956            &mut server,
3957            &mut notifier,
3958            Version::Copper as u32,
3959            target_info.into(),
3960            true,
3961            FeatureFlags::new().with_client_id(true).into(),
3962        );
3963    }
3964
3965    #[test]
3966    fn test_version_negotiation_interrupt_page() {
3967        let (mut notifier, _recv) = TestNotifier::new();
3968        let mut server = Server::new(Vtl::Vtl0, MESSAGE_CONNECTION_ID, 0);
3969        test_initiate_contact(
3970            &mut server,
3971            &mut notifier,
3972            Version::V1 as u32,
3973            1234,
3974            true,
3975            0,
3976        );
3977
3978        let (mut notifier, _recv) = TestNotifier::new();
3979        let mut server = Server::new(Vtl::Vtl0, MESSAGE_CONNECTION_ID, 0);
3980        test_initiate_contact(
3981            &mut server,
3982            &mut notifier,
3983            Version::Win7 as u32,
3984            1234,
3985            true,
3986            0,
3987        );
3988
3989        let (mut notifier, _recv) = TestNotifier::new();
3990        let mut server = Server::new(Vtl::Vtl0, MESSAGE_CONNECTION_ID, 0);
3991        test_initiate_contact(
3992            &mut server,
3993            &mut notifier,
3994            Version::Win8 as u32,
3995            1234,
3996            true,
3997            0,
3998        );
3999    }
4000
4001    fn test_initiate_contact(
4002        server: &mut Server,
4003        notifier: &mut TestNotifier,
4004        version: u32,
4005        target_info: u64,
4006        expect_supported: bool,
4007        expected_features: u32,
4008    ) {
4009        server
4010            .with_notifier(notifier)
4011            .handle_synic_message(in_msg(
4012                protocol::MessageType::INITIATE_CONTACT,
4013                protocol::InitiateContact2 {
4014                    initiate_contact: protocol::InitiateContact {
4015                        version_requested: version,
4016                        target_message_vp: 1,
4017                        interrupt_page_or_target_info: target_info,
4018                        parent_to_child_monitor_page_gpa: 0,
4019                        child_to_parent_monitor_page_gpa: 0,
4020                    },
4021                    client_id: guid::guid!("e6e6e6e6-e6e6-e6e6-e6e6-e6e6e6e6e6e6"),
4022                },
4023            ))
4024            .unwrap();
4025
4026        let selected_version_or_connection_id = if expect_supported {
4027            let request = notifier.next_action();
4028            let interrupt_page = if version < Version::Win8 as u32 {
4029                Update::Set(target_info)
4030            } else {
4031                Update::Reset
4032            };
4033
4034            let target_message_vp = if version < Version::Win8_1 as u32 {
4035                Some(0)
4036            } else {
4037                Some(1)
4038            };
4039
4040            assert_eq!(
4041                request,
4042                ModifyConnectionRequest {
4043                    version: Some(version),
4044                    monitor_page: Update::Reset,
4045                    interrupt_page,
4046                    target_message_vp,
4047                    ..Default::default()
4048                }
4049            );
4050
4051            server.with_notifier(notifier).complete_initiate_contact(
4052                ModifyConnectionResponse::Supported(
4053                    protocol::ConnectionState::SUCCESSFUL,
4054                    SUPPORTED_FEATURE_FLAGS,
4055                ),
4056            );
4057
4058            if version >= Version::Win10Rs3_1 as u32 {
4059                1
4060            } else {
4061                version
4062            }
4063        } else {
4064            0
4065        };
4066
4067        let version_response = protocol::VersionResponse {
4068            version_supported: if expect_supported { 1 } else { 0 },
4069            connection_state: protocol::ConnectionState::SUCCESSFUL,
4070            padding: 0,
4071            selected_version_or_connection_id,
4072        };
4073
4074        if version >= Version::Copper as u32 && expect_supported {
4075            notifier.check_message(OutgoingMessage::new(&protocol::VersionResponse2 {
4076                version_response,
4077                supported_features: expected_features,
4078            }));
4079        } else {
4080            notifier.check_message(OutgoingMessage::new(&version_response));
4081            assert_eq!(expected_features, 0);
4082        }
4083
4084        assert!(notifier.messages.is_empty());
4085        if expect_supported {
4086            assert!(matches!(server.state, ConnectionState::Connected { .. }));
4087            if version < Version::Win8_1 as u32 {
4088                assert_eq!(Some(0), notifier.target_message_vp);
4089            } else {
4090                assert_eq!(Some(1), notifier.target_message_vp);
4091            }
4092        } else {
4093            assert!(matches!(server.state, ConnectionState::Disconnected));
4094            assert!(notifier.target_message_vp.is_none());
4095        }
4096
4097        if version < Version::Win8 as u32 {
4098            assert_eq!(notifier.interrupt_page, Some(target_info));
4099        } else {
4100            assert!(notifier.interrupt_page.is_none());
4101        }
4102    }
4103
4104    struct TestNotifier {
4105        send: mpsc::Sender<(OfferId, Action)>,
4106        modify_requests: VecDeque<ModifyConnectionRequest>,
4107        messages: VecDeque<(OutgoingMessage, MessageTarget)>,
4108        hvsock_requests: Vec<HvsockConnectRequest>,
4109        forward_request: Option<InitiateContactRequest>,
4110        interrupt_page: Option<u64>,
4111        reset: bool,
4112        monitor_page: Option<MonitorPageGpas>,
4113        target_message_vp: Option<u32>,
4114        pend_messages: bool,
4115    }
4116
4117    impl TestNotifier {
4118        fn new() -> (Self, mpsc::Receiver<(OfferId, Action)>) {
4119            let (send, recv) = mpsc::channel();
4120            (
4121                Self {
4122                    send,
4123                    modify_requests: VecDeque::new(),
4124                    messages: VecDeque::new(),
4125                    hvsock_requests: Vec::new(),
4126                    forward_request: None,
4127                    interrupt_page: None,
4128                    reset: false,
4129                    monitor_page: None,
4130                    target_message_vp: None,
4131                    pend_messages: false,
4132                },
4133                recv,
4134            )
4135        }
4136
4137        fn check_message(&mut self, message: OutgoingMessage) {
4138            self.check_message_with_target(message, MessageTarget::Default);
4139        }
4140
4141        fn check_message_with_target(&mut self, message: OutgoingMessage, target: MessageTarget) {
4142            assert_eq!(self.messages.pop_front().unwrap(), (message, target));
4143            assert!(self.messages.is_empty());
4144        }
4145
4146        fn get_message<T: VmbusMessage + FromBytes + Immutable + KnownLayout>(&mut self) -> T {
4147            let (message, _) = self.messages.pop_front().unwrap();
4148            let (header, data) = protocol::MessageHeader::read_from_prefix(message.data()).unwrap();
4149
4150            assert_eq!(header.message_type(), T::MESSAGE_TYPE);
4151            T::read_from_prefix(data).unwrap().0 // TODO: zerocopy: use-rest-of-range (https://github.com/microsoft/openvmm/issues/759)
4152        }
4153
4154        fn check_messages(&mut self, messages: &[OutgoingMessage]) {
4155            let messages: Vec<_> = messages
4156                .iter()
4157                .map(|m| (m.clone(), MessageTarget::Default))
4158                .collect();
4159            assert_eq!(self.messages, messages.as_slice());
4160            self.messages.clear();
4161        }
4162
4163        fn is_reset(&mut self) -> bool {
4164            std::mem::replace(&mut self.reset, false)
4165        }
4166
4167        fn check_reset(&mut self) {
4168            assert!(self.is_reset());
4169            assert!(self.monitor_page.is_none());
4170            assert!(self.target_message_vp.is_none());
4171        }
4172
4173        fn next_action(&mut self) -> ModifyConnectionRequest {
4174            self.modify_requests.pop_front().unwrap()
4175        }
4176    }
4177
4178    impl Notifier for TestNotifier {
4179        fn notify(&mut self, offer_id: OfferId, action: Action) {
4180            tracing::debug!(?offer_id, ?action, "notify");
4181            self.send.send((offer_id, action)).unwrap()
4182        }
4183
4184        fn forward_unhandled(&mut self, request: InitiateContactRequest) {
4185            assert!(self.forward_request.is_none());
4186            self.forward_request = Some(request);
4187        }
4188
4189        fn modify_connection(&mut self, request: ModifyConnectionRequest) -> anyhow::Result<()> {
4190            match request.monitor_page {
4191                Update::Unchanged => (),
4192                Update::Reset => self.monitor_page = None,
4193                Update::Set(value) => self.monitor_page = Some(value),
4194            }
4195
4196            if let Some(vp) = request.target_message_vp {
4197                self.target_message_vp = Some(vp);
4198            }
4199
4200            match request.interrupt_page {
4201                Update::Unchanged => (),
4202                Update::Reset => self.interrupt_page = None,
4203                Update::Set(value) => self.interrupt_page = Some(value),
4204            }
4205
4206            self.modify_requests.push_back(request);
4207            Ok(())
4208        }
4209
4210        fn send_message(&mut self, message: &OutgoingMessage, target: MessageTarget) -> bool {
4211            if self.pend_messages {
4212                return false;
4213            }
4214
4215            self.messages.push_back((message.clone(), target));
4216            true
4217        }
4218
4219        fn notify_hvsock(&mut self, request: &HvsockConnectRequest) {
4220            tracing::debug!(?request, "notify_hvsock");
4221            // There is no hvsocket listener, so just drop everything.
4222            // N.B. No HvsockConnectResult will be sent to indicate failure.
4223            self.hvsock_requests.push(*request);
4224        }
4225
4226        fn reset_complete(&mut self) {
4227            self.monitor_page = None;
4228            self.target_message_vp = None;
4229            self.reset = true;
4230        }
4231
4232        fn unload_complete(&mut self) {}
4233    }
4234
4235    #[test]
4236    fn test_channel_lifetime() {
4237        test_channel_lifetime_helper(Version::Win10Rs5, FeatureFlags::new());
4238    }
4239
4240    #[test]
4241    fn test_channel_lifetime_iron() {
4242        test_channel_lifetime_helper(Version::Iron, FeatureFlags::new());
4243    }
4244
4245    #[test]
4246    fn test_channel_lifetime_copper() {
4247        test_channel_lifetime_helper(Version::Copper, FeatureFlags::new());
4248    }
4249
4250    #[test]
4251    fn test_channel_lifetime_copper_guest_signal() {
4252        test_channel_lifetime_helper(
4253            Version::Copper,
4254            FeatureFlags::new().with_guest_specified_signal_parameters(true),
4255        );
4256    }
4257
4258    #[test]
4259    fn test_channel_lifetime_copper_open_flags() {
4260        test_channel_lifetime_helper(
4261            Version::Copper,
4262            FeatureFlags::new().with_channel_interrupt_redirection(true),
4263        );
4264    }
4265
4266    fn test_channel_lifetime_helper(version: Version, feature_flags: FeatureFlags) {
4267        let (mut notifier, recv) = TestNotifier::new();
4268        let mut server = Server::new(Vtl::Vtl0, MESSAGE_CONNECTION_ID, 0);
4269        let interface_id = Guid::new_random();
4270        let instance_id = Guid::new_random();
4271        let offer_id = server
4272            .with_notifier(&mut notifier)
4273            .offer_channel(OfferParamsInternal {
4274                interface_name: "test".to_owned(),
4275                instance_id,
4276                interface_id,
4277                ..Default::default()
4278            })
4279            .unwrap();
4280
4281        let mut target_info = TargetInfo::new()
4282            .with_sint(SINT)
4283            .with_vtl(2)
4284            .with_feature_flags(FeatureFlags::new().into());
4285        if version >= Version::Copper {
4286            target_info.set_feature_flags(feature_flags.into());
4287        }
4288
4289        server
4290            .with_notifier(&mut notifier)
4291            .handle_synic_message(in_msg(
4292                protocol::MessageType::INITIATE_CONTACT,
4293                protocol::InitiateContact {
4294                    version_requested: version as u32,
4295                    target_message_vp: 0,
4296                    interrupt_page_or_target_info: target_info.into(),
4297                    parent_to_child_monitor_page_gpa: 0,
4298                    child_to_parent_monitor_page_gpa: 0,
4299                },
4300            ))
4301            .unwrap();
4302
4303        let request = notifier.next_action();
4304        assert_eq!(
4305            request,
4306            ModifyConnectionRequest {
4307                version: Some(version as u32),
4308                monitor_page: Update::Reset,
4309                interrupt_page: Update::Reset,
4310                target_message_vp: Some(0),
4311                ..Default::default()
4312            }
4313        );
4314
4315        server
4316            .with_notifier(&mut notifier)
4317            .complete_initiate_contact(ModifyConnectionResponse::Supported(
4318                protocol::ConnectionState::SUCCESSFUL,
4319                SUPPORTED_FEATURE_FLAGS,
4320            ));
4321
4322        let version_response = protocol::VersionResponse {
4323            version_supported: 1,
4324            selected_version_or_connection_id: 1,
4325            ..FromZeros::new_zeroed()
4326        };
4327
4328        if version >= Version::Copper {
4329            notifier.check_message(OutgoingMessage::new(&protocol::VersionResponse2 {
4330                version_response,
4331                supported_features: feature_flags.into(),
4332            }));
4333        } else {
4334            notifier.check_message(OutgoingMessage::new(&version_response));
4335        }
4336
4337        server
4338            .with_notifier(&mut notifier)
4339            .handle_synic_message(in_msg(protocol::MessageType::REQUEST_OFFERS, ()))
4340            .unwrap();
4341
4342        let channel_id = ChannelId(1);
4343        notifier.check_messages(&[
4344            OutgoingMessage::new(&protocol::OfferChannel {
4345                interface_id,
4346                instance_id,
4347                channel_id,
4348                connection_id: 0x2001,
4349                is_dedicated: 1,
4350                monitor_id: 0xff,
4351                ..protocol::OfferChannel::new_zeroed()
4352            }),
4353            OutgoingMessage::new(&protocol::AllOffersDelivered {}),
4354        ]);
4355
4356        let open_channel = protocol::OpenChannel {
4357            channel_id,
4358            open_id: 1,
4359            ring_buffer_gpadl_id: GpadlId(1),
4360            target_vp: 3,
4361            downstream_ring_buffer_page_offset: 2,
4362            user_data: UserDefinedData::new_zeroed(),
4363        };
4364
4365        let mut event_flag = 1;
4366        let mut connection_id = 0x2001;
4367        let mut expected_flags = protocol::OpenChannelFlags::new();
4368        if version >= Version::Copper
4369            && (feature_flags.guest_specified_signal_parameters()
4370                || feature_flags.channel_interrupt_redirection())
4371        {
4372            if feature_flags.channel_interrupt_redirection() {
4373                expected_flags.set_redirect_interrupt(true);
4374            }
4375
4376            if feature_flags.guest_specified_signal_parameters() {
4377                event_flag = 2;
4378                connection_id = 0x2002;
4379            }
4380
4381            server
4382                .with_notifier(&mut notifier)
4383                .handle_synic_message(in_msg(
4384                    protocol::MessageType::OPEN_CHANNEL,
4385                    protocol::OpenChannel2 {
4386                        open_channel,
4387                        event_flag: 2,
4388                        connection_id: 0x2002,
4389                        flags: (u16::from(
4390                            protocol::OpenChannelFlags::new().with_redirect_interrupt(true),
4391                        ) | 0xabc)
4392                            .into(), // a real flag and some junk
4393                    },
4394                ))
4395                .unwrap();
4396        } else {
4397            server
4398                .with_notifier(&mut notifier)
4399                .handle_synic_message(in_msg(protocol::MessageType::OPEN_CHANNEL, open_channel))
4400                .unwrap();
4401        }
4402
4403        let (id, action) = recv.recv().unwrap();
4404        assert_eq!(id, offer_id);
4405        let Action::Open(op, ..) = action else {
4406            panic!("unexpected action: {:?}", action);
4407        };
4408        assert_eq!(op.open_data.ring_gpadl_id, GpadlId(1));
4409        assert_eq!(op.open_data.ring_offset, 2);
4410        assert_eq!(op.open_data.target_vp, 3);
4411        assert_eq!(op.open_data.event_flag, event_flag);
4412        assert_eq!(op.open_data.connection_id, connection_id);
4413        assert_eq!(op.connection_id, connection_id);
4414        assert_eq!(op.event_flag, event_flag);
4415        assert_eq!(op.monitor_info, None);
4416        assert_eq!(op.flags, expected_flags);
4417
4418        server
4419            .with_notifier(&mut notifier)
4420            .open_complete(offer_id, 0);
4421
4422        notifier.check_message(OutgoingMessage::new(&protocol::OpenResult {
4423            channel_id,
4424            open_id: 1,
4425            status: 0,
4426        }));
4427
4428        server
4429            .with_notifier(&mut notifier)
4430            .handle_synic_message(in_msg(
4431                protocol::MessageType::MODIFY_CHANNEL,
4432                protocol::ModifyChannel {
4433                    channel_id,
4434                    target_vp: 4,
4435                },
4436            ))
4437            .unwrap();
4438
4439        let (id, action) = recv.recv().unwrap();
4440        assert_eq!(id, offer_id);
4441        assert!(matches!(action, Action::Modify { target_vp: 4 }));
4442
4443        server
4444            .with_notifier(&mut notifier)
4445            .modify_channel_complete(id, 0);
4446
4447        if version >= Version::Iron {
4448            notifier.check_message(OutgoingMessage::new(&protocol::ModifyChannelResponse {
4449                channel_id,
4450                status: 0,
4451            }));
4452        }
4453
4454        assert!(notifier.messages.is_empty());
4455
4456        server.with_notifier(&mut notifier).revoke_channel(offer_id);
4457
4458        server
4459            .with_notifier(&mut notifier)
4460            .handle_synic_message(in_msg(
4461                protocol::MessageType::REL_ID_RELEASED,
4462                protocol::RelIdReleased { channel_id },
4463            ))
4464            .unwrap();
4465    }
4466
4467    #[test]
4468    fn test_hvsock() {
4469        test_hvsock_helper(Version::Win10, false);
4470    }
4471
4472    #[test]
4473    fn test_hvsock_rs3() {
4474        test_hvsock_helper(Version::Win10Rs3_0, false);
4475    }
4476
4477    #[test]
4478    fn test_hvsock_rs5() {
4479        test_hvsock_helper(Version::Win10Rs5, false);
4480        test_hvsock_helper(Version::Win10Rs5, true);
4481    }
4482
4483    fn test_hvsock_helper(version: Version, force_small_message: bool) {
4484        let (mut notifier, _recv) = TestNotifier::new();
4485        let mut server = Server::new(Vtl::Vtl0, MESSAGE_CONNECTION_ID, 0);
4486
4487        server
4488            .with_notifier(&mut notifier)
4489            .handle_synic_message(in_msg(
4490                protocol::MessageType::INITIATE_CONTACT,
4491                protocol::InitiateContact {
4492                    version_requested: version as u32,
4493                    target_message_vp: 0,
4494                    interrupt_page_or_target_info: 0,
4495                    parent_to_child_monitor_page_gpa: 0,
4496                    child_to_parent_monitor_page_gpa: 0,
4497                },
4498            ))
4499            .unwrap();
4500
4501        let request = notifier.next_action();
4502        assert_eq!(
4503            request,
4504            ModifyConnectionRequest {
4505                version: Some(version as u32),
4506                monitor_page: Update::Reset,
4507                interrupt_page: Update::Reset,
4508                target_message_vp: Some(0),
4509                ..Default::default()
4510            }
4511        );
4512
4513        server
4514            .with_notifier(&mut notifier)
4515            .complete_initiate_contact(ModifyConnectionResponse::Supported(
4516                protocol::ConnectionState::SUCCESSFUL,
4517                SUPPORTED_FEATURE_FLAGS,
4518            ));
4519
4520        // Discard the version response message.
4521        notifier.messages.pop_front();
4522
4523        let service_id = Guid::new_random();
4524        let endpoint_id = Guid::new_random();
4525        let request_msg = if version >= Version::Win10Rs5 && !force_small_message {
4526            in_msg(
4527                protocol::MessageType::TL_CONNECT_REQUEST,
4528                protocol::TlConnectRequest2 {
4529                    base: protocol::TlConnectRequest {
4530                        service_id,
4531                        endpoint_id,
4532                    },
4533                    silo_id: Guid::ZERO,
4534                },
4535            )
4536        } else {
4537            in_msg(
4538                protocol::MessageType::TL_CONNECT_REQUEST,
4539                protocol::TlConnectRequest {
4540                    service_id,
4541                    endpoint_id,
4542                },
4543            )
4544        };
4545
4546        server
4547            .with_notifier(&mut notifier)
4548            .handle_synic_message(request_msg)
4549            .unwrap();
4550
4551        let request = notifier.hvsock_requests.pop().unwrap();
4552        assert_eq!(request.service_id, service_id);
4553        assert_eq!(request.endpoint_id, endpoint_id);
4554        assert!(notifier.hvsock_requests.is_empty());
4555
4556        // Notify the guest of connection failure.
4557        server
4558            .with_notifier(&mut notifier)
4559            .send_tl_connect_result(HvsockConnectResult::from_request(&request, false));
4560
4561        if version >= Version::Win10Rs3_0 {
4562            notifier.check_message(OutgoingMessage::new(&protocol::TlConnectResult {
4563                service_id: request.service_id,
4564                endpoint_id: request.endpoint_id,
4565                status: protocol::STATUS_CONNECTION_REFUSED,
4566            }));
4567        }
4568
4569        assert!(notifier.messages.is_empty());
4570    }
4571
4572    struct TestEnv {
4573        server: Server,
4574        notifier: TestNotifier,
4575        version: Option<VersionInfo>,
4576        _recv: mpsc::Receiver<(OfferId, Action)>,
4577    }
4578
4579    impl TestEnv {
4580        fn new() -> Self {
4581            let (notifier, _recv) = TestNotifier::new();
4582            let server = Server::new(Vtl::Vtl0, MESSAGE_CONNECTION_ID, 0);
4583            Self {
4584                server,
4585                notifier,
4586                version: None,
4587                _recv,
4588            }
4589        }
4590
4591        fn c(&mut self) -> ServerWithNotifier<'_, TestNotifier> {
4592            self.server.with_notifier(&mut self.notifier)
4593        }
4594
4595        // Completes a reset operation if the server send a modify request as part of it. This
4596        // shouldn't be called if the server was not connected or had no open channels or gpadls
4597        // during the reset.
4598        fn complete_reset(&mut self) {
4599            let _ = self.next_action();
4600            self.c()
4601                .complete_modify_connection(ModifyConnectionResponse::Supported(
4602                    protocol::ConnectionState::SUCCESSFUL,
4603                    SUPPORTED_FEATURE_FLAGS,
4604                ));
4605        }
4606
4607        fn offer(&mut self, id: u32) -> OfferId {
4608            self.offer_inner(id, id, MnfUsage::Disabled, None, OfferFlags::new())
4609        }
4610
4611        fn offer_with_mnf(&mut self, id: u32) -> OfferId {
4612            self.offer_inner(
4613                id,
4614                id,
4615                MnfUsage::Enabled {
4616                    latency: Duration::from_micros(100),
4617                },
4618                None,
4619                OfferFlags::new(),
4620            )
4621        }
4622
4623        fn offer_with_preset_mnf(&mut self, id: u32, monitor_id: u8) -> OfferId {
4624            self.offer_inner(
4625                id,
4626                id,
4627                MnfUsage::Relayed { monitor_id },
4628                None,
4629                OfferFlags::new(),
4630            )
4631        }
4632
4633        fn offer_with_order(
4634            &mut self,
4635            interface_id: u32,
4636            instance_id: u32,
4637            order: Option<u32>,
4638        ) -> OfferId {
4639            self.offer_inner(
4640                interface_id,
4641                instance_id,
4642                MnfUsage::Disabled,
4643                order,
4644                OfferFlags::new(),
4645            )
4646        }
4647
4648        fn offer_with_flags(&mut self, id: u32, flags: OfferFlags) -> OfferId {
4649            self.offer_inner(id, id, MnfUsage::Disabled, None, flags)
4650        }
4651
4652        fn offer_inner(
4653            &mut self,
4654            interface_id: u32,
4655            instance_id: u32,
4656            use_mnf: MnfUsage,
4657            offer_order: Option<u32>,
4658            flags: OfferFlags,
4659        ) -> OfferId {
4660            self.c()
4661                .offer_channel(OfferParamsInternal {
4662                    instance_id: Guid {
4663                        data1: instance_id,
4664                        ..Guid::ZERO
4665                    },
4666                    interface_id: Guid {
4667                        data1: interface_id,
4668                        ..Guid::ZERO
4669                    },
4670                    use_mnf,
4671                    offer_order,
4672                    flags,
4673                    ..Default::default()
4674                })
4675                .unwrap()
4676        }
4677
4678        fn open(&mut self, id: u32) {
4679            self.c()
4680                .handle_open_channel(&protocol::OpenChannel2 {
4681                    open_channel: protocol::OpenChannel {
4682                        channel_id: ChannelId(id),
4683                        ..FromZeros::new_zeroed()
4684                    },
4685                    ..FromZeros::new_zeroed()
4686                })
4687                .unwrap()
4688        }
4689
4690        fn close(&mut self, id: u32) -> Result<(), ChannelError> {
4691            self.c().handle_close_channel(&protocol::CloseChannel {
4692                channel_id: ChannelId(id),
4693            })
4694        }
4695
4696        fn open_reserved(&mut self, id: u32, target_vp: u32, target_sint: u32) {
4697            let version = self.server.state.get_version().expect("vmbus connected");
4698
4699            self.c()
4700                .handle_open_reserved_channel(
4701                    &protocol::OpenReservedChannel {
4702                        channel_id: ChannelId(id),
4703                        target_vp,
4704                        target_sint,
4705                        ring_buffer_gpadl: GpadlId(id),
4706                        ..FromZeros::new_zeroed()
4707                    },
4708                    version,
4709                )
4710                .unwrap()
4711        }
4712
4713        fn close_reserved(&mut self, id: u32, target_vp: u32, target_sint: u32) {
4714            self.c()
4715                .handle_close_reserved_channel(&protocol::CloseReservedChannel {
4716                    channel_id: ChannelId(id),
4717                    target_vp,
4718                    target_sint,
4719                })
4720                .unwrap();
4721        }
4722
4723        fn gpadl(&mut self, channel_id: u32, gpadl_id: u32) {
4724            self.c()
4725                .handle_gpadl_header(
4726                    &protocol::GpadlHeader {
4727                        channel_id: ChannelId(channel_id),
4728                        gpadl_id: GpadlId(gpadl_id),
4729                        count: 1,
4730                        len: 16,
4731                    },
4732                    [1u64, 0u64].as_bytes(),
4733                )
4734                .unwrap();
4735        }
4736
4737        fn teardown_gpadl(&mut self, channel_id: u32, gpadl_id: u32) {
4738            self.c()
4739                .handle_gpadl_teardown(&protocol::GpadlTeardown {
4740                    channel_id: ChannelId(channel_id),
4741                    gpadl_id: GpadlId(gpadl_id),
4742                })
4743                .unwrap();
4744        }
4745
4746        fn release(&mut self, id: u32) {
4747            self.c()
4748                .handle_rel_id_released(&protocol::RelIdReleased {
4749                    channel_id: ChannelId(id),
4750                })
4751                .unwrap();
4752        }
4753
4754        fn connect(&mut self, version: Version, feature_flags: FeatureFlags) {
4755            self.start_connect(version, feature_flags, false);
4756            self.complete_connect();
4757        }
4758
4759        fn connect_trusted(&mut self, version: Version, feature_flags: FeatureFlags) {
4760            self.start_connect(version, feature_flags, true);
4761            self.complete_connect();
4762        }
4763
4764        fn start_connect(&mut self, version: Version, feature_flags: FeatureFlags, trusted: bool) {
4765            self.version = Some(VersionInfo {
4766                version,
4767                feature_flags,
4768            });
4769
4770            let result = self.c().handle_synic_message(in_msg_ex(
4771                protocol::MessageType::INITIATE_CONTACT,
4772                protocol::InitiateContact2 {
4773                    initiate_contact: protocol::InitiateContact {
4774                        version_requested: version as u32,
4775                        interrupt_page_or_target_info: TargetInfo::new()
4776                            .with_sint(SINT)
4777                            .with_vtl(0)
4778                            .with_feature_flags(feature_flags.into())
4779                            .into(),
4780                        child_to_parent_monitor_page_gpa: 0x123f000,
4781                        parent_to_child_monitor_page_gpa: 0x321f000,
4782                        ..FromZeros::new_zeroed()
4783                    },
4784                    client_id: Guid::ZERO,
4785                },
4786                false,
4787                trusted,
4788            ));
4789            assert!(result.is_ok());
4790
4791            let request = self.notifier.next_action();
4792            assert_eq!(
4793                request,
4794                ModifyConnectionRequest {
4795                    version: Some(version as u32),
4796                    monitor_page: Update::Set(MonitorPageGpas {
4797                        child_to_parent: 0x123f000,
4798                        parent_to_child: 0x321f000,
4799                    }),
4800                    interrupt_page: Update::Reset,
4801                    target_message_vp: Some(0),
4802                    ..Default::default()
4803                }
4804            );
4805        }
4806
4807        fn complete_connect(&mut self) {
4808            self.c()
4809                .complete_initiate_contact(ModifyConnectionResponse::Supported(
4810                    protocol::ConnectionState::SUCCESSFUL,
4811                    SUPPORTED_FEATURE_FLAGS,
4812                ));
4813
4814            let version = self.version.unwrap();
4815            if version.version >= Version::Copper {
4816                let response = self.notifier.get_message::<protocol::VersionResponse2>();
4817                assert_eq!(response.version_response.version_supported, 1);
4818                self.version = Some(VersionInfo {
4819                    version: version.version,
4820                    feature_flags: version.feature_flags & response.supported_features.into(),
4821                })
4822            } else {
4823                let response = self.notifier.get_message::<protocol::VersionResponse>();
4824                assert_eq!(response.version_supported, 1);
4825            }
4826        }
4827
4828        fn send_message(&mut self, message: SynicMessage) {
4829            self.try_send_message(message).unwrap();
4830        }
4831
4832        fn try_send_message(&mut self, message: SynicMessage) -> Result<(), ChannelError> {
4833            self.c().handle_synic_message(message)
4834        }
4835
4836        fn next_action(&mut self) -> ModifyConnectionRequest {
4837            self.notifier.next_action()
4838        }
4839    }
4840
4841    /// Ensure that channels can be offered at each stage of connection.
4842    #[test]
4843    fn test_hot_add() {
4844        let mut env = TestEnv::new();
4845        let offer_id1 = env.offer(1);
4846        let result = env.c().handle_initiate_contact(
4847            &protocol::InitiateContact2 {
4848                initiate_contact: protocol::InitiateContact {
4849                    version_requested: Version::Win10 as u32,
4850                    ..FromZeros::new_zeroed()
4851                },
4852                ..FromZeros::new_zeroed()
4853            },
4854            &SynicMessage::default(),
4855            true,
4856        );
4857        assert!(result.is_ok());
4858        let offer_id2 = env.offer(2);
4859        env.c()
4860            .complete_initiate_contact(ModifyConnectionResponse::Supported(
4861                protocol::ConnectionState::SUCCESSFUL,
4862                SUPPORTED_FEATURE_FLAGS,
4863            ));
4864        let offer_id3 = env.offer(3);
4865        env.c().handle_request_offers().unwrap();
4866        let offer_id4 = env.offer(4);
4867        env.open(1);
4868        env.open(2);
4869        env.open(3);
4870        env.open(4);
4871        env.c().open_complete(offer_id1, 0);
4872        env.c().open_complete(offer_id2, 0);
4873        env.c().open_complete(offer_id3, 0);
4874        env.c().open_complete(offer_id4, 0);
4875        env.c().reset();
4876        env.c().close_complete(offer_id1);
4877        env.c().close_complete(offer_id2);
4878        env.c().close_complete(offer_id3);
4879        env.c().close_complete(offer_id4);
4880        env.complete_reset();
4881        assert!(env.notifier.is_reset());
4882    }
4883
4884    #[test]
4885    fn test_save_restore_with_no_connection() {
4886        let mut env = TestEnv::new();
4887
4888        let offer_id1 = env.offer(1);
4889        let _offer_id2 = env.offer(2);
4890
4891        let state = env.server.save();
4892        env.c().reset();
4893        assert!(env.notifier.is_reset());
4894        env.c().restore(state).unwrap();
4895        env.c().restore_channel(offer_id1, false).unwrap();
4896    }
4897
4898    #[test]
4899    fn test_save_restore_with_connection() {
4900        let mut env = TestEnv::new();
4901
4902        let offer_id1 = env.offer_with_mnf(1);
4903        let offer_id2 = env.offer(2);
4904        let offer_id3 = env.offer_with_mnf(3);
4905        let offer_id4 = env.offer(4);
4906        let offer_id5 = env.offer_with_mnf(5);
4907        let offer_id6 = env.offer(6);
4908        let offer_id7 = env.offer(7);
4909        let offer_id8 = env.offer(8);
4910        let offer_id9 = env.offer(9);
4911        let offer_id10 = env.offer(10);
4912
4913        let expected_monitor = MonitorPageGpas {
4914            child_to_parent: 0x123f000,
4915            parent_to_child: 0x321f000,
4916        };
4917
4918        env.connect(Version::Win10, FeatureFlags::new());
4919        assert_eq!(env.notifier.monitor_page, Some(expected_monitor));
4920
4921        env.c().handle_request_offers().unwrap();
4922        assert_eq!(env.server.assigned_monitors.bitmap(), 7);
4923
4924        env.open(1);
4925        env.open(2);
4926        env.open(3);
4927        env.open(5);
4928
4929        env.c().open_complete(offer_id1, 0);
4930        env.c().open_complete(offer_id2, 0);
4931        env.c().open_complete(offer_id5, 0);
4932
4933        env.gpadl(1, 10);
4934        env.c().gpadl_create_complete(offer_id1, GpadlId(10), 0);
4935        env.gpadl(1, 11);
4936        env.gpadl(2, 20);
4937        env.c().gpadl_create_complete(offer_id2, GpadlId(20), 0);
4938        env.gpadl(2, 21);
4939        env.gpadl(3, 30);
4940        env.c().gpadl_create_complete(offer_id3, GpadlId(30), 0);
4941        env.gpadl(3, 31);
4942
4943        // Test Opening, Open, and Closing save for reserved channels
4944        env.open_reserved(7, 1, SINT.into());
4945        env.open_reserved(8, 2, SINT.into());
4946        env.open_reserved(9, 3, SINT.into());
4947        env.c().open_complete(offer_id8, 0);
4948        env.c().open_complete(offer_id9, 0);
4949        env.close_reserved(9, 3, SINT.into());
4950
4951        // Revoke an offer but don't have the "guest" release it, so we can then mark it as
4952        // reoffered.
4953        env.c().revoke_channel(offer_id10);
4954        let offer_id10 = env.offer(10);
4955
4956        let state = env.server.save();
4957
4958        env.c().reset();
4959
4960        env.c().close_complete(offer_id1);
4961        env.c().close_complete(offer_id2);
4962        env.c().open_complete(offer_id3, -1);
4963        env.c().close_complete(offer_id5);
4964        env.c().open_complete(offer_id7, -1);
4965        env.c().close_complete(offer_id8);
4966        env.c().close_complete(offer_id9);
4967
4968        env.c().gpadl_teardown_complete(offer_id1, GpadlId(10));
4969        env.c().gpadl_create_complete(offer_id1, GpadlId(11), -1);
4970        env.c().gpadl_teardown_complete(offer_id2, GpadlId(20));
4971        env.c().gpadl_create_complete(offer_id2, GpadlId(21), -1);
4972        env.c().gpadl_teardown_complete(offer_id3, GpadlId(30));
4973        env.c().gpadl_create_complete(offer_id3, GpadlId(31), -1);
4974
4975        env.complete_reset();
4976        env.notifier.check_reset();
4977
4978        env.c().revoke_channel(offer_id5);
4979        env.c().revoke_channel(offer_id6);
4980
4981        env.c().restore(state.clone()).unwrap();
4982
4983        env.c().revoke_channel(offer_id1);
4984        env.c().revoke_channel(offer_id4);
4985        env.c().restore_channel(offer_id3, false).unwrap();
4986        let offer_id5 = env.offer_with_mnf(5);
4987        env.c().restore_channel(offer_id5, true).unwrap();
4988        env.c().restore_channel(offer_id7, false).unwrap();
4989        env.c().restore_channel(offer_id8, true).unwrap();
4990        env.c().restore_channel(offer_id9, true).unwrap();
4991        env.c().restore_channel(offer_id10, false).unwrap();
4992        assert!(matches!(
4993            env.server.channels[offer_id10].state,
4994            ChannelState::Reoffered
4995        ));
4996
4997        env.c().revoke_unclaimed_channels();
4998
4999        assert_eq!(env.notifier.monitor_page, Some(expected_monitor));
5000        assert_eq!(env.notifier.target_message_vp, Some(0));
5001
5002        assert_eq!(env.server.assigned_monitors.bitmap(), 6);
5003        env.release(1);
5004        env.release(2);
5005        env.release(4);
5006
5007        // Check reserved channels have been restored to the same state
5008        env.c().open_complete(offer_id7, 0);
5009        env.close_reserved(8, 2, SINT.into());
5010        env.c().close_complete(offer_id8);
5011        env.c().close_complete(offer_id9);
5012
5013        env.c().reset();
5014
5015        env.c().open_complete(offer_id3, -1);
5016        env.c().gpadl_teardown_complete(offer_id3, GpadlId(30));
5017        env.c().gpadl_create_complete(offer_id3, GpadlId(31), -1);
5018        env.c().close_complete(offer_id5);
5019        env.c().close_complete(offer_id7);
5020
5021        env.complete_reset();
5022        env.notifier.check_reset();
5023
5024        env.c().restore(state).unwrap();
5025        env.c().restore_channel(offer_id3, false).unwrap();
5026        assert_eq!(env.notifier.monitor_page, Some(expected_monitor));
5027        assert_eq!(env.notifier.target_message_vp, Some(0));
5028    }
5029
5030    #[test]
5031    fn test_save_restore_connecting() {
5032        let mut env = TestEnv::new();
5033
5034        let offer_id1 = env.offer_with_mnf(1);
5035        let _offer_id2 = env.offer(2);
5036
5037        env.start_connect(Version::Win10, FeatureFlags::new(), false);
5038        assert_eq!(
5039            env.notifier.monitor_page,
5040            Some(MonitorPageGpas {
5041                child_to_parent: 0x123f000,
5042                parent_to_child: 0x321f000
5043            })
5044        );
5045
5046        let state = env.server.save();
5047
5048        env.c().reset();
5049        // We have to "complete" the connection to let the reset go through.
5050        env.complete_connect();
5051        env.notifier.check_reset();
5052
5053        env.c().restore(state).unwrap();
5054        env.c().restore_channel(offer_id1, false).unwrap();
5055        assert_eq!(
5056            env.notifier.monitor_page,
5057            Some(MonitorPageGpas {
5058                child_to_parent: 0x123f000,
5059                parent_to_child: 0x321f000
5060            })
5061        );
5062
5063        // Restore should resend the modify connection request.
5064        let request = env.next_action();
5065        assert_eq!(
5066            request,
5067            ModifyConnectionRequest {
5068                version: Some(Version::Win10 as u32),
5069                monitor_page: Update::Set(MonitorPageGpas {
5070                    child_to_parent: 0x123f000,
5071                    parent_to_child: 0x321f000,
5072                }),
5073                interrupt_page: Update::Reset,
5074                target_message_vp: Some(0),
5075                ..Default::default()
5076            }
5077        );
5078
5079        assert_eq!(Some(0), env.notifier.target_message_vp);
5080
5081        // We can successfully complete connecting after restore.
5082        env.complete_connect();
5083    }
5084
5085    #[test]
5086    fn test_save_restore_modifying() {
5087        let mut env = TestEnv::new();
5088        env.connect(
5089            Version::Copper,
5090            FeatureFlags::new().with_modify_connection(true),
5091        );
5092
5093        let expected = MonitorPageGpas {
5094            parent_to_child: 0x123f000,
5095            child_to_parent: 0x321f000,
5096        };
5097
5098        env.send_message(in_msg(
5099            protocol::MessageType::MODIFY_CONNECTION,
5100            protocol::ModifyConnection {
5101                parent_to_child_monitor_page_gpa: expected.parent_to_child,
5102                child_to_parent_monitor_page_gpa: expected.child_to_parent,
5103            },
5104        ));
5105
5106        // Discard ModifyConnectionRequest
5107        env.next_action();
5108
5109        assert_eq!(env.notifier.monitor_page, Some(expected));
5110
5111        let state = env.server.save();
5112        env.c().reset();
5113        env.notifier.check_reset();
5114
5115        env.c().restore(state).unwrap();
5116
5117        // Restore should have resent the request.
5118        let request = env.next_action();
5119        assert_eq!(
5120            request,
5121            ModifyConnectionRequest {
5122                monitor_page: Update::Set(MonitorPageGpas {
5123                    parent_to_child: 0x123f000,
5124                    child_to_parent: 0x321f000,
5125                }),
5126                interrupt_page: Update::Reset,
5127                target_message_vp: Some(0),
5128                ..Default::default()
5129            }
5130        );
5131
5132        assert_eq!(env.notifier.monitor_page, Some(expected));
5133
5134        // We can complete the modify request after restore.
5135        env.c()
5136            .complete_modify_connection(ModifyConnectionResponse::Supported(
5137                protocol::ConnectionState::SUCCESSFUL,
5138                SUPPORTED_FEATURE_FLAGS,
5139            ));
5140
5141        env.notifier
5142            .check_message(OutgoingMessage::new(&protocol::ModifyConnectionResponse {
5143                connection_state: protocol::ConnectionState::SUCCESSFUL,
5144            }));
5145    }
5146
5147    #[test]
5148    fn test_save_restore_disconnected_reserved() {
5149        let mut env = TestEnv::new();
5150
5151        let offer_id1 = env.offer(1);
5152        let _offer_id2 = env.offer(2);
5153        let _offer_id3 = env.offer(3);
5154
5155        env.connect(Version::Copper, FeatureFlags::new());
5156        env.c().handle_request_offers().unwrap();
5157
5158        env.gpadl(1, 1);
5159        env.c().gpadl_create_complete(offer_id1, GpadlId(1), 0);
5160        env.open_reserved(1, 0, 3);
5161        env.c().open_complete(offer_id1, protocol::STATUS_SUCCESS);
5162        env.c().handle_unload();
5163
5164        let state = env.server.save();
5165        let mut env = TestEnv::new();
5166        let offer_id1 = env.offer(1);
5167        let offer_id2 = env.offer(2);
5168        let offer_id3 = env.offer(3);
5169
5170        env.c().restore(state).unwrap();
5171
5172        // This will panic if the reserved channel was not restored.
5173        env.c().restore_channel(offer_id1, true).unwrap();
5174        env.c().restore_channel(offer_id2, false).unwrap();
5175        env.c().restore_channel(offer_id3, false).unwrap();
5176
5177        // Make sure the gpadl was restored as well.
5178        assert!(env.server.gpadls.contains_key(&(GpadlId(1), offer_id1)));
5179    }
5180
5181    #[test]
5182    fn test_pending_messages() {
5183        let mut env = TestEnv::new();
5184
5185        let offer_id1 = env.offer(1);
5186        let offer_id2 = env.offer(2);
5187        let offer_id3 = env.offer(3);
5188
5189        env.connect(Version::Copper, FeatureFlags::new());
5190        env.c().handle_request_offers().unwrap();
5191
5192        env.notifier.messages.clear();
5193        env.notifier.pend_messages = true;
5194        env.open_reserved(2, 4, SINT.into());
5195        env.c().open_complete(offer_id2, protocol::STATUS_SUCCESS);
5196
5197        // Reserved channel message should not be queued, but just discarded if it cannot be sent.
5198        assert!(env.notifier.messages.is_empty());
5199        assert!(!env.server.has_pending_messages());
5200
5201        env.gpadl(1, 10);
5202        env.c()
5203            .gpadl_create_complete(offer_id1, GpadlId(10), protocol::STATUS_SUCCESS);
5204
5205        // The next message should still be queued because there is already a queued message.
5206        env.notifier.pend_messages = true;
5207        env.open(3);
5208        env.c().open_complete(offer_id3, protocol::STATUS_SUCCESS);
5209
5210        // No messages were received.
5211        assert!(env.notifier.messages.is_empty());
5212        assert!(env.server.has_pending_messages());
5213        env.notifier.pend_messages = false;
5214
5215        let state = env.server.save();
5216
5217        // Create a new env instead of resetting because the gpadl blocks the reset until released.
5218        let mut env = TestEnv::new();
5219
5220        let offer_id1 = env.offer(1);
5221        let offer_id2 = env.offer(2);
5222        let offer_id3 = env.offer(3);
5223
5224        env.c().restore(state).unwrap();
5225        env.c().restore_channel(offer_id1, false).unwrap();
5226        env.c().restore_channel(offer_id2, true).unwrap();
5227        env.c().restore_channel(offer_id3, true).unwrap();
5228
5229        // The messages should be pending again.
5230        assert!(env.server.has_pending_messages());
5231        let mut pending_messages = Vec::new();
5232        let r = env.server.poll_flush_pending_messages(|msg| {
5233            pending_messages.push(msg.clone());
5234            Poll::Ready(())
5235        });
5236        assert!(r.is_ready());
5237        assert_eq!(pending_messages.len(), 2);
5238        assert_eq!(
5239            protocol::MessageHeader::read_from_prefix(pending_messages[0].data())
5240                .unwrap()
5241                .0
5242                .message_type(),
5243            protocol::MessageType::GPADL_CREATED
5244        );
5245
5246        assert_eq!(
5247            protocol::MessageHeader::read_from_prefix(pending_messages[1].data())
5248                .unwrap()
5249                .0
5250                .message_type(),
5251            protocol::MessageType::OPEN_CHANNEL_RESULT
5252        );
5253
5254        assert!(!env.server.has_pending_messages());
5255    }
5256
5257    #[test]
5258    fn test_modify_connection() {
5259        let mut env = TestEnv::new();
5260        env.connect(
5261            Version::Copper,
5262            FeatureFlags::new().with_modify_connection(true),
5263        );
5264
5265        env.send_message(in_msg(
5266            protocol::MessageType::MODIFY_CONNECTION,
5267            protocol::ModifyConnection {
5268                parent_to_child_monitor_page_gpa: 5,
5269                child_to_parent_monitor_page_gpa: 6,
5270            },
5271        ));
5272
5273        assert_eq!(
5274            env.notifier.monitor_page,
5275            Some(MonitorPageGpas {
5276                parent_to_child: 5,
5277                child_to_parent: 6
5278            })
5279        );
5280
5281        let request = env.next_action();
5282        assert_eq!(
5283            request,
5284            ModifyConnectionRequest {
5285                monitor_page: Update::Set(MonitorPageGpas {
5286                    child_to_parent: 6,
5287                    parent_to_child: 5,
5288                }),
5289                ..Default::default()
5290            }
5291        );
5292
5293        env.c()
5294            .complete_modify_connection(ModifyConnectionResponse::Supported(
5295                protocol::ConnectionState::FAILED_UNKNOWN_FAILURE,
5296                SUPPORTED_FEATURE_FLAGS,
5297            ));
5298
5299        env.notifier
5300            .check_message(OutgoingMessage::new(&protocol::ModifyConnectionResponse {
5301                connection_state: protocol::ConnectionState::FAILED_UNKNOWN_FAILURE,
5302            }));
5303    }
5304
5305    #[test]
5306    fn test_modify_connection_unsupported() {
5307        let mut env = TestEnv::new();
5308        env.connect(Version::Copper, FeatureFlags::new());
5309
5310        let err = env
5311            .try_send_message(in_msg(
5312                protocol::MessageType::MODIFY_CONNECTION,
5313                protocol::ModifyConnection {
5314                    parent_to_child_monitor_page_gpa: 5,
5315                    child_to_parent_monitor_page_gpa: 6,
5316                },
5317            ))
5318            .unwrap_err();
5319
5320        assert!(matches!(
5321            err,
5322            ChannelError::ParseError(protocol::ParseError::InvalidMessageType(
5323                protocol::MessageType::MODIFY_CONNECTION
5324            ))
5325        ));
5326    }
5327
5328    #[test]
5329    fn test_reserved_channels() {
5330        let mut env = TestEnv::new();
5331
5332        let offer_id1 = env.offer(1);
5333        let offer_id2 = env.offer(2);
5334        let offer_id3 = env.offer(3);
5335
5336        env.connect(Version::Win10, FeatureFlags::new());
5337        env.c().handle_request_offers().unwrap();
5338
5339        // Check gpadl doesn't prevent unload or get torndown on disconnect
5340        env.gpadl(1, 10);
5341        env.c().gpadl_create_complete(offer_id1, GpadlId(10), 0);
5342
5343        env.notifier.messages.clear();
5344
5345        // Open responses should be sent to the provided target
5346        env.open_reserved(1, 1, SINT.into());
5347        env.c().open_complete(offer_id1, 0);
5348        env.notifier.check_message_with_target(
5349            OutgoingMessage::new(&protocol::OpenResult {
5350                channel_id: ChannelId(1),
5351                ..FromZeros::new_zeroed()
5352            }),
5353            MessageTarget::ReservedChannel(offer_id1, ConnectionTarget { vp: 1, sint: SINT }),
5354        );
5355        env.open_reserved(2, 2, SINT.into());
5356        env.c().open_complete(offer_id2, 0);
5357        env.open_reserved(3, 3, SINT.into());
5358        env.c().open_complete(offer_id3, 0);
5359
5360        // This should fail
5361        assert!(matches!(env.close(2), Err(ChannelError::ChannelReserved)));
5362
5363        // Reserved channels and gpadls should stay open across unloads
5364        env.c().handle_unload();
5365
5366        // Closing while disconnected should work
5367        env.close_reserved(2, 2, SINT.into());
5368        env.c().close_complete(offer_id2);
5369
5370        env.notifier.messages.clear();
5371        env.connect(Version::Copper, FeatureFlags::new());
5372        env.c().handle_request_offers().unwrap();
5373
5374        // Check reserved gpadl gets torndown on reset
5375        // Duplicate GPADL IDs across different channels should also work
5376        env.gpadl(2, 10);
5377        env.c().gpadl_create_complete(offer_id2, GpadlId(10), 0);
5378
5379        // Reopening the same offer should work
5380        env.open_reserved(2, 3, SINT.into());
5381        env.c().open_complete(offer_id2, 0);
5382
5383        env.notifier.messages.clear();
5384
5385        // The channel should still be open after disconnect/reconnect
5386        // and close responses should be sent to the provided target
5387        env.close_reserved(1, 4, SINT.into());
5388        env.c().close_complete(offer_id1);
5389        env.notifier.check_message_with_target(
5390            OutgoingMessage::new(&protocol::CloseReservedChannelResponse {
5391                channel_id: ChannelId(1),
5392            }),
5393            MessageTarget::ReservedChannel(offer_id1, ConnectionTarget { vp: 4, sint: SINT }),
5394        );
5395        env.teardown_gpadl(1, 10);
5396        env.c().gpadl_teardown_complete(offer_id1, GpadlId(10));
5397
5398        // Reset should force reserved channels closed
5399        env.c().reset();
5400        env.c().close_complete(offer_id2);
5401        env.c().gpadl_teardown_complete(offer_id2, GpadlId(10));
5402        env.c().close_complete(offer_id3);
5403
5404        env.complete_reset();
5405        assert!(env.notifier.is_reset());
5406    }
5407
5408    #[test]
5409    fn test_disconnected_reset() {
5410        let mut env = TestEnv::new();
5411
5412        let offer_id1 = env.offer(1);
5413
5414        env.connect(Version::Win10, FeatureFlags::new());
5415        env.c().handle_request_offers().unwrap();
5416
5417        env.gpadl(1, 10);
5418        env.c().gpadl_create_complete(offer_id1, GpadlId(10), 0);
5419        env.open_reserved(1, 1, SINT.into());
5420        env.c().open_complete(offer_id1, 0);
5421
5422        env.c().handle_unload();
5423
5424        // Reset while disconnected should cleanup reserved channels
5425        // and complete disconnect automatically
5426        env.c().reset();
5427        env.c().close_complete(offer_id1);
5428        env.c().gpadl_teardown_complete(offer_id1, GpadlId(10));
5429
5430        env.complete_reset();
5431        assert!(env.notifier.is_reset());
5432
5433        let offer_id2 = env.offer(2);
5434
5435        env.notifier.messages.clear();
5436        env.connect(Version::Win10, FeatureFlags::new());
5437        env.c().handle_request_offers().unwrap();
5438
5439        env.gpadl(2, 20);
5440        env.c().gpadl_create_complete(offer_id2, GpadlId(20), 0);
5441        env.open_reserved(2, 2, SINT.into());
5442        env.c().open_complete(offer_id2, 0);
5443
5444        env.c().handle_unload();
5445
5446        env.close_reserved(2, 2, SINT.into());
5447        env.c().close_complete(offer_id2);
5448        env.c().gpadl_teardown_complete(offer_id2, GpadlId(20));
5449
5450        env.c().reset();
5451        assert!(env.notifier.is_reset());
5452    }
5453
5454    #[test]
5455    fn test_mnf_channel() {
5456        let mut env = TestEnv::new();
5457
5458        // This test combines server-handled and preset MNF IDs, which can't happen normally, but
5459        // it simplifies the test.
5460        let _offer_id1 = env.offer(1);
5461        let _offer_id2 = env.offer_with_mnf(2);
5462        let _offer_id3 = env.offer_with_preset_mnf(3, 5);
5463
5464        env.connect(Version::Copper, FeatureFlags::new());
5465        env.c().handle_request_offers().unwrap();
5466
5467        // Preset monitor ID should not be in the bitmap.
5468        assert_eq!(env.server.assigned_monitors.bitmap(), 1);
5469
5470        env.notifier.check_messages(&[
5471            OutgoingMessage::new(&protocol::OfferChannel {
5472                interface_id: Guid {
5473                    data1: 1,
5474                    ..Guid::ZERO
5475                },
5476                instance_id: Guid {
5477                    data1: 1,
5478                    ..Guid::ZERO
5479                },
5480                channel_id: ChannelId(1),
5481                connection_id: 0x2001,
5482                is_dedicated: 1,
5483                monitor_id: 0xff,
5484                ..protocol::OfferChannel::new_zeroed()
5485            }),
5486            OutgoingMessage::new(&protocol::OfferChannel {
5487                interface_id: Guid {
5488                    data1: 2,
5489                    ..Guid::ZERO
5490                },
5491                instance_id: Guid {
5492                    data1: 2,
5493                    ..Guid::ZERO
5494                },
5495                channel_id: ChannelId(2),
5496                connection_id: 0x2002,
5497                is_dedicated: 1,
5498                monitor_id: 0,
5499                monitor_allocated: 1,
5500                ..protocol::OfferChannel::new_zeroed()
5501            }),
5502            OutgoingMessage::new(&protocol::OfferChannel {
5503                interface_id: Guid {
5504                    data1: 3,
5505                    ..Guid::ZERO
5506                },
5507                instance_id: Guid {
5508                    data1: 3,
5509                    ..Guid::ZERO
5510                },
5511                channel_id: ChannelId(3),
5512                connection_id: 0x2003,
5513                is_dedicated: 1,
5514                monitor_id: 5,
5515                monitor_allocated: 1,
5516                ..protocol::OfferChannel::new_zeroed()
5517            }),
5518            OutgoingMessage::new(&protocol::AllOffersDelivered {}),
5519        ])
5520    }
5521
5522    #[test]
5523    fn test_channel_id_order() {
5524        let mut env = TestEnv::new();
5525
5526        let _offer_id1 = env.offer(3);
5527        let _offer_id2 = env.offer(10);
5528        let _offer_id3 = env.offer(5);
5529        let _offer_id4 = env.offer(17);
5530        let _offer_id5 = env.offer_with_order(5, 6, Some(2));
5531        let _offer_id6 = env.offer_with_order(5, 8, Some(1));
5532        let _offer_id7 = env.offer_with_order(5, 1, None);
5533
5534        env.connect(Version::Win10, FeatureFlags::new());
5535        env.c().handle_request_offers().unwrap();
5536
5537        env.notifier.check_messages(&[
5538            OutgoingMessage::new(&protocol::OfferChannel {
5539                interface_id: Guid {
5540                    data1: 3,
5541                    ..Guid::ZERO
5542                },
5543                instance_id: Guid {
5544                    data1: 3,
5545                    ..Guid::ZERO
5546                },
5547                channel_id: ChannelId(1),
5548                connection_id: 0x2001,
5549                is_dedicated: 1,
5550                monitor_id: 0xff,
5551                ..protocol::OfferChannel::new_zeroed()
5552            }),
5553            OutgoingMessage::new(&protocol::OfferChannel {
5554                interface_id: Guid {
5555                    data1: 5,
5556                    ..Guid::ZERO
5557                },
5558                instance_id: Guid {
5559                    data1: 8,
5560                    ..Guid::ZERO
5561                },
5562                channel_id: ChannelId(2),
5563                connection_id: 0x2002,
5564                is_dedicated: 1,
5565                monitor_id: 0xff,
5566                ..protocol::OfferChannel::new_zeroed()
5567            }),
5568            OutgoingMessage::new(&protocol::OfferChannel {
5569                interface_id: Guid {
5570                    data1: 5,
5571                    ..Guid::ZERO
5572                },
5573                instance_id: Guid {
5574                    data1: 6,
5575                    ..Guid::ZERO
5576                },
5577                channel_id: ChannelId(3),
5578                connection_id: 0x2003,
5579                is_dedicated: 1,
5580                monitor_id: 0xff,
5581                ..protocol::OfferChannel::new_zeroed()
5582            }),
5583            OutgoingMessage::new(&protocol::OfferChannel {
5584                interface_id: Guid {
5585                    data1: 5,
5586                    ..Guid::ZERO
5587                },
5588                instance_id: Guid {
5589                    data1: 1,
5590                    ..Guid::ZERO
5591                },
5592                channel_id: ChannelId(4),
5593                connection_id: 0x2004,
5594                is_dedicated: 1,
5595                monitor_id: 0xff,
5596                ..protocol::OfferChannel::new_zeroed()
5597            }),
5598            OutgoingMessage::new(&protocol::OfferChannel {
5599                interface_id: Guid {
5600                    data1: 5,
5601                    ..Guid::ZERO
5602                },
5603                instance_id: Guid {
5604                    data1: 5,
5605                    ..Guid::ZERO
5606                },
5607                channel_id: ChannelId(5),
5608                connection_id: 0x2005,
5609                is_dedicated: 1,
5610                monitor_id: 0xff,
5611                ..protocol::OfferChannel::new_zeroed()
5612            }),
5613            OutgoingMessage::new(&protocol::OfferChannel {
5614                interface_id: Guid {
5615                    data1: 10,
5616                    ..Guid::ZERO
5617                },
5618                instance_id: Guid {
5619                    data1: 10,
5620                    ..Guid::ZERO
5621                },
5622                channel_id: ChannelId(6),
5623                connection_id: 0x2006,
5624                is_dedicated: 1,
5625                monitor_id: 0xff,
5626                ..protocol::OfferChannel::new_zeroed()
5627            }),
5628            OutgoingMessage::new(&protocol::OfferChannel {
5629                interface_id: Guid {
5630                    data1: 17,
5631                    ..Guid::ZERO
5632                },
5633                instance_id: Guid {
5634                    data1: 17,
5635                    ..Guid::ZERO
5636                },
5637                channel_id: ChannelId(7),
5638                connection_id: 0x2007,
5639                is_dedicated: 1,
5640                monitor_id: 0xff,
5641                ..protocol::OfferChannel::new_zeroed()
5642            }),
5643            OutgoingMessage::new(&protocol::AllOffersDelivered {}),
5644        ])
5645    }
5646
5647    #[test]
5648    fn test_confidential_connection() {
5649        let mut env = TestEnv::new();
5650        env.connect_trusted(
5651            Version::Copper,
5652            FeatureFlags::new().with_confidential_channels(true),
5653        );
5654
5655        assert_eq!(
5656            env.version.unwrap(),
5657            VersionInfo {
5658                version: Version::Copper,
5659                feature_flags: FeatureFlags::new().with_confidential_channels(true)
5660            }
5661        );
5662
5663        env.offer(1); // non-confidential
5664        env.offer_with_flags(2, OfferFlags::new().with_confidential_ring_buffer(true));
5665        env.offer_with_flags(
5666            3,
5667            OfferFlags::new()
5668                .with_confidential_ring_buffer(true)
5669                .with_confidential_external_memory(true),
5670        );
5671
5672        // Untrusted messages are rejected when the connection is trusted.
5673        let error = env
5674            .try_send_message(in_msg(
5675                protocol::MessageType::REQUEST_OFFERS,
5676                protocol::RequestOffers {},
5677            ))
5678            .unwrap_err();
5679
5680        assert!(matches!(error, ChannelError::UntrustedMessage));
5681        assert!(env.notifier.messages.is_empty());
5682
5683        // Trusted messages are accepted.
5684        env.send_message(in_msg_ex(
5685            protocol::MessageType::REQUEST_OFFERS,
5686            protocol::RequestOffers {},
5687            false,
5688            true,
5689        ));
5690
5691        let offer = env.notifier.get_message::<protocol::OfferChannel>();
5692        assert_eq!(offer.channel_id, ChannelId(1));
5693        assert_eq!(offer.flags, OfferFlags::new());
5694
5695        let offer = env.notifier.get_message::<protocol::OfferChannel>();
5696        assert_eq!(offer.channel_id, ChannelId(2));
5697        assert_eq!(
5698            offer.flags,
5699            OfferFlags::new().with_confidential_ring_buffer(true)
5700        );
5701
5702        let offer = env.notifier.get_message::<protocol::OfferChannel>();
5703        assert_eq!(offer.channel_id, ChannelId(3));
5704        assert_eq!(
5705            offer.flags,
5706            OfferFlags::new()
5707                .with_confidential_ring_buffer(true)
5708                .with_confidential_external_memory(true)
5709        );
5710
5711        env.notifier
5712            .check_message(OutgoingMessage::new(&protocol::AllOffersDelivered {}));
5713    }
5714
5715    #[test]
5716    fn test_confidential_channels_unsupported() {
5717        let mut env = TestEnv::new();
5718
5719        // A trusted connection without confidential channels is weird, but it makes sure the server
5720        // looks at the flag, not the trusted state.
5721        env.connect_trusted(Version::Copper, FeatureFlags::new());
5722
5723        assert_eq!(
5724            env.version.unwrap(),
5725            VersionInfo {
5726                version: Version::Copper,
5727                feature_flags: FeatureFlags::new()
5728            }
5729        );
5730
5731        env.offer_with_flags(1, OfferFlags::new().with_enumerate_device_interface(true)); // non-confidential
5732        env.offer_with_flags(
5733            2,
5734            OfferFlags::new()
5735                .with_named_pipe_mode(true)
5736                .with_confidential_ring_buffer(true)
5737                .with_confidential_external_memory(true),
5738        );
5739
5740        env.send_message(in_msg_ex(
5741            protocol::MessageType::REQUEST_OFFERS,
5742            protocol::RequestOffers {},
5743            false,
5744            true,
5745        ));
5746
5747        let offer = env.notifier.get_message::<protocol::OfferChannel>();
5748        assert_eq!(offer.channel_id, ChannelId(1));
5749        assert_eq!(
5750            offer.flags,
5751            OfferFlags::new().with_enumerate_device_interface(true)
5752        );
5753
5754        // The confidential channel flags are not sent without the feature flag.
5755        let offer = env.notifier.get_message::<protocol::OfferChannel>();
5756        assert_eq!(offer.channel_id, ChannelId(2));
5757        assert_eq!(offer.flags, OfferFlags::new().with_named_pipe_mode(true));
5758
5759        env.notifier
5760            .check_message(OutgoingMessage::new(&protocol::AllOffersDelivered {}));
5761    }
5762
5763    #[test]
5764    fn test_confidential_channels_untrusted() {
5765        let mut env = TestEnv::new();
5766
5767        env.connect(
5768            Version::Copper,
5769            FeatureFlags::new().with_confidential_channels(true),
5770        );
5771
5772        // The server should not offer confidential channel support to untrusted clients, even if
5773        // requested.
5774        assert_eq!(
5775            env.version.unwrap(),
5776            VersionInfo {
5777                version: Version::Copper,
5778                feature_flags: FeatureFlags::new()
5779            }
5780        );
5781    }
5782}