vmbus_server/channels/
saved_state.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4use super::MnfUsage;
5use super::Notifier;
6use super::OfferError;
7use super::OfferParamsInternal;
8use super::OfferedInfo;
9use super::RestoreState;
10use super::SUPPORTED_FEATURE_FLAGS;
11use guid::Guid;
12pub use inner::SavedState;
13use mesh::payload::Protobuf;
14use std::fmt::Display;
15use thiserror::Error;
16use vmbus_channel::bus::OfferKey;
17use vmbus_core::protocol::ChannelId;
18use vmbus_core::protocol::FeatureFlags;
19use vmbus_core::protocol::GpadlId;
20use vmbus_core::protocol::Version;
21use vmbus_ring::gparange;
22use vmcore::monitor::MonitorId;
23
24impl super::Server {
25    fn restore_one_channel(&mut self, saved_channel: Channel) -> Result<(), RestoreError> {
26        let (info, stub_offer, state) = saved_channel.restore()?;
27        if let Some((offer_id, channel)) = self.channels.get_by_key_mut(&saved_channel.key) {
28            // There is an existing channel. Restore on top of it.
29
30            if !matches!(channel.state, super::ChannelState::ClientReleased)
31                || channel.restore_state != RestoreState::New
32            {
33                return Err(RestoreError::AlreadyRestored(saved_channel.key));
34            }
35
36            // The channel's monitor ID can be already set if it was set by the device, which is
37            // the case with relay channels. In that case, it must match the saved ID.
38            if let MnfUsage::Relayed { monitor_id } = channel.offer.use_mnf {
39                if info.monitor_id != Some(MonitorId(monitor_id)) {
40                    return Err(RestoreError::MismatchedMonitorId(
41                        monitor_id,
42                        saved_channel.monitor_id,
43                    ));
44                }
45            }
46
47            self.assigned_channels
48                .set(info.channel_id)?
49                .insert(offer_id);
50
51            channel.state = state;
52            channel.restore_state = RestoreState::Restoring;
53            channel.info = Some(info);
54        } else {
55            // There is no existing channel.
56
57            let entry = self
58                .assigned_channels
59                .set(ChannelId(saved_channel.channel_id))?;
60
61            let channel = super::Channel {
62                info: Some(info),
63                offer: stub_offer,
64                state,
65                restore_state: RestoreState::Unmatched,
66            };
67
68            let offer_id = self.channels.offer(channel);
69            entry.insert(offer_id);
70        }
71        Ok(())
72    }
73
74    fn restore_one_gpadl(&mut self, saved_gpadl: Gpadl) -> Result<(), RestoreError> {
75        let gpadl_id = GpadlId(saved_gpadl.id);
76        let channel_id = ChannelId(saved_gpadl.channel_id);
77        let (offer_id, channel) = self
78            .channels
79            .get_by_channel_id(&self.assigned_channels, channel_id)
80            .map_err(|_| RestoreError::MissingGpadlChannel(gpadl_id, channel_id))?;
81
82        if channel.restore_state == RestoreState::New || channel.state.is_released() {
83            return Err(RestoreError::MissingGpadlChannel(gpadl_id, channel_id));
84        }
85
86        let gpadl = saved_gpadl.restore(channel)?;
87        let state = gpadl.state;
88        if self.gpadls.insert((gpadl_id, offer_id), gpadl).is_some() {
89            return Err(RestoreError::GpadlIdInUse(gpadl_id, channel_id));
90        }
91
92        if state == super::GpadlState::InProgress
93            && self.incomplete_gpadls.insert(gpadl_id, offer_id).is_some()
94        {
95            unreachable!("gpadl ID validated above");
96        }
97
98        Ok(())
99    }
100
101    /// Saves state.
102    pub fn save(&self) -> SavedState {
103        SavedStateData {
104            state: if let Some(state) = self.save_connected_state() {
105                SavedConnectionState::Connected(state)
106            } else {
107                SavedConnectionState::Disconnected(self.save_disconnected_state())
108            },
109            pending_messages: self.save_pending_messages(),
110        }
111        .into()
112    }
113
114    fn save_connected_state(&self) -> Option<ConnectedState> {
115        let connection = Connection::save(&self.state)?;
116        let channels = self
117            .channels
118            .iter()
119            .filter_map(|(_, channel)| Channel::save(channel))
120            .collect();
121
122        let gpadls = self.save_gpadls();
123        Some(ConnectedState {
124            connection,
125            channels,
126            gpadls,
127        })
128    }
129
130    fn save_gpadls(&self) -> Vec<Gpadl> {
131        self.gpadls
132            .iter()
133            .filter_map(|((gpadl_id, offer_id), gpadl)| {
134                Gpadl::save(*gpadl_id, self.channels[*offer_id].info?.channel_id, gpadl)
135            })
136            .collect()
137    }
138
139    fn save_disconnected_state(&self) -> DisconnectedState {
140        // Save reserved channels only.
141        let channels = self
142            .channels
143            .iter()
144            .filter_map(|(_, channel)| {
145                channel
146                    .state
147                    .is_reserved()
148                    .then(|| Channel::save(channel))
149                    .flatten()
150            })
151            .collect();
152
153        // Save the GPADLs for reserved channels.
154        // N.B. There cannot be any other GPADLs while disconnected.
155        let gpadls = self.save_gpadls();
156        DisconnectedState {
157            reserved_channels: channels,
158            reserved_gpadls: gpadls,
159        }
160    }
161
162    fn save_pending_messages(&self) -> Vec<OutgoingMessage> {
163        self.pending_messages
164            .0
165            .iter()
166            .map(OutgoingMessage::save)
167            .collect()
168    }
169}
170
171impl<'a, N: 'a + Notifier> super::ServerWithNotifier<'a, N> {
172    /// Restores state.
173    ///
174    /// This may be called before or after channels have been offered. After
175    /// calling this routine, [`restore_channel`] should be
176    /// called for each channel to be restored, possibly interleaved with
177    /// additional calls to offer or revoke channels.
178    ///
179    /// Once all channels are in the appropriate state,
180    /// [`revoke_unclaimed_channels`] should be called. This will revoke
181    /// any channels that were in the saved state but were not restored via
182    /// [`restore_channel`].
183    ///
184    /// [`revoke_unclaimed_channels`]: super::ServerWithNotifier::revoke_unclaimed_channels
185    /// [`restore_channel`]: super::ServerWithNotifier::restore_channel
186    pub fn restore(&mut self, saved: SavedState) -> Result<(), RestoreError> {
187        tracing::trace!(?saved, "restoring channel state");
188
189        let saved = SavedStateData::from(saved);
190        match saved.state {
191            SavedConnectionState::Connected(saved) => {
192                self.inner.state = saved.connection.restore()?;
193
194                // Restore server state, and resend server notifications if needed. If these notifications
195                // were processed before the save, it's harmless as the values will be the same.
196                let request = match self.inner.state {
197                    super::ConnectionState::Connecting {
198                        info,
199                        next_action: _,
200                    } => Some(super::ModifyConnectionRequest {
201                        version: Some(info.version),
202                        interrupt_page: info.interrupt_page.into(),
203                        monitor_page: info.monitor_page.map(|mp| mp.gpas).into(),
204                        target_message_vp: Some(info.target_message_vp),
205                        notify_relay: true,
206                    }),
207                    super::ConnectionState::Connected(info) => {
208                        Some(super::ModifyConnectionRequest {
209                            version: None,
210                            monitor_page: info.monitor_page.map(|mp| mp.gpas).into(),
211                            interrupt_page: info.interrupt_page.into(),
212                            target_message_vp: Some(info.target_message_vp),
213                            // If the save didn't happen while modifying, the relay doesn't need to be notified
214                            // of this info as it doesn't constitute a change, we're just restoring existing
215                            // connection state.
216                            notify_relay: info.modifying,
217                        })
218                    }
219                    // No action needed for these states; if disconnecting, check_disconnected will resend
220                    // the reset request if needed.
221                    super::ConnectionState::Disconnected
222                    | super::ConnectionState::Disconnecting { .. } => None,
223                };
224
225                if let Some(request) = request {
226                    self.notifier.modify_connection(request)?;
227                }
228
229                for saved_channel in saved.channels {
230                    self.inner.restore_one_channel(saved_channel)?;
231                }
232
233                for saved_gpadl in saved.gpadls {
234                    self.inner.restore_one_gpadl(saved_gpadl)?;
235                }
236            }
237            SavedConnectionState::Disconnected(saved) => {
238                self.inner.state = super::ConnectionState::Disconnected;
239                for saved_channel in saved.reserved_channels {
240                    self.inner.restore_one_channel(saved_channel)?;
241                }
242
243                for saved_gpadl in saved.reserved_gpadls {
244                    self.inner.restore_one_gpadl(saved_gpadl)?;
245                }
246            }
247        }
248
249        self.inner
250            .pending_messages
251            .0
252            .reserve(saved.pending_messages.len());
253
254        for message in saved.pending_messages {
255            self.inner.pending_messages.0.push_back(message.restore()?);
256        }
257
258        Ok(())
259    }
260}
261
262#[derive(Debug, Error)]
263pub enum RestoreError {
264    #[error(transparent)]
265    Offer(#[from] OfferError),
266
267    #[error("channel {0} has already been restored")]
268    AlreadyRestored(OfferKey),
269
270    #[error("gpadl {} is for missing channel {}", (.0).0, (.1).0)]
271    MissingGpadlChannel(GpadlId, ChannelId),
272
273    #[error("gpadl {} is for revoked channel {}", (.0).0, (.1).0)]
274    GpadlForRevokedChannel(GpadlId, ChannelId),
275
276    #[error("gpadl {} is already restored", (.0).0)]
277    GpadlIdInUse(GpadlId, ChannelId),
278
279    #[error("unsupported protocol version {0:#x}")]
280    UnsupportedVersion(u32),
281
282    #[error("invalid gpadl")]
283    InvalidGpadl(#[from] gparange::Error),
284
285    #[error("unsupported feature flags {0:#x}")]
286    UnsupportedFeatureFlags(u32),
287
288    #[error("channel {0} has a mismatched open state")]
289    MismatchedOpenState(OfferKey),
290
291    #[error("channel {0} is missing from the saved state")]
292    MissingChannel(OfferKey),
293
294    #[error("unsupported reserved channel protocol version {0:#x}")]
295    UnsupportedReserveVersion(u32),
296
297    #[error("unsupported reserved channel feature flags {0:#x}")]
298    UnsupportedReserveFeatureFlags(u32),
299
300    #[error("mismatched monitor id; expected {0}, actual {1:?}")]
301    MismatchedMonitorId(u8, Option<u8>),
302
303    #[error("monitor ID used by multiple channels in the saved state")]
304    DuplicateMonitorId(u8),
305
306    #[error(transparent)]
307    ServerError(#[from] anyhow::Error),
308
309    #[error(
310        "reserved channel with ID {0} has a pending message but is missing from the saved state"
311    )]
312    MissingReservedChannel(u32),
313    #[error("a saved pending message is larger than the maximum message size")]
314    MessageTooLarge,
315}
316
317mod inner {
318    use super::*;
319
320    /// The top-level saved state for the VMBus channels library. It is placed in its own module to
321    /// keep the internals private, and the only thing you can do with it is convert to/from
322    /// `SavedStateData`. This enforces that users always consider both the connected and
323    /// disconnected states.
324    #[derive(Debug, Protobuf, Clone)]
325    #[mesh(package = "vmbus.server.channels")]
326    pub struct SavedState {
327        #[mesh(1)]
328        state: Option<ConnectedState>,
329        // Disconnected state is used to save any open reserved channels while the guest is
330        // disconnected. It is mutually exclusive with `state`, but is separate to maintain saved
331        // state compatibility.
332        // N.B. In a saved state created by the current version, either state or disconnected_state
333        //      is always `Some`, but for older versions, it is possible that both are `None`. They
334        //      can never both be `Some`.
335        #[mesh(2)]
336        disconnected_state: Option<DisconnectedState>,
337        #[mesh(3)]
338        pending_messages: Vec<OutgoingMessage>,
339    }
340
341    impl From<SavedStateData> for SavedState {
342        fn from(value: SavedStateData) -> Self {
343            let (state, disconnected_state) = match value.state {
344                SavedConnectionState::Connected(connected) => (Some(connected), None),
345                SavedConnectionState::Disconnected(disconnected) => (None, Some(disconnected)),
346            };
347
348            Self {
349                state,
350                disconnected_state,
351                pending_messages: value.pending_messages,
352            }
353        }
354    }
355
356    impl From<SavedState> for SavedStateData {
357        fn from(value: SavedState) -> Self {
358            Self {
359                state: if let Some(connected) = value.state {
360                    SavedConnectionState::Connected(connected)
361                } else {
362                    // Older saved state versions may not have a disconnected state, in which case
363                    // we use an empty value which has no channels or gpadls.
364                    SavedConnectionState::Disconnected(value.disconnected_state.unwrap_or_default())
365                },
366                pending_messages: value.pending_messages,
367            }
368        }
369    }
370}
371
372/// Represents either connected or disconnected saved state.
373#[derive(Debug, Clone)]
374pub enum SavedConnectionState {
375    Connected(ConnectedState),
376    Disconnected(DisconnectedState),
377}
378
379/// Alternative representation of the saved state that ensures that all code paths deal with either
380/// the connected or disconnected state, and cannot neglect one.
381#[derive(Debug, Clone)]
382pub struct SavedStateData {
383    pub state: SavedConnectionState,
384    pub pending_messages: Vec<OutgoingMessage>,
385}
386
387impl SavedStateData {
388    /// Finds a channel in the saved state.
389    pub fn find_channel(&self, offer: OfferKey) -> Option<&Channel> {
390        let (channels, _) = self.channels_and_gpadls();
391        channels.iter().find(|c| c.key == offer)
392    }
393
394    /// Retrieves all the channels and GPADLs from the saved state.
395    /// If disconnected, returns any reserved channels and their GPADLs.
396    pub fn channels_and_gpadls(&self) -> (&[Channel], &[Gpadl]) {
397        match &self.state {
398            SavedConnectionState::Connected(connected) => (&connected.channels, &connected.gpadls),
399            SavedConnectionState::Disconnected(disconnected) => (
400                &disconnected.reserved_channels,
401                &disconnected.reserved_gpadls,
402            ),
403        }
404    }
405}
406
407#[derive(Debug, Clone, Protobuf)]
408#[mesh(package = "vmbus.server.channels")]
409pub struct ConnectedState {
410    #[mesh(1)]
411    pub connection: Connection,
412    #[mesh(2)]
413    pub channels: Vec<Channel>,
414    #[mesh(3)]
415    pub gpadls: Vec<Gpadl>,
416}
417
418#[derive(Default, Debug, Clone, Protobuf)]
419#[mesh(package = "vmbus.server.channels")]
420pub struct DisconnectedState {
421    #[mesh(1)]
422    pub reserved_channels: Vec<Channel>,
423    #[mesh(2)]
424    pub reserved_gpadls: Vec<Gpadl>,
425}
426
427#[derive(Debug, Clone, Protobuf)]
428#[mesh(package = "vmbus.server.channels")]
429pub struct VersionInfo {
430    #[mesh(1)]
431    pub version: u32,
432    #[mesh(2)]
433    pub feature_flags: u32,
434}
435
436impl VersionInfo {
437    fn save(value: &super::VersionInfo) -> Self {
438        Self {
439            version: value.version as u32,
440            feature_flags: value.feature_flags.into(),
441        }
442    }
443
444    fn restore(self, trusted: bool) -> Result<vmbus_core::VersionInfo, RestoreError> {
445        let version = super::SUPPORTED_VERSIONS
446            .iter()
447            .find(|v| self.version == **v as u32)
448            .copied()
449            .ok_or(RestoreError::UnsupportedVersion(self.version))?;
450
451        let feature_flags = FeatureFlags::from(self.feature_flags);
452        let supported_flags = SUPPORTED_FEATURE_FLAGS.with_confidential_channels(trusted);
453        if !supported_flags.contains(feature_flags) {
454            return Err(RestoreError::UnsupportedFeatureFlags(feature_flags.into()));
455        }
456
457        Ok(super::VersionInfo {
458            version,
459            feature_flags,
460        })
461    }
462}
463
464#[derive(Debug, Clone, Protobuf)]
465#[mesh(package = "vmbus.server.channels")]
466pub enum Connection {
467    #[mesh(1)]
468    Disconnecting {
469        #[mesh(1)]
470        next_action: ConnectionAction,
471    },
472    #[mesh(2)]
473    Connecting {
474        #[mesh(1)]
475        version: VersionInfo,
476        #[mesh(2)]
477        interrupt_page: Option<u64>,
478        #[mesh(3)]
479        monitor_page: Option<MonitorPageGpas>,
480        #[mesh(4)]
481        target_message_vp: u32,
482        #[mesh(5)]
483        next_action: ConnectionAction,
484        #[mesh(6)]
485        client_id: Option<Guid>,
486        #[mesh(7)]
487        trusted: bool,
488    },
489    #[mesh(3)]
490    Connected {
491        #[mesh(1)]
492        version: VersionInfo,
493        #[mesh(2)]
494        offers_sent: bool,
495        #[mesh(3)]
496        interrupt_page: Option<u64>,
497        #[mesh(4)]
498        monitor_page: Option<MonitorPageGpas>,
499        #[mesh(5)]
500        target_message_vp: u32,
501        #[mesh(6)]
502        modifying: bool,
503        #[mesh(7)]
504        client_id: Option<Guid>,
505        #[mesh(8)]
506        trusted: bool,
507        #[mesh(9)]
508        paused: bool,
509    },
510}
511
512impl Connection {
513    fn save(value: &super::ConnectionState) -> Option<Self> {
514        match value {
515            super::ConnectionState::Disconnected => {
516                // No state to save.
517                None
518            }
519            super::ConnectionState::Connecting { info, next_action } => {
520                Some(Connection::Connecting {
521                    version: VersionInfo::save(&info.version),
522                    interrupt_page: info.interrupt_page,
523                    monitor_page: info.monitor_page.map(MonitorPageGpas::save),
524                    target_message_vp: info.target_message_vp,
525                    next_action: ConnectionAction::save(next_action),
526                    client_id: Some(info.client_id),
527                    trusted: info.trusted,
528                })
529            }
530            super::ConnectionState::Connected(info) => Some(Connection::Connected {
531                version: VersionInfo::save(&info.version),
532                offers_sent: info.offers_sent,
533                interrupt_page: info.interrupt_page,
534                monitor_page: info.monitor_page.map(MonitorPageGpas::save),
535                target_message_vp: info.target_message_vp,
536                modifying: info.modifying,
537                client_id: Some(info.client_id),
538                trusted: info.trusted,
539                paused: info.paused,
540            }),
541            super::ConnectionState::Disconnecting {
542                next_action,
543                modify_sent: _,
544            } => Some(Connection::Disconnecting {
545                next_action: ConnectionAction::save(next_action),
546            }),
547        }
548    }
549
550    fn restore(self) -> Result<super::ConnectionState, RestoreError> {
551        Ok(match self {
552            Connection::Connecting {
553                version,
554                interrupt_page,
555                monitor_page,
556                target_message_vp,
557                next_action,
558                client_id,
559                trusted,
560            } => super::ConnectionState::Connecting {
561                info: super::ConnectionInfo {
562                    version: version.restore(trusted)?,
563                    trusted,
564                    interrupt_page,
565                    monitor_page: monitor_page.map(MonitorPageGpas::restore),
566                    target_message_vp,
567                    offers_sent: false,
568                    modifying: false,
569                    client_id: client_id.unwrap_or(Guid::ZERO),
570                    paused: false,
571                },
572                next_action: next_action.restore(),
573            },
574            Connection::Connected {
575                version,
576                offers_sent,
577                interrupt_page,
578                monitor_page,
579                target_message_vp,
580                modifying,
581                client_id,
582                trusted,
583                paused,
584            } => super::ConnectionState::Connected(super::ConnectionInfo {
585                version: version.restore(trusted)?,
586                trusted,
587                offers_sent,
588                interrupt_page,
589                monitor_page: monitor_page.map(MonitorPageGpas::restore),
590                target_message_vp,
591                modifying,
592                client_id: client_id.unwrap_or(Guid::ZERO),
593                paused,
594            }),
595            Connection::Disconnecting { next_action } => super::ConnectionState::Disconnecting {
596                next_action: next_action.restore(),
597                // If the modify request was sent, it will be resent.
598                modify_sent: false,
599            },
600        })
601    }
602}
603
604#[derive(Debug, Clone, Protobuf)]
605#[mesh(package = "vmbus.server.channels")]
606pub enum ConnectionAction {
607    #[mesh(1)]
608    None,
609    #[mesh(2)]
610    SendUnloadComplete,
611    #[mesh(3)]
612    Reconnect {
613        #[mesh(1)]
614        initiate_contact: InitiateContactRequest,
615    },
616    #[mesh(4)]
617    SendFailedVersionResponse,
618}
619
620impl ConnectionAction {
621    fn save(value: &super::ConnectionAction) -> Self {
622        match value {
623            super::ConnectionAction::Reset | super::ConnectionAction::None => {
624                // The caller is responsible for remembering that a
625                // reset was in progress and reissuing it.
626                Self::None
627            }
628            super::ConnectionAction::SendUnloadComplete => Self::SendUnloadComplete,
629            super::ConnectionAction::Reconnect { initiate_contact } => Self::Reconnect {
630                initiate_contact: InitiateContactRequest::save(initiate_contact),
631            },
632            super::ConnectionAction::SendFailedVersionResponse => Self::SendFailedVersionResponse,
633        }
634    }
635
636    fn restore(self) -> super::ConnectionAction {
637        match self {
638            Self::None => super::ConnectionAction::None,
639            Self::SendUnloadComplete => super::ConnectionAction::SendUnloadComplete,
640            Self::Reconnect { initiate_contact } => super::ConnectionAction::Reconnect {
641                initiate_contact: initiate_contact.restore(),
642            },
643            Self::SendFailedVersionResponse => super::ConnectionAction::SendFailedVersionResponse,
644        }
645    }
646}
647
648#[derive(Debug, Clone, Protobuf)]
649#[mesh(package = "vmbus.server.channels")]
650pub struct Channel {
651    #[mesh(1)]
652    pub key: OfferKey,
653    #[mesh(2)]
654    pub channel_id: u32,
655    #[mesh(3)]
656    pub offered_connection_id: u32,
657    #[mesh(4)]
658    pub state: ChannelState,
659    #[mesh(5)]
660    pub monitor_id: Option<u8>,
661}
662
663impl Channel {
664    fn save(value: &super::Channel) -> Option<Self> {
665        let info = value.info.as_ref()?;
666        let key = value.offer.key();
667        if let Some(state) = ChannelState::save(&value.state) {
668            tracing::trace!(%key, %state, "channel saved");
669            Some(Channel {
670                channel_id: info.channel_id.0,
671                offered_connection_id: info.connection_id,
672                key,
673                state,
674                monitor_id: info.monitor_id.map(|id| id.0),
675            })
676        } else {
677            tracing::info!(%key, state = %value.state, "skipping channel save");
678            None
679        }
680    }
681
682    fn restore(
683        &self,
684    ) -> Result<(OfferedInfo, OfferParamsInternal, super::ChannelState), RestoreError> {
685        let info = OfferedInfo {
686            channel_id: ChannelId(self.channel_id),
687            connection_id: self.offered_connection_id,
688            monitor_id: self.monitor_id.map(MonitorId),
689        };
690
691        let stub_offer = OfferParamsInternal {
692            instance_id: self.key.instance_id,
693            interface_id: self.key.interface_id,
694            subchannel_index: self.key.subchannel_index,
695            ..Default::default()
696        };
697
698        let state = self.state.restore()?;
699        tracing::info!(key = %self.key, %state, "channel restored");
700        Ok((info, stub_offer, state))
701    }
702
703    pub fn channel_id(&self) -> u32 {
704        self.channel_id
705    }
706
707    pub fn key(&self) -> OfferKey {
708        self.key
709    }
710
711    pub fn open_request(&self) -> Option<OpenRequest> {
712        match self.state {
713            ChannelState::Closed => None,
714            ChannelState::Opening { request, .. } => Some(request),
715            ChannelState::Open { params, .. } => Some(params),
716            ChannelState::Closing { params, .. } => Some(params),
717            ChannelState::ClosingReopen { params, .. } => Some(params),
718            ChannelState::Revoked => None,
719        }
720    }
721}
722
723#[derive(PartialEq, Eq, Debug, Copy, Clone, Protobuf)]
724#[mesh(package = "vmbus.server.channels")]
725pub struct InitiateContactRequest {
726    #[mesh(1)]
727    pub version_requested: u32,
728    #[mesh(2)]
729    pub target_message_vp: u32,
730    #[mesh(3)]
731    pub monitor_page: MonitorPageRequest,
732    #[mesh(4)]
733    pub target_sint: u8,
734    #[mesh(5)]
735    pub target_vtl: u8,
736    #[mesh(6)]
737    pub feature_flags: u32,
738    #[mesh(7)]
739    pub interrupt_page: Option<u64>,
740    #[mesh(8)]
741    pub client_id: Guid,
742    #[mesh(9)]
743    pub trusted: bool,
744}
745
746impl InitiateContactRequest {
747    fn save(value: &super::InitiateContactRequest) -> Self {
748        Self {
749            version_requested: value.version_requested,
750            target_message_vp: value.target_message_vp,
751            monitor_page: MonitorPageRequest::save(value.monitor_page),
752            target_sint: value.target_sint,
753            target_vtl: value.target_vtl,
754            feature_flags: value.feature_flags,
755            interrupt_page: value.interrupt_page,
756            client_id: value.client_id,
757            trusted: value.trusted,
758        }
759    }
760
761    fn restore(self) -> super::InitiateContactRequest {
762        super::InitiateContactRequest {
763            version_requested: self.version_requested,
764            target_message_vp: self.target_message_vp,
765            monitor_page: self.monitor_page.restore(),
766            target_sint: self.target_sint,
767            target_vtl: self.target_vtl,
768            feature_flags: self.feature_flags,
769            interrupt_page: self.interrupt_page,
770            client_id: self.client_id,
771            trusted: self.trusted,
772        }
773    }
774}
775
776#[derive(PartialEq, Eq, Debug, Copy, Clone, Protobuf)]
777#[mesh(package = "vmbus.server.channels")]
778pub struct MonitorPageGpas {
779    #[mesh(1)]
780    pub parent_to_child: u64,
781    #[mesh(2)]
782    pub child_to_parent: u64,
783}
784
785impl MonitorPageGpas {
786    fn save(value: super::MonitorPageGpaInfo) -> Self {
787        assert!(
788            !value.server_allocated,
789            "cannot save with server-allocated monitor pages"
790        );
791        Self {
792            child_to_parent: value.gpas.child_to_parent,
793            parent_to_child: value.gpas.parent_to_child,
794        }
795    }
796
797    fn restore(self) -> super::MonitorPageGpaInfo {
798        super::MonitorPageGpaInfo::from_guest_gpas(super::MonitorPageGpas {
799            child_to_parent: self.child_to_parent,
800            parent_to_child: self.parent_to_child,
801        })
802    }
803}
804
805#[derive(PartialEq, Eq, Debug, Copy, Clone, Protobuf)]
806#[mesh(package = "vmbus.server.channels")]
807pub enum MonitorPageRequest {
808    #[mesh(1)]
809    None,
810    #[mesh(2)]
811    Some(#[mesh(1)] MonitorPageGpas),
812    #[mesh(3)]
813    Invalid,
814}
815
816impl MonitorPageRequest {
817    fn save(value: super::MonitorPageRequest) -> Self {
818        match value {
819            super::MonitorPageRequest::None => MonitorPageRequest::None,
820            super::MonitorPageRequest::Some(mp) => MonitorPageRequest::Some(MonitorPageGpas::save(
821                super::MonitorPageGpaInfo::from_guest_gpas(mp),
822            )),
823            super::MonitorPageRequest::Invalid => MonitorPageRequest::Invalid,
824        }
825    }
826
827    fn restore(self) -> super::MonitorPageRequest {
828        match self {
829            MonitorPageRequest::None => super::MonitorPageRequest::None,
830            MonitorPageRequest::Some(mp) => super::MonitorPageRequest::Some(mp.restore().gpas),
831            MonitorPageRequest::Invalid => super::MonitorPageRequest::Invalid,
832        }
833    }
834}
835
836#[derive(Debug, Copy, Clone, Protobuf)]
837#[mesh(package = "vmbus.server.channels")]
838pub struct SignalInfo {
839    #[mesh(1)]
840    pub event_flag: u16,
841    #[mesh(2)]
842    pub connection_id: u32,
843}
844
845impl SignalInfo {
846    fn save(value: &super::SignalInfo) -> Self {
847        Self {
848            event_flag: value.event_flag,
849            connection_id: value.connection_id,
850        }
851    }
852
853    fn restore(self) -> super::SignalInfo {
854        super::SignalInfo {
855            event_flag: self.event_flag,
856            connection_id: self.connection_id,
857        }
858    }
859}
860
861#[derive(Debug, Copy, Clone, Protobuf)]
862#[mesh(package = "vmbus.server.channels")]
863pub struct OpenRequest {
864    #[mesh(1)]
865    pub open_id: u32,
866    #[mesh(2)]
867    pub ring_buffer_gpadl_id: GpadlId,
868    #[mesh(3)]
869    pub target_vp: u32,
870    #[mesh(4)]
871    pub downstream_ring_buffer_page_offset: u32,
872    #[mesh(5)]
873    pub user_data: [u8; 120],
874    #[mesh(6)]
875    pub guest_specified_interrupt_info: Option<SignalInfo>,
876    #[mesh(7)]
877    pub flags: u16,
878}
879
880impl OpenRequest {
881    fn save(value: &super::OpenRequest) -> Self {
882        Self {
883            open_id: value.open_id,
884            ring_buffer_gpadl_id: value.ring_buffer_gpadl_id,
885            target_vp: value.target_vp,
886            downstream_ring_buffer_page_offset: value.downstream_ring_buffer_page_offset,
887            user_data: value.user_data.into(),
888            guest_specified_interrupt_info: value
889                .guest_specified_interrupt_info
890                .as_ref()
891                .map(SignalInfo::save),
892            flags: value.flags.into(),
893        }
894    }
895
896    fn restore(self) -> super::OpenRequest {
897        super::OpenRequest {
898            open_id: self.open_id,
899            ring_buffer_gpadl_id: self.ring_buffer_gpadl_id,
900            target_vp: self.target_vp,
901            downstream_ring_buffer_page_offset: self.downstream_ring_buffer_page_offset,
902            user_data: self.user_data.into(),
903            guest_specified_interrupt_info: self
904                .guest_specified_interrupt_info
905                .map(SignalInfo::restore),
906            flags: self.flags.into(),
907        }
908    }
909}
910
911#[derive(Debug, Copy, Clone, Protobuf)]
912#[mesh(package = "vmbus.server.channels")]
913pub enum ModifyState {
914    #[mesh(1)]
915    NotModifying,
916    #[mesh(2)]
917    Modifying {
918        #[mesh(1)]
919        pending_target_vp: Option<u32>,
920    },
921}
922
923impl ModifyState {
924    fn save(value: &super::ModifyState) -> Self {
925        match value {
926            super::ModifyState::NotModifying => Self::NotModifying,
927            super::ModifyState::Modifying { pending_target_vp } => Self::Modifying {
928                pending_target_vp: *pending_target_vp,
929            },
930        }
931    }
932
933    fn restore(self) -> super::ModifyState {
934        match self {
935            ModifyState::NotModifying => super::ModifyState::NotModifying,
936            ModifyState::Modifying { pending_target_vp } => {
937                super::ModifyState::Modifying { pending_target_vp }
938            }
939        }
940    }
941}
942
943#[derive(Debug, Clone, Protobuf)]
944#[mesh(package = "vmbus.server.channels")]
945pub struct ReservedState {
946    #[mesh(1)]
947    pub version: VersionInfo,
948    #[mesh(2)]
949    pub vp: u32,
950    #[mesh(3)]
951    pub sint: u8,
952}
953
954impl ReservedState {
955    fn save(reserved_state: &super::ReservedState) -> Self {
956        Self {
957            version: VersionInfo::save(&reserved_state.version),
958            vp: reserved_state.target.vp,
959            sint: reserved_state.target.sint,
960        }
961    }
962
963    fn restore(&self) -> Result<super::ReservedState, RestoreError> {
964        // We don't know if the connection when the channel was reserved was trusted, so assume it
965        // was for what feature flags are accepted here; it doesn't affect any actual behavior.
966        let version = self.version.clone().restore(true).map_err(|e| match e {
967            RestoreError::UnsupportedVersion(v) => RestoreError::UnsupportedReserveVersion(v),
968            RestoreError::UnsupportedFeatureFlags(f) => {
969                RestoreError::UnsupportedReserveFeatureFlags(f)
970            }
971            err => err,
972        })?;
973
974        if version.version < Version::Win10 {
975            return Err(RestoreError::UnsupportedReserveVersion(
976                version.version as u32,
977            ));
978        }
979
980        Ok(super::ReservedState {
981            version,
982            target: super::ConnectionTarget {
983                vp: self.vp,
984                sint: self.sint,
985            },
986        })
987    }
988}
989
990#[derive(Debug, Clone, Protobuf)]
991#[mesh(package = "vmbus.server.channels")]
992pub enum ChannelState {
993    #[mesh(1)]
994    Closed,
995    #[mesh(2)]
996    Opening {
997        #[mesh(1)]
998        request: OpenRequest,
999        #[mesh(2)]
1000        reserved_state: Option<ReservedState>,
1001    },
1002    #[mesh(3)]
1003    Open {
1004        #[mesh(1)]
1005        params: OpenRequest,
1006        #[mesh(2)]
1007        modify_state: ModifyState,
1008        #[mesh(3)]
1009        reserved_state: Option<ReservedState>,
1010    },
1011    #[mesh(4)]
1012    Closing {
1013        #[mesh(1)]
1014        params: OpenRequest,
1015        #[mesh(2)]
1016        reserved_state: Option<ReservedState>,
1017    },
1018    #[mesh(5)]
1019    ClosingReopen {
1020        #[mesh(1)]
1021        params: OpenRequest,
1022        #[mesh(2)]
1023        request: OpenRequest,
1024    },
1025    #[mesh(6)]
1026    Revoked,
1027}
1028
1029impl ChannelState {
1030    fn save(value: &super::ChannelState) -> Option<Self> {
1031        Some(match value {
1032            super::ChannelState::Closed => ChannelState::Closed,
1033            super::ChannelState::Opening {
1034                request,
1035                reserved_state,
1036            } => ChannelState::Opening {
1037                request: OpenRequest::save(request),
1038                reserved_state: reserved_state.as_ref().map(ReservedState::save),
1039            },
1040            super::ChannelState::ClosingReopen { params, request } => ChannelState::ClosingReopen {
1041                params: OpenRequest::save(params),
1042                request: OpenRequest::save(request),
1043            },
1044            super::ChannelState::Open {
1045                params,
1046                modify_state,
1047                reserved_state,
1048            } => ChannelState::Open {
1049                params: OpenRequest::save(params),
1050                modify_state: ModifyState::save(modify_state),
1051                reserved_state: reserved_state.as_ref().map(ReservedState::save),
1052            },
1053            super::ChannelState::Closing {
1054                params,
1055                reserved_state,
1056            } => ChannelState::Closing {
1057                params: OpenRequest::save(params),
1058                reserved_state: reserved_state.as_ref().map(ReservedState::save),
1059            },
1060
1061            super::ChannelState::Revoked => ChannelState::Revoked,
1062            super::ChannelState::Reoffered => ChannelState::Revoked,
1063            super::ChannelState::ClientReleased
1064            | super::ChannelState::ClosingClientRelease
1065            | super::ChannelState::OpeningClientRelease => return None,
1066        })
1067    }
1068
1069    fn restore(&self) -> Result<super::ChannelState, RestoreError> {
1070        Ok(match self {
1071            ChannelState::Closed => super::ChannelState::Closed,
1072            ChannelState::Opening {
1073                request,
1074                reserved_state,
1075            } => super::ChannelState::Opening {
1076                request: request.restore(),
1077                reserved_state: reserved_state
1078                    .as_ref()
1079                    .map(ReservedState::restore)
1080                    .transpose()?,
1081            },
1082            ChannelState::ClosingReopen { params, request } => super::ChannelState::ClosingReopen {
1083                params: params.restore(),
1084                request: request.restore(),
1085            },
1086            ChannelState::Open {
1087                params,
1088                modify_state,
1089                reserved_state,
1090            } => super::ChannelState::Open {
1091                params: params.restore(),
1092                modify_state: modify_state.restore(),
1093                reserved_state: reserved_state
1094                    .as_ref()
1095                    .map(ReservedState::restore)
1096                    .transpose()?,
1097            },
1098            ChannelState::Closing {
1099                params,
1100                reserved_state,
1101            } => super::ChannelState::Closing {
1102                params: params.restore(),
1103                reserved_state: reserved_state
1104                    .as_ref()
1105                    .map(ReservedState::restore)
1106                    .transpose()?,
1107            },
1108            ChannelState::Revoked => {
1109                // Mark it reoffered for now. This may transition back to revoked in post_restore.
1110                super::ChannelState::Reoffered
1111            }
1112        })
1113    }
1114}
1115
1116impl Display for ChannelState {
1117    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1118        let state = match self {
1119            Self::Closed => "Closed",
1120            Self::Opening { .. } => "Opening",
1121            Self::Open { .. } => "Open",
1122            Self::Closing { .. } => "Closing",
1123            Self::ClosingReopen { .. } => "ClosingReopen",
1124            Self::Revoked => "Revoked",
1125        };
1126        write!(f, "{}", state)
1127    }
1128}
1129
1130#[derive(Debug, Clone, Protobuf)]
1131#[mesh(package = "vmbus.server.channels")]
1132pub struct Gpadl {
1133    #[mesh(1)]
1134    pub id: u32,
1135    #[mesh(2)]
1136    pub channel_id: u32,
1137    #[mesh(3)]
1138    pub count: u16,
1139    #[mesh(4)]
1140    pub buf: Vec<u64>,
1141    #[mesh(5)]
1142    pub state: GpadlState,
1143}
1144
1145impl Gpadl {
1146    fn save(gpadl_id: GpadlId, channel_id: ChannelId, gpadl: &super::Gpadl) -> Option<Self> {
1147        tracing::trace!(id = %gpadl_id.0, channel_id = %channel_id.0, "gpadl saved");
1148        Some(Gpadl {
1149            id: gpadl_id.0,
1150            channel_id: channel_id.0,
1151            count: gpadl.count,
1152            buf: gpadl.buf.clone(),
1153            state: match gpadl.state {
1154                super::GpadlState::InProgress => GpadlState::InProgress,
1155                super::GpadlState::Offered => GpadlState::Offered,
1156                super::GpadlState::Accepted => GpadlState::Accepted,
1157                super::GpadlState::TearingDown => GpadlState::TearingDown,
1158                super::GpadlState::OfferedTearingDown => return None,
1159            },
1160        })
1161    }
1162
1163    fn restore(self, channel: &super::Channel) -> Result<super::Gpadl, RestoreError> {
1164        if self.state != GpadlState::InProgress {
1165            // Validate the range.
1166            gparange::validate_gpa_ranges(self.count.into(), &self.buf)?;
1167        }
1168        let (state, allow_revoked) = match self.state {
1169            GpadlState::InProgress => (super::GpadlState::InProgress, true),
1170            GpadlState::Offered => (super::GpadlState::Offered, false),
1171            GpadlState::Accepted => {
1172                // It is assumed the device already knows about this GPADL.
1173                (super::GpadlState::Accepted, true)
1174            }
1175            GpadlState::TearingDown => (super::GpadlState::TearingDown, false),
1176        };
1177
1178        if !allow_revoked && channel.state.is_revoked() {
1179            return Err(RestoreError::GpadlForRevokedChannel(
1180                GpadlId(self.id),
1181                ChannelId(self.channel_id),
1182            ));
1183        }
1184
1185        Ok(super::Gpadl {
1186            count: self.count,
1187            buf: self.buf,
1188            state,
1189        })
1190    }
1191
1192    pub fn is_tearing_down(&self) -> bool {
1193        self.state == GpadlState::TearingDown
1194    }
1195}
1196
1197#[derive(Debug, Clone, Protobuf, PartialEq, Eq)]
1198#[mesh(package = "vmbus.server.channels")]
1199pub enum GpadlState {
1200    #[mesh(1)]
1201    InProgress,
1202    #[mesh(2)]
1203    Offered,
1204    #[mesh(3)]
1205    Accepted,
1206    #[mesh(4)]
1207    TearingDown,
1208}
1209
1210#[derive(Debug, Clone, Protobuf, PartialEq, Eq)]
1211#[mesh(package = "vmbus.server.channels")]
1212pub struct OutgoingMessage(pub Vec<u8>);
1213
1214impl OutgoingMessage {
1215    fn save(value: &vmbus_core::OutgoingMessage) -> Self {
1216        Self(value.data().to_vec())
1217    }
1218
1219    fn restore(self) -> Result<vmbus_core::OutgoingMessage, RestoreError> {
1220        vmbus_core::OutgoingMessage::from_message(&self.0)
1221            .map_err(|_| RestoreError::MessageTooLarge)
1222    }
1223}