Skip to main content

vmbus_server/channels/
saved_state.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4use super::MnfUsage;
5use super::Notifier;
6use super::OfferError;
7use super::OfferParamsInternal;
8use super::OfferedInfo;
9use super::RestoreState;
10use super::SUPPORTED_FEATURE_FLAGS;
11use guid::Guid;
12pub use inner::SavedState;
13use mesh::payload::Protobuf;
14use std::fmt::Display;
15use thiserror::Error;
16use vmbus_channel::bus::OfferKey;
17use vmbus_core::MaxVersionInfo;
18use vmbus_core::protocol;
19use vmbus_core::protocol::ChannelId;
20use vmbus_core::protocol::FeatureFlags;
21use vmbus_core::protocol::GpadlId;
22use vmbus_core::protocol::Version;
23use vmbus_ring::gparange;
24use vmcore::monitor::MonitorId;
25
26impl super::Server {
27    fn restore_one_channel(&mut self, saved_channel: Channel) -> Result<(), RestoreError> {
28        let (info, stub_offer, state) = saved_channel.restore(self.max_version)?;
29        if let Some((offer_id, channel)) = self.channels.get_by_key_mut(&saved_channel.key) {
30            // There is an existing channel. Restore on top of it.
31
32            if !matches!(channel.state, super::ChannelState::ClientReleased)
33                || channel.restore_state != RestoreState::New
34            {
35                return Err(RestoreError::AlreadyRestored(saved_channel.key));
36            }
37
38            // The channel's monitor ID can be already set if it was set by the device, which is
39            // the case with relay channels. In that case, it must match the saved ID.
40            if let MnfUsage::Relayed { monitor_id } = channel.offer.use_mnf {
41                if info.monitor_id != Some(MonitorId(monitor_id)) {
42                    return Err(RestoreError::MismatchedMonitorId(
43                        monitor_id,
44                        saved_channel.monitor_id,
45                    ));
46                }
47            }
48
49            self.assigned_channels
50                .set(info.channel_id)?
51                .insert(offer_id);
52
53            channel.state = state;
54            channel.restore_state = RestoreState::Restoring;
55            channel.info = Some(info);
56        } else {
57            // There is no existing channel.
58
59            let entry = self
60                .assigned_channels
61                .set(ChannelId(saved_channel.channel_id))?;
62
63            let channel = super::Channel {
64                info: Some(info),
65                offer: stub_offer,
66                state,
67                restore_state: RestoreState::Unmatched,
68            };
69
70            let offer_id = self.channels.offer(channel);
71            entry.insert(offer_id);
72        }
73        Ok(())
74    }
75
76    fn restore_one_gpadl(&mut self, saved_gpadl: Gpadl) -> Result<(), RestoreError> {
77        let gpadl_id = GpadlId(saved_gpadl.id);
78        let channel_id = ChannelId(saved_gpadl.channel_id);
79        let (offer_id, channel) = self
80            .channels
81            .get_by_channel_id(&self.assigned_channels, channel_id)
82            .map_err(|_| RestoreError::MissingGpadlChannel(gpadl_id, channel_id))?;
83
84        if channel.restore_state == RestoreState::New || channel.state.is_released() {
85            return Err(RestoreError::MissingGpadlChannel(gpadl_id, channel_id));
86        }
87
88        let gpadl = saved_gpadl.restore(channel)?;
89        let state = gpadl.state;
90        if self.gpadls.insert((gpadl_id, offer_id), gpadl).is_some() {
91            return Err(RestoreError::GpadlIdInUse(gpadl_id, channel_id));
92        }
93
94        if state == super::GpadlState::InProgress
95            && self.incomplete_gpadls.insert(gpadl_id, offer_id).is_some()
96        {
97            unreachable!("gpadl ID validated above");
98        }
99
100        Ok(())
101    }
102
103    /// Saves state.
104    pub fn save(&self) -> SavedState {
105        SavedStateData {
106            state: if let Some(state) = self.save_connected_state() {
107                SavedConnectionState::Connected(state)
108            } else {
109                SavedConnectionState::Disconnected(self.save_disconnected_state())
110            },
111            pending_messages: self.save_pending_messages(),
112        }
113        .into()
114    }
115
116    fn save_connected_state(&self) -> Option<ConnectedState> {
117        let connection = Connection::save(&self.state)?;
118        let channels = self
119            .channels
120            .iter()
121            .filter_map(|(_, channel)| Channel::save(channel))
122            .collect();
123
124        let gpadls = self.save_gpadls();
125        Some(ConnectedState {
126            connection,
127            channels,
128            gpadls,
129        })
130    }
131
132    fn save_gpadls(&self) -> Vec<Gpadl> {
133        self.gpadls
134            .iter()
135            .filter_map(|((gpadl_id, offer_id), gpadl)| {
136                Gpadl::save(*gpadl_id, self.channels[*offer_id].info?.channel_id, gpadl)
137            })
138            .collect()
139    }
140
141    fn save_disconnected_state(&self) -> DisconnectedState {
142        // Save reserved channels only.
143        let channels = self
144            .channels
145            .iter()
146            .filter_map(|(_, channel)| {
147                channel
148                    .state
149                    .is_reserved()
150                    .then(|| Channel::save(channel))
151                    .flatten()
152            })
153            .collect();
154
155        // Save the GPADLs for reserved channels.
156        // N.B. There cannot be any other GPADLs while disconnected.
157        let gpadls = self.save_gpadls();
158        DisconnectedState {
159            reserved_channels: channels,
160            reserved_gpadls: gpadls,
161        }
162    }
163
164    fn save_pending_messages(&self) -> Vec<OutgoingMessage> {
165        self.pending_messages
166            .0
167            .iter()
168            .map(OutgoingMessage::save)
169            .collect()
170    }
171}
172
173impl<'a, N: 'a + Notifier> super::ServerWithNotifier<'a, N> {
174    /// Restores state.
175    ///
176    /// This may be called before or after channels have been offered. After
177    /// calling this routine, [`restore_channel`] should be
178    /// called for each channel to be restored, possibly interleaved with
179    /// additional calls to offer or revoke channels.
180    ///
181    /// Once all channels are in the appropriate state,
182    /// [`revoke_unclaimed_channels`] should be called. This will revoke
183    /// any channels that were in the saved state but were not restored via
184    /// [`restore_channel`].
185    ///
186    /// [`revoke_unclaimed_channels`]: super::ServerWithNotifier::revoke_unclaimed_channels
187    /// [`restore_channel`]: super::ServerWithNotifier::restore_channel
188    pub fn restore(&mut self, saved: SavedState) -> Result<(), RestoreError> {
189        tracing::trace!(?saved, "restoring channel state");
190
191        let saved = SavedStateData::from(saved);
192        match saved.state {
193            SavedConnectionState::Connected(saved) => {
194                self.inner.state = saved.connection.restore(self.inner.max_version)?;
195
196                // Restore server state, and resend server notifications if needed. If these notifications
197                // were processed before the save, it's harmless as the values will be the same.
198                let request = match self.inner.state {
199                    super::ConnectionState::Connecting {
200                        info,
201                        next_action: _,
202                    } => Some(super::ModifyConnectionRequest {
203                        version: Some(info.version),
204                        interrupt_page: info.interrupt_page.into(),
205                        monitor_page: info.monitor_page.map(|mp| mp.gpas).into(),
206                        target_message_vp: Some(info.target_message_vp),
207                        notify_relay: true,
208                    }),
209                    super::ConnectionState::Connected(info) => {
210                        Some(super::ModifyConnectionRequest {
211                            version: None,
212                            monitor_page: info.monitor_page.map(|mp| mp.gpas).into(),
213                            interrupt_page: info.interrupt_page.into(),
214                            target_message_vp: Some(info.target_message_vp),
215                            // If the save didn't happen while modifying, the relay doesn't need to be notified
216                            // of this info as it doesn't constitute a change, we're just restoring existing
217                            // connection state.
218                            notify_relay: info.modifying,
219                        })
220                    }
221                    // No action needed for these states; if disconnecting, check_disconnected will resend
222                    // the reset request if needed.
223                    super::ConnectionState::Disconnected
224                    | super::ConnectionState::Disconnecting { .. } => None,
225                };
226
227                if let Some(request) = request {
228                    self.notifier.modify_connection(request)?;
229                }
230
231                for saved_channel in saved.channels {
232                    self.inner.restore_one_channel(saved_channel)?;
233                }
234
235                for saved_gpadl in saved.gpadls {
236                    self.inner.restore_one_gpadl(saved_gpadl)?;
237                }
238            }
239            SavedConnectionState::Disconnected(saved) => {
240                self.inner.state = super::ConnectionState::Disconnected;
241                for saved_channel in saved.reserved_channels {
242                    self.inner.restore_one_channel(saved_channel)?;
243                }
244
245                for saved_gpadl in saved.reserved_gpadls {
246                    self.inner.restore_one_gpadl(saved_gpadl)?;
247                }
248            }
249        }
250
251        self.inner
252            .pending_messages
253            .0
254            .reserve(saved.pending_messages.len());
255
256        for message in saved.pending_messages {
257            self.inner.pending_messages.0.push_back(message.restore()?);
258        }
259
260        Ok(())
261    }
262}
263
264#[derive(Debug, Error)]
265pub enum RestoreError {
266    #[error(transparent)]
267    Offer(#[from] OfferError),
268
269    #[error("channel {0} has already been restored")]
270    AlreadyRestored(OfferKey),
271
272    #[error("gpadl {} is for missing channel {}", (.0).0, (.1).0)]
273    MissingGpadlChannel(GpadlId, ChannelId),
274
275    #[error("gpadl {} is for revoked channel {}", (.0).0, (.1).0)]
276    GpadlForRevokedChannel(GpadlId, ChannelId),
277
278    #[error("gpadl {} is already restored", (.0).0)]
279    GpadlIdInUse(GpadlId, ChannelId),
280
281    #[error("unsupported protocol version {0:#x}")]
282    UnsupportedVersion(u32),
283
284    #[error("invalid gpadl")]
285    InvalidGpadl(#[from] gparange::Error),
286
287    #[error("unsupported feature flags {0:#x}")]
288    UnsupportedFeatureFlags(u32),
289
290    #[error("channel {0} has a mismatched open state")]
291    MismatchedOpenState(OfferKey),
292
293    #[error("channel {0} is missing from the saved state")]
294    MissingChannel(OfferKey),
295
296    #[error("unsupported reserved channel protocol version {0:#x}")]
297    UnsupportedReserveVersion(u32),
298
299    #[error("unsupported reserved channel feature flags {0:#x}")]
300    UnsupportedReserveFeatureFlags(u32),
301
302    #[error("mismatched monitor id; expected {0}, actual {1:?}")]
303    MismatchedMonitorId(u8, Option<u8>),
304
305    #[error("monitor ID used by multiple channels in the saved state")]
306    DuplicateMonitorId(u8),
307
308    #[error(transparent)]
309    ServerError(#[from] anyhow::Error),
310
311    #[error(
312        "reserved channel with ID {0} has a pending message but is missing from the saved state"
313    )]
314    MissingReservedChannel(u32),
315    #[error("a saved pending message is larger than the maximum message size")]
316    MessageTooLarge,
317}
318
319mod inner {
320    use super::*;
321
322    /// The top-level saved state for the VMBus channels library. It is placed in its own module to
323    /// keep the internals private, and the only thing you can do with it is convert to/from
324    /// `SavedStateData`. This enforces that users always consider both the connected and
325    /// disconnected states.
326    #[derive(Debug, Protobuf, Clone)]
327    #[mesh(package = "vmbus.server.channels")]
328    pub struct SavedState {
329        #[mesh(1)]
330        state: Option<ConnectedState>,
331        // Disconnected state is used to save any open reserved channels while the guest is
332        // disconnected. It is mutually exclusive with `state`, but is separate to maintain saved
333        // state compatibility.
334        // N.B. In a saved state created by the current version, either state or disconnected_state
335        //      is always `Some`, but for older versions, it is possible that both are `None`. They
336        //      can never both be `Some`.
337        #[mesh(2)]
338        disconnected_state: Option<DisconnectedState>,
339        #[mesh(3)]
340        pending_messages: Vec<OutgoingMessage>,
341    }
342
343    impl From<SavedStateData> for SavedState {
344        fn from(value: SavedStateData) -> Self {
345            let (state, disconnected_state) = match value.state {
346                SavedConnectionState::Connected(connected) => (Some(connected), None),
347                SavedConnectionState::Disconnected(disconnected) => (None, Some(disconnected)),
348            };
349
350            Self {
351                state,
352                disconnected_state,
353                pending_messages: value.pending_messages,
354            }
355        }
356    }
357
358    impl From<SavedState> for SavedStateData {
359        fn from(value: SavedState) -> Self {
360            Self {
361                state: if let Some(connected) = value.state {
362                    SavedConnectionState::Connected(connected)
363                } else {
364                    // Older saved state versions may not have a disconnected state, in which case
365                    // we use an empty value which has no channels or gpadls.
366                    SavedConnectionState::Disconnected(value.disconnected_state.unwrap_or_default())
367                },
368                pending_messages: value.pending_messages,
369            }
370        }
371    }
372}
373
374/// Represents either connected or disconnected saved state.
375#[derive(Debug, Clone)]
376pub enum SavedConnectionState {
377    Connected(ConnectedState),
378    Disconnected(DisconnectedState),
379}
380
381/// Alternative representation of the saved state that ensures that all code paths deal with either
382/// the connected or disconnected state, and cannot neglect one.
383#[derive(Debug, Clone)]
384pub struct SavedStateData {
385    pub state: SavedConnectionState,
386    pub pending_messages: Vec<OutgoingMessage>,
387}
388
389impl SavedStateData {
390    /// Finds a channel in the saved state.
391    pub fn find_channel(&self, offer: OfferKey) -> Option<&Channel> {
392        let (channels, _) = self.channels_and_gpadls();
393        channels.iter().find(|c| c.key == offer)
394    }
395
396    /// Retrieves all the channels and GPADLs from the saved state.
397    /// If disconnected, returns any reserved channels and their GPADLs.
398    pub fn channels_and_gpadls(&self) -> (&[Channel], &[Gpadl]) {
399        match &self.state {
400            SavedConnectionState::Connected(connected) => (&connected.channels, &connected.gpadls),
401            SavedConnectionState::Disconnected(disconnected) => (
402                &disconnected.reserved_channels,
403                &disconnected.reserved_gpadls,
404            ),
405        }
406    }
407}
408
409#[derive(Debug, Clone, Protobuf)]
410#[mesh(package = "vmbus.server.channels")]
411pub struct ConnectedState {
412    #[mesh(1)]
413    pub connection: Connection,
414    #[mesh(2)]
415    pub channels: Vec<Channel>,
416    #[mesh(3)]
417    pub gpadls: Vec<Gpadl>,
418}
419
420#[derive(Default, Debug, Clone, Protobuf)]
421#[mesh(package = "vmbus.server.channels")]
422pub struct DisconnectedState {
423    #[mesh(1)]
424    pub reserved_channels: Vec<Channel>,
425    #[mesh(2)]
426    pub reserved_gpadls: Vec<Gpadl>,
427}
428
429#[derive(Debug, PartialEq, Eq, Clone, Protobuf)]
430#[mesh(package = "vmbus.server.channels")]
431pub struct VersionInfo {
432    #[mesh(1)]
433    pub version: u32,
434    #[mesh(2)]
435    pub feature_flags: u32,
436}
437
438impl VersionInfo {
439    fn save(value: &super::VersionInfo) -> Self {
440        Self {
441            version: value.version as u32,
442            feature_flags: value.feature_flags.into(),
443        }
444    }
445
446    fn restore(
447        self,
448        trusted: bool,
449        max_version: Option<MaxVersionInfo>,
450    ) -> Result<vmbus_core::VersionInfo, RestoreError> {
451        let version = super::SUPPORTED_VERSIONS
452            .iter()
453            .find(|v| self.version == **v as u32)
454            .copied()
455            .ok_or(RestoreError::UnsupportedVersion(self.version))?;
456
457        // Enforce the server's compatibility version limit, if any. This may
458        // restrict both the protocol version and the set of feature flags
459        // permitted, similar to negotiation in `check_version_supported`.
460        if let Some(max_version) = max_version {
461            if version as u32 > max_version.version {
462                return Err(RestoreError::UnsupportedVersion(self.version));
463            }
464        }
465
466        let feature_flags = FeatureFlags::from(self.feature_flags);
467        let supported_flags = if version >= Version::Copper {
468            let mut flags = SUPPORTED_FEATURE_FLAGS.with_confidential_channels(trusted);
469            // Limit feature flags to those supported by the server's
470            // compatibility version, if specified.
471            if let Some(max_version) = max_version {
472                flags &= max_version.feature_flags;
473            }
474
475            flags
476        } else {
477            // Earlier protocol versions don't have feature flags.
478            FeatureFlags::new()
479        };
480
481        if !supported_flags.contains(feature_flags) {
482            return Err(RestoreError::UnsupportedFeatureFlags(feature_flags.into()));
483        }
484
485        Ok(super::VersionInfo {
486            version,
487            feature_flags,
488        })
489    }
490}
491
492#[derive(Debug, Clone, Protobuf)]
493#[mesh(package = "vmbus.server.channels")]
494pub enum Connection {
495    #[mesh(1)]
496    Disconnecting {
497        #[mesh(1)]
498        next_action: ConnectionAction,
499    },
500    #[mesh(2)]
501    Connecting {
502        #[mesh(1)]
503        version: VersionInfo,
504        #[mesh(2)]
505        interrupt_page: Option<u64>,
506        #[mesh(3)]
507        monitor_page: Option<MonitorPageGpas>,
508        #[mesh(4)]
509        target_message_vp: u32,
510        #[mesh(5)]
511        next_action: ConnectionAction,
512        #[mesh(6)]
513        client_id: Option<Guid>,
514        #[mesh(7)]
515        trusted: bool,
516    },
517    #[mesh(3)]
518    Connected {
519        #[mesh(1)]
520        version: VersionInfo,
521        #[mesh(2)]
522        offers_sent: bool,
523        #[mesh(3)]
524        interrupt_page: Option<u64>,
525        #[mesh(4)]
526        monitor_page: Option<MonitorPageGpas>,
527        #[mesh(5)]
528        target_message_vp: u32,
529        #[mesh(6)]
530        modifying: bool,
531        #[mesh(7)]
532        client_id: Option<Guid>,
533        #[mesh(8)]
534        trusted: bool,
535        #[mesh(9)]
536        paused: bool,
537    },
538}
539
540impl Connection {
541    fn save(value: &super::ConnectionState) -> Option<Self> {
542        match value {
543            super::ConnectionState::Disconnected => {
544                // No state to save.
545                None
546            }
547            super::ConnectionState::Connecting { info, next_action } => {
548                Some(Connection::Connecting {
549                    version: VersionInfo::save(&info.version),
550                    interrupt_page: info.interrupt_page,
551                    monitor_page: info.monitor_page.map(MonitorPageGpas::save),
552                    target_message_vp: info.target_message_vp,
553                    next_action: ConnectionAction::save(next_action),
554                    client_id: Some(info.client_id),
555                    trusted: info.trusted,
556                })
557            }
558            super::ConnectionState::Connected(info) => Some(Connection::Connected {
559                version: VersionInfo::save(&info.version),
560                offers_sent: info.offers_sent,
561                interrupt_page: info.interrupt_page,
562                monitor_page: info.monitor_page.map(MonitorPageGpas::save),
563                target_message_vp: info.target_message_vp,
564                modifying: info.modifying,
565                client_id: Some(info.client_id),
566                trusted: info.trusted,
567                paused: info.paused,
568            }),
569            super::ConnectionState::Disconnecting {
570                next_action,
571                modify_sent: _,
572            } => Some(Connection::Disconnecting {
573                next_action: ConnectionAction::save(next_action),
574            }),
575        }
576    }
577
578    fn restore(
579        self,
580        max_version: Option<MaxVersionInfo>,
581    ) -> Result<super::ConnectionState, RestoreError> {
582        Ok(match self {
583            Connection::Connecting {
584                version,
585                interrupt_page,
586                monitor_page,
587                target_message_vp,
588                next_action,
589                client_id,
590                trusted,
591            } => super::ConnectionState::Connecting {
592                info: super::ConnectionInfo {
593                    version: version.restore(trusted, max_version)?,
594                    trusted,
595                    interrupt_page,
596                    monitor_page: monitor_page.map(MonitorPageGpas::restore),
597                    target_message_vp,
598                    offers_sent: false,
599                    modifying: false,
600                    client_id: client_id.unwrap_or(Guid::ZERO),
601                    paused: false,
602                },
603                next_action: next_action.restore(),
604            },
605            Connection::Connected {
606                version,
607                offers_sent,
608                interrupt_page,
609                monitor_page,
610                target_message_vp,
611                modifying,
612                client_id,
613                trusted,
614                paused,
615            } => super::ConnectionState::Connected(super::ConnectionInfo {
616                version: version.restore(trusted, max_version)?,
617                trusted,
618                offers_sent,
619                interrupt_page,
620                monitor_page: monitor_page.map(MonitorPageGpas::restore),
621                target_message_vp,
622                modifying,
623                client_id: client_id.unwrap_or(Guid::ZERO),
624                paused,
625            }),
626            Connection::Disconnecting { next_action } => super::ConnectionState::Disconnecting {
627                next_action: next_action.restore(),
628                // If the modify request was sent, it will be resent.
629                modify_sent: false,
630            },
631        })
632    }
633}
634
635#[derive(Debug, Clone, Protobuf)]
636#[mesh(package = "vmbus.server.channels")]
637pub enum ConnectionAction {
638    #[mesh(1)]
639    None,
640    #[mesh(2)]
641    SendUnloadComplete,
642    #[mesh(3)]
643    Reconnect {
644        #[mesh(1)]
645        initiate_contact: InitiateContactRequest,
646    },
647    #[mesh(4)]
648    SendFailedVersionResponse,
649}
650
651impl ConnectionAction {
652    fn save(value: &super::ConnectionAction) -> Self {
653        match value {
654            super::ConnectionAction::Reset | super::ConnectionAction::None => {
655                // The caller is responsible for remembering that a
656                // reset was in progress and reissuing it.
657                Self::None
658            }
659            super::ConnectionAction::SendUnloadComplete => Self::SendUnloadComplete,
660            super::ConnectionAction::Reconnect { initiate_contact } => Self::Reconnect {
661                initiate_contact: InitiateContactRequest::save(initiate_contact),
662            },
663            super::ConnectionAction::SendFailedVersionResponse => Self::SendFailedVersionResponse,
664        }
665    }
666
667    fn restore(self) -> super::ConnectionAction {
668        match self {
669            Self::None => super::ConnectionAction::None,
670            Self::SendUnloadComplete => super::ConnectionAction::SendUnloadComplete,
671            Self::Reconnect { initiate_contact } => super::ConnectionAction::Reconnect {
672                initiate_contact: initiate_contact.restore(),
673            },
674            Self::SendFailedVersionResponse => super::ConnectionAction::SendFailedVersionResponse,
675        }
676    }
677}
678
679#[derive(Debug, Clone, Protobuf)]
680#[mesh(package = "vmbus.server.channels")]
681pub struct Channel {
682    #[mesh(1)]
683    pub key: OfferKey,
684    #[mesh(2)]
685    pub channel_id: u32,
686    #[mesh(3)]
687    pub offered_connection_id: u32,
688    #[mesh(4)]
689    pub state: ChannelState,
690    #[mesh(5)]
691    pub monitor_id: Option<u8>,
692}
693
694impl Channel {
695    fn save(value: &super::Channel) -> Option<Self> {
696        let info = value.info.as_ref()?;
697        let key = value.offer.key();
698        if let Some(state) = ChannelState::save(&value.state) {
699            tracing::trace!(%key, %state, "channel saved");
700            Some(Channel {
701                channel_id: info.channel_id.0,
702                offered_connection_id: info.connection_id,
703                key,
704                state,
705                monitor_id: info.monitor_id.map(|id| id.0),
706            })
707        } else {
708            tracing::info!(%key, state = %value.state, "skipping channel save");
709            None
710        }
711    }
712
713    fn restore(
714        &self,
715        max_version: Option<MaxVersionInfo>,
716    ) -> Result<(OfferedInfo, OfferParamsInternal, super::ChannelState), RestoreError> {
717        let info = OfferedInfo {
718            channel_id: ChannelId(self.channel_id),
719            connection_id: self.offered_connection_id,
720            monitor_id: self.monitor_id.map(MonitorId),
721        };
722
723        let stub_offer = OfferParamsInternal {
724            instance_id: self.key.instance_id,
725            interface_id: self.key.interface_id,
726            subchannel_index: self.key.subchannel_index,
727            ..Default::default()
728        };
729
730        let state = self.state.restore(max_version)?;
731        tracing::info!(key = %self.key, %state, "channel restored");
732        Ok((info, stub_offer, state))
733    }
734
735    pub fn channel_id(&self) -> u32 {
736        self.channel_id
737    }
738
739    pub fn key(&self) -> OfferKey {
740        self.key
741    }
742
743    pub fn open_request(&self) -> Option<OpenRequest> {
744        match self.state {
745            ChannelState::Closed => None,
746            ChannelState::Opening { request, .. } => Some(request),
747            ChannelState::Open { params, .. } => Some(params),
748            ChannelState::Closing { params, .. } => Some(params),
749            ChannelState::ClosingReopen { params, .. } => Some(params),
750            ChannelState::Revoked => None,
751        }
752    }
753}
754
755#[derive(PartialEq, Eq, Debug, Copy, Clone, Protobuf)]
756#[mesh(package = "vmbus.server.channels")]
757pub struct InitiateContactRequest {
758    #[mesh(1)]
759    pub version_requested: u32,
760    #[mesh(2)]
761    pub target_message_vp: u32,
762    #[mesh(3)]
763    pub monitor_page: MonitorPageRequest,
764    #[mesh(4)]
765    pub target_sint: u8,
766    #[mesh(5)]
767    pub target_vtl: u8,
768    #[mesh(6)]
769    pub feature_flags: u32,
770    #[mesh(7)]
771    pub interrupt_page: Option<u64>,
772    #[mesh(8)]
773    pub client_id: Guid,
774    #[mesh(9)]
775    pub trusted: bool,
776}
777
778impl InitiateContactRequest {
779    fn save(value: &super::InitiateContactRequest) -> Self {
780        Self {
781            version_requested: value.version_requested,
782            target_message_vp: value.target_message_vp,
783            monitor_page: MonitorPageRequest::save(value.monitor_page),
784            target_sint: value.target_sint,
785            target_vtl: value.target_vtl,
786            feature_flags: value.feature_flags,
787            interrupt_page: value.interrupt_page,
788            client_id: value.client_id,
789            trusted: value.trusted,
790        }
791    }
792
793    fn restore(self) -> super::InitiateContactRequest {
794        super::InitiateContactRequest {
795            version_requested: self.version_requested,
796            target_message_vp: self.target_message_vp,
797            monitor_page: self.monitor_page.restore(),
798            target_sint: self.target_sint,
799            target_vtl: self.target_vtl,
800            feature_flags: self.feature_flags,
801            interrupt_page: self.interrupt_page,
802            client_id: self.client_id,
803            trusted: self.trusted,
804        }
805    }
806}
807
808#[derive(PartialEq, Eq, Debug, Copy, Clone, Protobuf)]
809#[mesh(package = "vmbus.server.channels")]
810pub struct MonitorPageGpas {
811    #[mesh(1)]
812    pub parent_to_child: u64,
813    #[mesh(2)]
814    pub child_to_parent: u64,
815}
816
817impl MonitorPageGpas {
818    fn save(value: super::MonitorPageGpaInfo) -> Self {
819        assert!(
820            !value.server_allocated,
821            "cannot save with server-allocated monitor pages"
822        );
823        Self {
824            child_to_parent: value.gpas.child_to_parent,
825            parent_to_child: value.gpas.parent_to_child,
826        }
827    }
828
829    fn restore(self) -> super::MonitorPageGpaInfo {
830        super::MonitorPageGpaInfo::from_guest_gpas(super::MonitorPageGpas {
831            child_to_parent: self.child_to_parent,
832            parent_to_child: self.parent_to_child,
833        })
834    }
835}
836
837#[derive(PartialEq, Eq, Debug, Copy, Clone, Protobuf)]
838#[mesh(package = "vmbus.server.channels")]
839pub enum MonitorPageRequest {
840    #[mesh(1)]
841    None,
842    #[mesh(2)]
843    Some(#[mesh(1)] MonitorPageGpas),
844    #[mesh(3)]
845    Invalid,
846}
847
848impl MonitorPageRequest {
849    fn save(value: super::MonitorPageRequest) -> Self {
850        match value {
851            super::MonitorPageRequest::None => MonitorPageRequest::None,
852            super::MonitorPageRequest::Some(mp) => MonitorPageRequest::Some(MonitorPageGpas::save(
853                super::MonitorPageGpaInfo::from_guest_gpas(mp),
854            )),
855            super::MonitorPageRequest::Invalid => MonitorPageRequest::Invalid,
856        }
857    }
858
859    fn restore(self) -> super::MonitorPageRequest {
860        match self {
861            MonitorPageRequest::None => super::MonitorPageRequest::None,
862            MonitorPageRequest::Some(mp) => super::MonitorPageRequest::Some(mp.restore().gpas),
863            MonitorPageRequest::Invalid => super::MonitorPageRequest::Invalid,
864        }
865    }
866}
867
868#[derive(PartialEq, Eq, Debug, Copy, Clone, Protobuf)]
869#[mesh(package = "vmbus.server.channels")]
870pub struct SignalInfo {
871    #[mesh(1)]
872    pub event_flag: u16,
873    #[mesh(2)]
874    pub connection_id: u32,
875}
876
877impl SignalInfo {
878    fn save(value: &super::SignalInfo) -> Self {
879        Self {
880            event_flag: value.event_flag,
881            connection_id: value.connection_id,
882        }
883    }
884
885    fn restore(self) -> super::SignalInfo {
886        super::SignalInfo {
887            event_flag: self.event_flag,
888            connection_id: self.connection_id,
889        }
890    }
891}
892
893#[derive(Copy, PartialEq, Eq, Clone, Protobuf)]
894#[mesh(package = "vmbus.server.channels")]
895pub struct OpenRequest {
896    #[mesh(1)]
897    pub open_id: u32,
898    #[mesh(2)]
899    pub ring_buffer_gpadl_id: GpadlId,
900    #[mesh(3)]
901    pub target_vp: u32,
902    #[mesh(4)]
903    pub downstream_ring_buffer_page_offset: u32,
904    #[mesh(5)]
905    pub user_data: [u8; 120],
906    #[mesh(6)]
907    pub guest_specified_interrupt_info: Option<SignalInfo>,
908    #[mesh(7)]
909    pub flags: u16,
910}
911
912impl std::fmt::Debug for OpenRequest {
913    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
914        let Self {
915            open_id,
916            ring_buffer_gpadl_id,
917            target_vp,
918            downstream_ring_buffer_page_offset,
919            user_data,
920            guest_specified_interrupt_info,
921            flags,
922        } = self;
923
924        let user_data_display: &dyn std::fmt::Debug = if self.user_data.iter().all(|&b| b == 0) {
925            &"[<all-zeroes>]"
926        } else {
927            struct HexDisplay<'a>(&'a [u8; 120]);
928            impl std::fmt::Debug for HexDisplay<'_> {
929                fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
930                    write!(f, "[")?;
931                    for byte in self.0 {
932                        write!(f, "{:02X}", byte)?;
933                    }
934                    write!(f, "]")
935                }
936            }
937            &HexDisplay(user_data)
938        };
939
940        f.debug_struct("OpenRequest")
941            .field("open_id", open_id)
942            .field("ring_buffer_gpadl_id", ring_buffer_gpadl_id)
943            .field("target_vp", target_vp)
944            .field(
945                "downstream_ring_buffer_page_offset",
946                downstream_ring_buffer_page_offset,
947            )
948            .field("user_data", user_data_display)
949            .field(
950                "guest_specified_interrupt_info",
951                guest_specified_interrupt_info,
952            )
953            .field("flags", flags)
954            .finish()
955    }
956}
957
958impl OpenRequest {
959    fn save(value: &super::OpenRequest) -> Self {
960        Self {
961            open_id: value.open_id,
962            ring_buffer_gpadl_id: value.ring_buffer_gpadl_id,
963            target_vp: value
964                .target_vp
965                .unwrap_or(protocol::VP_INDEX_DISABLE_INTERRUPT),
966            downstream_ring_buffer_page_offset: value.downstream_ring_buffer_page_offset,
967            user_data: value.user_data.into(),
968            guest_specified_interrupt_info: value
969                .guest_specified_interrupt_info
970                .as_ref()
971                .map(SignalInfo::save),
972            flags: value.flags.into(),
973        }
974    }
975
976    fn restore(self) -> super::OpenRequest {
977        super::OpenRequest {
978            open_id: self.open_id,
979            ring_buffer_gpadl_id: self.ring_buffer_gpadl_id,
980            target_vp: protocol::vp_index_if_enabled(self.target_vp),
981            downstream_ring_buffer_page_offset: self.downstream_ring_buffer_page_offset,
982            user_data: self.user_data.into(),
983            guest_specified_interrupt_info: self
984                .guest_specified_interrupt_info
985                .map(SignalInfo::restore),
986            flags: self.flags.into(),
987        }
988    }
989}
990
991#[derive(Debug, Copy, Clone, Protobuf)]
992#[mesh(package = "vmbus.server.channels")]
993pub enum ModifyState {
994    #[mesh(1)]
995    NotModifying,
996    #[mesh(2)]
997    Modifying {
998        #[mesh(1)]
999        pending_target_vp: Option<u32>,
1000    },
1001}
1002
1003impl ModifyState {
1004    fn save(value: &super::ModifyState) -> Self {
1005        match value {
1006            super::ModifyState::NotModifying => Self::NotModifying,
1007            super::ModifyState::Modifying { pending_target_vp } => Self::Modifying {
1008                pending_target_vp: *pending_target_vp,
1009            },
1010        }
1011    }
1012
1013    fn restore(self) -> super::ModifyState {
1014        match self {
1015            ModifyState::NotModifying => super::ModifyState::NotModifying,
1016            ModifyState::Modifying { pending_target_vp } => {
1017                super::ModifyState::Modifying { pending_target_vp }
1018            }
1019        }
1020    }
1021}
1022
1023#[derive(Debug, PartialEq, Eq, Clone, Protobuf)]
1024#[mesh(package = "vmbus.server.channels")]
1025pub struct ReservedState {
1026    #[mesh(1)]
1027    pub version: VersionInfo,
1028    #[mesh(2)]
1029    pub vp: u32,
1030    #[mesh(3)]
1031    pub sint: u8,
1032}
1033
1034impl ReservedState {
1035    fn save(reserved_state: &super::ReservedState) -> Self {
1036        Self {
1037            version: VersionInfo::save(&reserved_state.version),
1038            vp: reserved_state.target.vp,
1039            sint: reserved_state.target.sint,
1040        }
1041    }
1042
1043    fn restore(
1044        &self,
1045        max_version: Option<MaxVersionInfo>,
1046    ) -> Result<super::ReservedState, RestoreError> {
1047        // We don't know if the connection when the channel was reserved was trusted, so assume it
1048        // was for what feature flags are accepted here; it doesn't affect any actual behavior.
1049        let version = self
1050            .version
1051            .clone()
1052            .restore(true, max_version)
1053            .map_err(|e| match e {
1054                RestoreError::UnsupportedVersion(v) => RestoreError::UnsupportedReserveVersion(v),
1055                RestoreError::UnsupportedFeatureFlags(f) => {
1056                    RestoreError::UnsupportedReserveFeatureFlags(f)
1057                }
1058                err => err,
1059            })?;
1060
1061        if version.version < Version::Win10 {
1062            return Err(RestoreError::UnsupportedReserveVersion(
1063                version.version as u32,
1064            ));
1065        }
1066
1067        Ok(super::ReservedState {
1068            version,
1069            target: super::ConnectionTarget {
1070                vp: self.vp,
1071                sint: self.sint,
1072            },
1073        })
1074    }
1075}
1076
1077#[derive(Debug, Clone, Protobuf)]
1078#[mesh(package = "vmbus.server.channels")]
1079pub enum ChannelState {
1080    #[mesh(1)]
1081    Closed,
1082    #[mesh(2)]
1083    Opening {
1084        #[mesh(1)]
1085        request: OpenRequest,
1086        #[mesh(2)]
1087        reserved_state: Option<ReservedState>,
1088    },
1089    #[mesh(3)]
1090    Open {
1091        #[mesh(1)]
1092        params: OpenRequest,
1093        #[mesh(2)]
1094        modify_state: ModifyState,
1095        #[mesh(3)]
1096        reserved_state: Option<ReservedState>,
1097    },
1098    #[mesh(4)]
1099    Closing {
1100        #[mesh(1)]
1101        params: OpenRequest,
1102        #[mesh(2)]
1103        reserved_state: Option<ReservedState>,
1104    },
1105    #[mesh(5)]
1106    ClosingReopen {
1107        #[mesh(1)]
1108        params: OpenRequest,
1109        #[mesh(2)]
1110        request: OpenRequest,
1111    },
1112    #[mesh(6)]
1113    Revoked,
1114}
1115
1116impl ChannelState {
1117    fn save(value: &super::ChannelState) -> Option<Self> {
1118        Some(match value {
1119            super::ChannelState::Closed => ChannelState::Closed,
1120            super::ChannelState::Opening {
1121                request,
1122                reserved_state,
1123            } => ChannelState::Opening {
1124                request: OpenRequest::save(request),
1125                reserved_state: reserved_state.as_ref().map(ReservedState::save),
1126            },
1127            super::ChannelState::ClosingReopen { params, request } => ChannelState::ClosingReopen {
1128                params: OpenRequest::save(params),
1129                request: OpenRequest::save(request),
1130            },
1131            super::ChannelState::Open {
1132                params,
1133                modify_state,
1134                reserved_state,
1135            } => ChannelState::Open {
1136                params: OpenRequest::save(params),
1137                modify_state: ModifyState::save(modify_state),
1138                reserved_state: reserved_state.as_ref().map(ReservedState::save),
1139            },
1140            super::ChannelState::Closing {
1141                params,
1142                reserved_state,
1143            } => ChannelState::Closing {
1144                params: OpenRequest::save(params),
1145                reserved_state: reserved_state.as_ref().map(ReservedState::save),
1146            },
1147
1148            super::ChannelState::Revoked => ChannelState::Revoked,
1149            super::ChannelState::Reoffered => ChannelState::Revoked,
1150            super::ChannelState::ClientReleased
1151            | super::ChannelState::ClosingClientRelease
1152            | super::ChannelState::OpeningClientRelease => return None,
1153        })
1154    }
1155
1156    fn restore(
1157        &self,
1158        max_version: Option<MaxVersionInfo>,
1159    ) -> Result<super::ChannelState, RestoreError> {
1160        Ok(match self {
1161            ChannelState::Closed => super::ChannelState::Closed,
1162            ChannelState::Opening {
1163                request,
1164                reserved_state,
1165            } => super::ChannelState::Opening {
1166                request: request.restore(),
1167                reserved_state: reserved_state
1168                    .as_ref()
1169                    .map(|r| r.restore(max_version))
1170                    .transpose()?,
1171            },
1172            ChannelState::ClosingReopen { params, request } => super::ChannelState::ClosingReopen {
1173                params: params.restore(),
1174                request: request.restore(),
1175            },
1176            ChannelState::Open {
1177                params,
1178                modify_state,
1179                reserved_state,
1180            } => super::ChannelState::Open {
1181                params: params.restore(),
1182                modify_state: modify_state.restore(),
1183                reserved_state: reserved_state
1184                    .as_ref()
1185                    .map(|r| r.restore(max_version))
1186                    .transpose()?,
1187            },
1188            ChannelState::Closing {
1189                params,
1190                reserved_state,
1191            } => super::ChannelState::Closing {
1192                params: params.restore(),
1193                reserved_state: reserved_state
1194                    .as_ref()
1195                    .map(|r| r.restore(max_version))
1196                    .transpose()?,
1197            },
1198            ChannelState::Revoked => {
1199                // Mark it reoffered for now. This may transition back to revoked in post_restore.
1200                super::ChannelState::Reoffered
1201            }
1202        })
1203    }
1204}
1205
1206impl Display for ChannelState {
1207    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1208        let state = match self {
1209            Self::Closed => "Closed",
1210            Self::Opening { .. } => "Opening",
1211            Self::Open { .. } => "Open",
1212            Self::Closing { .. } => "Closing",
1213            Self::ClosingReopen { .. } => "ClosingReopen",
1214            Self::Revoked => "Revoked",
1215        };
1216        write!(f, "{}", state)
1217    }
1218}
1219
1220#[derive(Debug, Clone, Protobuf)]
1221#[mesh(package = "vmbus.server.channels")]
1222pub struct Gpadl {
1223    #[mesh(1)]
1224    pub id: u32,
1225    #[mesh(2)]
1226    pub channel_id: u32,
1227    #[mesh(3)]
1228    pub count: u16,
1229    #[mesh(4)]
1230    pub buf: Vec<u64>,
1231    #[mesh(5)]
1232    pub state: GpadlState,
1233}
1234
1235impl Gpadl {
1236    fn save(gpadl_id: GpadlId, channel_id: ChannelId, gpadl: &super::Gpadl) -> Option<Self> {
1237        tracing::trace!(id = %gpadl_id.0, channel_id = %channel_id.0, "gpadl saved");
1238        Some(Gpadl {
1239            id: gpadl_id.0,
1240            channel_id: channel_id.0,
1241            count: gpadl.count,
1242            buf: gpadl.buf.clone(),
1243            state: match gpadl.state {
1244                super::GpadlState::InProgress => GpadlState::InProgress,
1245                super::GpadlState::Offered => GpadlState::Offered,
1246                super::GpadlState::Accepted => GpadlState::Accepted,
1247                super::GpadlState::TearingDown => GpadlState::TearingDown,
1248                super::GpadlState::OfferedTearingDown => return None,
1249            },
1250        })
1251    }
1252
1253    fn restore(self, channel: &super::Channel) -> Result<super::Gpadl, RestoreError> {
1254        if self.state != GpadlState::InProgress {
1255            // Validate the range.
1256            gparange::validate_gpa_ranges(self.count.into(), &self.buf)?;
1257        }
1258        let (state, allow_revoked) = match self.state {
1259            GpadlState::InProgress => (super::GpadlState::InProgress, true),
1260            GpadlState::Offered => (super::GpadlState::Offered, false),
1261            GpadlState::Accepted => {
1262                // It is assumed the device already knows about this GPADL.
1263                (super::GpadlState::Accepted, true)
1264            }
1265            GpadlState::TearingDown => (super::GpadlState::TearingDown, false),
1266        };
1267
1268        if !allow_revoked && channel.state.is_revoked() {
1269            return Err(RestoreError::GpadlForRevokedChannel(
1270                GpadlId(self.id),
1271                ChannelId(self.channel_id),
1272            ));
1273        }
1274
1275        Ok(super::Gpadl {
1276            count: self.count,
1277            buf: self.buf,
1278            state,
1279        })
1280    }
1281
1282    pub fn is_tearing_down(&self) -> bool {
1283        self.state == GpadlState::TearingDown
1284    }
1285}
1286
1287#[derive(Debug, Clone, Protobuf, PartialEq, Eq)]
1288#[mesh(package = "vmbus.server.channels")]
1289pub enum GpadlState {
1290    #[mesh(1)]
1291    InProgress,
1292    #[mesh(2)]
1293    Offered,
1294    #[mesh(3)]
1295    Accepted,
1296    #[mesh(4)]
1297    TearingDown,
1298}
1299
1300#[derive(Debug, Clone, Protobuf, PartialEq, Eq)]
1301#[mesh(package = "vmbus.server.channels")]
1302pub struct OutgoingMessage(pub Vec<u8>);
1303
1304impl OutgoingMessage {
1305    fn save(value: &vmbus_core::OutgoingMessage) -> Self {
1306        Self(value.data().to_vec())
1307    }
1308
1309    fn restore(self) -> Result<vmbus_core::OutgoingMessage, RestoreError> {
1310        vmbus_core::OutgoingMessage::from_message(&self.0)
1311            .map_err(|_| RestoreError::MessageTooLarge)
1312    }
1313}