vmbus_server/channels/
saved_state.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4use super::MnfUsage;
5use super::Notifier;
6use super::OfferError;
7use super::OfferParamsInternal;
8use super::OfferedInfo;
9use super::RestoreState;
10use super::SUPPORTED_FEATURE_FLAGS;
11use guid::Guid;
12pub use inner::SavedState;
13use mesh::payload::Protobuf;
14use std::fmt::Display;
15use thiserror::Error;
16use vmbus_channel::bus::OfferKey;
17use vmbus_core::protocol;
18use vmbus_core::protocol::ChannelId;
19use vmbus_core::protocol::FeatureFlags;
20use vmbus_core::protocol::GpadlId;
21use vmbus_core::protocol::Version;
22use vmbus_ring::gparange;
23use vmcore::monitor::MonitorId;
24
25impl super::Server {
26    fn restore_one_channel(&mut self, saved_channel: Channel) -> Result<(), RestoreError> {
27        let (info, stub_offer, state) = saved_channel.restore()?;
28        if let Some((offer_id, channel)) = self.channels.get_by_key_mut(&saved_channel.key) {
29            // There is an existing channel. Restore on top of it.
30
31            if !matches!(channel.state, super::ChannelState::ClientReleased)
32                || channel.restore_state != RestoreState::New
33            {
34                return Err(RestoreError::AlreadyRestored(saved_channel.key));
35            }
36
37            // The channel's monitor ID can be already set if it was set by the device, which is
38            // the case with relay channels. In that case, it must match the saved ID.
39            if let MnfUsage::Relayed { monitor_id } = channel.offer.use_mnf {
40                if info.monitor_id != Some(MonitorId(monitor_id)) {
41                    return Err(RestoreError::MismatchedMonitorId(
42                        monitor_id,
43                        saved_channel.monitor_id,
44                    ));
45                }
46            }
47
48            self.assigned_channels
49                .set(info.channel_id)?
50                .insert(offer_id);
51
52            channel.state = state;
53            channel.restore_state = RestoreState::Restoring;
54            channel.info = Some(info);
55        } else {
56            // There is no existing channel.
57
58            let entry = self
59                .assigned_channels
60                .set(ChannelId(saved_channel.channel_id))?;
61
62            let channel = super::Channel {
63                info: Some(info),
64                offer: stub_offer,
65                state,
66                restore_state: RestoreState::Unmatched,
67            };
68
69            let offer_id = self.channels.offer(channel);
70            entry.insert(offer_id);
71        }
72        Ok(())
73    }
74
75    fn restore_one_gpadl(&mut self, saved_gpadl: Gpadl) -> Result<(), RestoreError> {
76        let gpadl_id = GpadlId(saved_gpadl.id);
77        let channel_id = ChannelId(saved_gpadl.channel_id);
78        let (offer_id, channel) = self
79            .channels
80            .get_by_channel_id(&self.assigned_channels, channel_id)
81            .map_err(|_| RestoreError::MissingGpadlChannel(gpadl_id, channel_id))?;
82
83        if channel.restore_state == RestoreState::New || channel.state.is_released() {
84            return Err(RestoreError::MissingGpadlChannel(gpadl_id, channel_id));
85        }
86
87        let gpadl = saved_gpadl.restore(channel)?;
88        let state = gpadl.state;
89        if self.gpadls.insert((gpadl_id, offer_id), gpadl).is_some() {
90            return Err(RestoreError::GpadlIdInUse(gpadl_id, channel_id));
91        }
92
93        if state == super::GpadlState::InProgress
94            && self.incomplete_gpadls.insert(gpadl_id, offer_id).is_some()
95        {
96            unreachable!("gpadl ID validated above");
97        }
98
99        Ok(())
100    }
101
102    /// Saves state.
103    pub fn save(&self) -> SavedState {
104        SavedStateData {
105            state: if let Some(state) = self.save_connected_state() {
106                SavedConnectionState::Connected(state)
107            } else {
108                SavedConnectionState::Disconnected(self.save_disconnected_state())
109            },
110            pending_messages: self.save_pending_messages(),
111        }
112        .into()
113    }
114
115    fn save_connected_state(&self) -> Option<ConnectedState> {
116        let connection = Connection::save(&self.state)?;
117        let channels = self
118            .channels
119            .iter()
120            .filter_map(|(_, channel)| Channel::save(channel))
121            .collect();
122
123        let gpadls = self.save_gpadls();
124        Some(ConnectedState {
125            connection,
126            channels,
127            gpadls,
128        })
129    }
130
131    fn save_gpadls(&self) -> Vec<Gpadl> {
132        self.gpadls
133            .iter()
134            .filter_map(|((gpadl_id, offer_id), gpadl)| {
135                Gpadl::save(*gpadl_id, self.channels[*offer_id].info?.channel_id, gpadl)
136            })
137            .collect()
138    }
139
140    fn save_disconnected_state(&self) -> DisconnectedState {
141        // Save reserved channels only.
142        let channels = self
143            .channels
144            .iter()
145            .filter_map(|(_, channel)| {
146                channel
147                    .state
148                    .is_reserved()
149                    .then(|| Channel::save(channel))
150                    .flatten()
151            })
152            .collect();
153
154        // Save the GPADLs for reserved channels.
155        // N.B. There cannot be any other GPADLs while disconnected.
156        let gpadls = self.save_gpadls();
157        DisconnectedState {
158            reserved_channels: channels,
159            reserved_gpadls: gpadls,
160        }
161    }
162
163    fn save_pending_messages(&self) -> Vec<OutgoingMessage> {
164        self.pending_messages
165            .0
166            .iter()
167            .map(OutgoingMessage::save)
168            .collect()
169    }
170}
171
172impl<'a, N: 'a + Notifier> super::ServerWithNotifier<'a, N> {
173    /// Restores state.
174    ///
175    /// This may be called before or after channels have been offered. After
176    /// calling this routine, [`restore_channel`] should be
177    /// called for each channel to be restored, possibly interleaved with
178    /// additional calls to offer or revoke channels.
179    ///
180    /// Once all channels are in the appropriate state,
181    /// [`revoke_unclaimed_channels`] should be called. This will revoke
182    /// any channels that were in the saved state but were not restored via
183    /// [`restore_channel`].
184    ///
185    /// [`revoke_unclaimed_channels`]: super::ServerWithNotifier::revoke_unclaimed_channels
186    /// [`restore_channel`]: super::ServerWithNotifier::restore_channel
187    pub fn restore(&mut self, saved: SavedState) -> Result<(), RestoreError> {
188        tracing::trace!(?saved, "restoring channel state");
189
190        let saved = SavedStateData::from(saved);
191        match saved.state {
192            SavedConnectionState::Connected(saved) => {
193                self.inner.state = saved.connection.restore()?;
194
195                // Restore server state, and resend server notifications if needed. If these notifications
196                // were processed before the save, it's harmless as the values will be the same.
197                let request = match self.inner.state {
198                    super::ConnectionState::Connecting {
199                        info,
200                        next_action: _,
201                    } => Some(super::ModifyConnectionRequest {
202                        version: Some(info.version),
203                        interrupt_page: info.interrupt_page.into(),
204                        monitor_page: info.monitor_page.map(|mp| mp.gpas).into(),
205                        target_message_vp: Some(info.target_message_vp),
206                        notify_relay: true,
207                    }),
208                    super::ConnectionState::Connected(info) => {
209                        Some(super::ModifyConnectionRequest {
210                            version: None,
211                            monitor_page: info.monitor_page.map(|mp| mp.gpas).into(),
212                            interrupt_page: info.interrupt_page.into(),
213                            target_message_vp: Some(info.target_message_vp),
214                            // If the save didn't happen while modifying, the relay doesn't need to be notified
215                            // of this info as it doesn't constitute a change, we're just restoring existing
216                            // connection state.
217                            notify_relay: info.modifying,
218                        })
219                    }
220                    // No action needed for these states; if disconnecting, check_disconnected will resend
221                    // the reset request if needed.
222                    super::ConnectionState::Disconnected
223                    | super::ConnectionState::Disconnecting { .. } => None,
224                };
225
226                if let Some(request) = request {
227                    self.notifier.modify_connection(request)?;
228                }
229
230                for saved_channel in saved.channels {
231                    self.inner.restore_one_channel(saved_channel)?;
232                }
233
234                for saved_gpadl in saved.gpadls {
235                    self.inner.restore_one_gpadl(saved_gpadl)?;
236                }
237            }
238            SavedConnectionState::Disconnected(saved) => {
239                self.inner.state = super::ConnectionState::Disconnected;
240                for saved_channel in saved.reserved_channels {
241                    self.inner.restore_one_channel(saved_channel)?;
242                }
243
244                for saved_gpadl in saved.reserved_gpadls {
245                    self.inner.restore_one_gpadl(saved_gpadl)?;
246                }
247            }
248        }
249
250        self.inner
251            .pending_messages
252            .0
253            .reserve(saved.pending_messages.len());
254
255        for message in saved.pending_messages {
256            self.inner.pending_messages.0.push_back(message.restore()?);
257        }
258
259        Ok(())
260    }
261}
262
263#[derive(Debug, Error)]
264pub enum RestoreError {
265    #[error(transparent)]
266    Offer(#[from] OfferError),
267
268    #[error("channel {0} has already been restored")]
269    AlreadyRestored(OfferKey),
270
271    #[error("gpadl {} is for missing channel {}", (.0).0, (.1).0)]
272    MissingGpadlChannel(GpadlId, ChannelId),
273
274    #[error("gpadl {} is for revoked channel {}", (.0).0, (.1).0)]
275    GpadlForRevokedChannel(GpadlId, ChannelId),
276
277    #[error("gpadl {} is already restored", (.0).0)]
278    GpadlIdInUse(GpadlId, ChannelId),
279
280    #[error("unsupported protocol version {0:#x}")]
281    UnsupportedVersion(u32),
282
283    #[error("invalid gpadl")]
284    InvalidGpadl(#[from] gparange::Error),
285
286    #[error("unsupported feature flags {0:#x}")]
287    UnsupportedFeatureFlags(u32),
288
289    #[error("channel {0} has a mismatched open state")]
290    MismatchedOpenState(OfferKey),
291
292    #[error("channel {0} is missing from the saved state")]
293    MissingChannel(OfferKey),
294
295    #[error("unsupported reserved channel protocol version {0:#x}")]
296    UnsupportedReserveVersion(u32),
297
298    #[error("unsupported reserved channel feature flags {0:#x}")]
299    UnsupportedReserveFeatureFlags(u32),
300
301    #[error("mismatched monitor id; expected {0}, actual {1:?}")]
302    MismatchedMonitorId(u8, Option<u8>),
303
304    #[error("monitor ID used by multiple channels in the saved state")]
305    DuplicateMonitorId(u8),
306
307    #[error(transparent)]
308    ServerError(#[from] anyhow::Error),
309
310    #[error(
311        "reserved channel with ID {0} has a pending message but is missing from the saved state"
312    )]
313    MissingReservedChannel(u32),
314    #[error("a saved pending message is larger than the maximum message size")]
315    MessageTooLarge,
316}
317
318mod inner {
319    use super::*;
320
321    /// The top-level saved state for the VMBus channels library. It is placed in its own module to
322    /// keep the internals private, and the only thing you can do with it is convert to/from
323    /// `SavedStateData`. This enforces that users always consider both the connected and
324    /// disconnected states.
325    #[derive(Debug, Protobuf, Clone)]
326    #[mesh(package = "vmbus.server.channels")]
327    pub struct SavedState {
328        #[mesh(1)]
329        state: Option<ConnectedState>,
330        // Disconnected state is used to save any open reserved channels while the guest is
331        // disconnected. It is mutually exclusive with `state`, but is separate to maintain saved
332        // state compatibility.
333        // N.B. In a saved state created by the current version, either state or disconnected_state
334        //      is always `Some`, but for older versions, it is possible that both are `None`. They
335        //      can never both be `Some`.
336        #[mesh(2)]
337        disconnected_state: Option<DisconnectedState>,
338        #[mesh(3)]
339        pending_messages: Vec<OutgoingMessage>,
340    }
341
342    impl From<SavedStateData> for SavedState {
343        fn from(value: SavedStateData) -> Self {
344            let (state, disconnected_state) = match value.state {
345                SavedConnectionState::Connected(connected) => (Some(connected), None),
346                SavedConnectionState::Disconnected(disconnected) => (None, Some(disconnected)),
347            };
348
349            Self {
350                state,
351                disconnected_state,
352                pending_messages: value.pending_messages,
353            }
354        }
355    }
356
357    impl From<SavedState> for SavedStateData {
358        fn from(value: SavedState) -> Self {
359            Self {
360                state: if let Some(connected) = value.state {
361                    SavedConnectionState::Connected(connected)
362                } else {
363                    // Older saved state versions may not have a disconnected state, in which case
364                    // we use an empty value which has no channels or gpadls.
365                    SavedConnectionState::Disconnected(value.disconnected_state.unwrap_or_default())
366                },
367                pending_messages: value.pending_messages,
368            }
369        }
370    }
371}
372
373/// Represents either connected or disconnected saved state.
374#[derive(Debug, Clone)]
375pub enum SavedConnectionState {
376    Connected(ConnectedState),
377    Disconnected(DisconnectedState),
378}
379
380/// Alternative representation of the saved state that ensures that all code paths deal with either
381/// the connected or disconnected state, and cannot neglect one.
382#[derive(Debug, Clone)]
383pub struct SavedStateData {
384    pub state: SavedConnectionState,
385    pub pending_messages: Vec<OutgoingMessage>,
386}
387
388impl SavedStateData {
389    /// Finds a channel in the saved state.
390    pub fn find_channel(&self, offer: OfferKey) -> Option<&Channel> {
391        let (channels, _) = self.channels_and_gpadls();
392        channels.iter().find(|c| c.key == offer)
393    }
394
395    /// Retrieves all the channels and GPADLs from the saved state.
396    /// If disconnected, returns any reserved channels and their GPADLs.
397    pub fn channels_and_gpadls(&self) -> (&[Channel], &[Gpadl]) {
398        match &self.state {
399            SavedConnectionState::Connected(connected) => (&connected.channels, &connected.gpadls),
400            SavedConnectionState::Disconnected(disconnected) => (
401                &disconnected.reserved_channels,
402                &disconnected.reserved_gpadls,
403            ),
404        }
405    }
406}
407
408#[derive(Debug, Clone, Protobuf)]
409#[mesh(package = "vmbus.server.channels")]
410pub struct ConnectedState {
411    #[mesh(1)]
412    pub connection: Connection,
413    #[mesh(2)]
414    pub channels: Vec<Channel>,
415    #[mesh(3)]
416    pub gpadls: Vec<Gpadl>,
417}
418
419#[derive(Default, Debug, Clone, Protobuf)]
420#[mesh(package = "vmbus.server.channels")]
421pub struct DisconnectedState {
422    #[mesh(1)]
423    pub reserved_channels: Vec<Channel>,
424    #[mesh(2)]
425    pub reserved_gpadls: Vec<Gpadl>,
426}
427
428#[derive(Debug, PartialEq, Eq, Clone, Protobuf)]
429#[mesh(package = "vmbus.server.channels")]
430pub struct VersionInfo {
431    #[mesh(1)]
432    pub version: u32,
433    #[mesh(2)]
434    pub feature_flags: u32,
435}
436
437impl VersionInfo {
438    fn save(value: &super::VersionInfo) -> Self {
439        Self {
440            version: value.version as u32,
441            feature_flags: value.feature_flags.into(),
442        }
443    }
444
445    fn restore(self, trusted: bool) -> Result<vmbus_core::VersionInfo, RestoreError> {
446        let version = super::SUPPORTED_VERSIONS
447            .iter()
448            .find(|v| self.version == **v as u32)
449            .copied()
450            .ok_or(RestoreError::UnsupportedVersion(self.version))?;
451
452        let feature_flags = FeatureFlags::from(self.feature_flags);
453        let supported_flags = SUPPORTED_FEATURE_FLAGS.with_confidential_channels(trusted);
454        if !supported_flags.contains(feature_flags) {
455            return Err(RestoreError::UnsupportedFeatureFlags(feature_flags.into()));
456        }
457
458        Ok(super::VersionInfo {
459            version,
460            feature_flags,
461        })
462    }
463}
464
465#[derive(Debug, Clone, Protobuf)]
466#[mesh(package = "vmbus.server.channels")]
467pub enum Connection {
468    #[mesh(1)]
469    Disconnecting {
470        #[mesh(1)]
471        next_action: ConnectionAction,
472    },
473    #[mesh(2)]
474    Connecting {
475        #[mesh(1)]
476        version: VersionInfo,
477        #[mesh(2)]
478        interrupt_page: Option<u64>,
479        #[mesh(3)]
480        monitor_page: Option<MonitorPageGpas>,
481        #[mesh(4)]
482        target_message_vp: u32,
483        #[mesh(5)]
484        next_action: ConnectionAction,
485        #[mesh(6)]
486        client_id: Option<Guid>,
487        #[mesh(7)]
488        trusted: bool,
489    },
490    #[mesh(3)]
491    Connected {
492        #[mesh(1)]
493        version: VersionInfo,
494        #[mesh(2)]
495        offers_sent: bool,
496        #[mesh(3)]
497        interrupt_page: Option<u64>,
498        #[mesh(4)]
499        monitor_page: Option<MonitorPageGpas>,
500        #[mesh(5)]
501        target_message_vp: u32,
502        #[mesh(6)]
503        modifying: bool,
504        #[mesh(7)]
505        client_id: Option<Guid>,
506        #[mesh(8)]
507        trusted: bool,
508        #[mesh(9)]
509        paused: bool,
510    },
511}
512
513impl Connection {
514    fn save(value: &super::ConnectionState) -> Option<Self> {
515        match value {
516            super::ConnectionState::Disconnected => {
517                // No state to save.
518                None
519            }
520            super::ConnectionState::Connecting { info, next_action } => {
521                Some(Connection::Connecting {
522                    version: VersionInfo::save(&info.version),
523                    interrupt_page: info.interrupt_page,
524                    monitor_page: info.monitor_page.map(MonitorPageGpas::save),
525                    target_message_vp: info.target_message_vp,
526                    next_action: ConnectionAction::save(next_action),
527                    client_id: Some(info.client_id),
528                    trusted: info.trusted,
529                })
530            }
531            super::ConnectionState::Connected(info) => Some(Connection::Connected {
532                version: VersionInfo::save(&info.version),
533                offers_sent: info.offers_sent,
534                interrupt_page: info.interrupt_page,
535                monitor_page: info.monitor_page.map(MonitorPageGpas::save),
536                target_message_vp: info.target_message_vp,
537                modifying: info.modifying,
538                client_id: Some(info.client_id),
539                trusted: info.trusted,
540                paused: info.paused,
541            }),
542            super::ConnectionState::Disconnecting {
543                next_action,
544                modify_sent: _,
545            } => Some(Connection::Disconnecting {
546                next_action: ConnectionAction::save(next_action),
547            }),
548        }
549    }
550
551    fn restore(self) -> Result<super::ConnectionState, RestoreError> {
552        Ok(match self {
553            Connection::Connecting {
554                version,
555                interrupt_page,
556                monitor_page,
557                target_message_vp,
558                next_action,
559                client_id,
560                trusted,
561            } => super::ConnectionState::Connecting {
562                info: super::ConnectionInfo {
563                    version: version.restore(trusted)?,
564                    trusted,
565                    interrupt_page,
566                    monitor_page: monitor_page.map(MonitorPageGpas::restore),
567                    target_message_vp,
568                    offers_sent: false,
569                    modifying: false,
570                    client_id: client_id.unwrap_or(Guid::ZERO),
571                    paused: false,
572                },
573                next_action: next_action.restore(),
574            },
575            Connection::Connected {
576                version,
577                offers_sent,
578                interrupt_page,
579                monitor_page,
580                target_message_vp,
581                modifying,
582                client_id,
583                trusted,
584                paused,
585            } => super::ConnectionState::Connected(super::ConnectionInfo {
586                version: version.restore(trusted)?,
587                trusted,
588                offers_sent,
589                interrupt_page,
590                monitor_page: monitor_page.map(MonitorPageGpas::restore),
591                target_message_vp,
592                modifying,
593                client_id: client_id.unwrap_or(Guid::ZERO),
594                paused,
595            }),
596            Connection::Disconnecting { next_action } => super::ConnectionState::Disconnecting {
597                next_action: next_action.restore(),
598                // If the modify request was sent, it will be resent.
599                modify_sent: false,
600            },
601        })
602    }
603}
604
605#[derive(Debug, Clone, Protobuf)]
606#[mesh(package = "vmbus.server.channels")]
607pub enum ConnectionAction {
608    #[mesh(1)]
609    None,
610    #[mesh(2)]
611    SendUnloadComplete,
612    #[mesh(3)]
613    Reconnect {
614        #[mesh(1)]
615        initiate_contact: InitiateContactRequest,
616    },
617    #[mesh(4)]
618    SendFailedVersionResponse,
619}
620
621impl ConnectionAction {
622    fn save(value: &super::ConnectionAction) -> Self {
623        match value {
624            super::ConnectionAction::Reset | super::ConnectionAction::None => {
625                // The caller is responsible for remembering that a
626                // reset was in progress and reissuing it.
627                Self::None
628            }
629            super::ConnectionAction::SendUnloadComplete => Self::SendUnloadComplete,
630            super::ConnectionAction::Reconnect { initiate_contact } => Self::Reconnect {
631                initiate_contact: InitiateContactRequest::save(initiate_contact),
632            },
633            super::ConnectionAction::SendFailedVersionResponse => Self::SendFailedVersionResponse,
634        }
635    }
636
637    fn restore(self) -> super::ConnectionAction {
638        match self {
639            Self::None => super::ConnectionAction::None,
640            Self::SendUnloadComplete => super::ConnectionAction::SendUnloadComplete,
641            Self::Reconnect { initiate_contact } => super::ConnectionAction::Reconnect {
642                initiate_contact: initiate_contact.restore(),
643            },
644            Self::SendFailedVersionResponse => super::ConnectionAction::SendFailedVersionResponse,
645        }
646    }
647}
648
649#[derive(Debug, Clone, Protobuf)]
650#[mesh(package = "vmbus.server.channels")]
651pub struct Channel {
652    #[mesh(1)]
653    pub key: OfferKey,
654    #[mesh(2)]
655    pub channel_id: u32,
656    #[mesh(3)]
657    pub offered_connection_id: u32,
658    #[mesh(4)]
659    pub state: ChannelState,
660    #[mesh(5)]
661    pub monitor_id: Option<u8>,
662}
663
664impl Channel {
665    fn save(value: &super::Channel) -> Option<Self> {
666        let info = value.info.as_ref()?;
667        let key = value.offer.key();
668        if let Some(state) = ChannelState::save(&value.state) {
669            tracing::trace!(%key, %state, "channel saved");
670            Some(Channel {
671                channel_id: info.channel_id.0,
672                offered_connection_id: info.connection_id,
673                key,
674                state,
675                monitor_id: info.monitor_id.map(|id| id.0),
676            })
677        } else {
678            tracing::info!(%key, state = %value.state, "skipping channel save");
679            None
680        }
681    }
682
683    fn restore(
684        &self,
685    ) -> Result<(OfferedInfo, OfferParamsInternal, super::ChannelState), RestoreError> {
686        let info = OfferedInfo {
687            channel_id: ChannelId(self.channel_id),
688            connection_id: self.offered_connection_id,
689            monitor_id: self.monitor_id.map(MonitorId),
690        };
691
692        let stub_offer = OfferParamsInternal {
693            instance_id: self.key.instance_id,
694            interface_id: self.key.interface_id,
695            subchannel_index: self.key.subchannel_index,
696            ..Default::default()
697        };
698
699        let state = self.state.restore()?;
700        tracing::info!(key = %self.key, %state, "channel restored");
701        Ok((info, stub_offer, state))
702    }
703
704    pub fn channel_id(&self) -> u32 {
705        self.channel_id
706    }
707
708    pub fn key(&self) -> OfferKey {
709        self.key
710    }
711
712    pub fn open_request(&self) -> Option<OpenRequest> {
713        match self.state {
714            ChannelState::Closed => None,
715            ChannelState::Opening { request, .. } => Some(request),
716            ChannelState::Open { params, .. } => Some(params),
717            ChannelState::Closing { params, .. } => Some(params),
718            ChannelState::ClosingReopen { params, .. } => Some(params),
719            ChannelState::Revoked => None,
720        }
721    }
722}
723
724#[derive(PartialEq, Eq, Debug, Copy, Clone, Protobuf)]
725#[mesh(package = "vmbus.server.channels")]
726pub struct InitiateContactRequest {
727    #[mesh(1)]
728    pub version_requested: u32,
729    #[mesh(2)]
730    pub target_message_vp: u32,
731    #[mesh(3)]
732    pub monitor_page: MonitorPageRequest,
733    #[mesh(4)]
734    pub target_sint: u8,
735    #[mesh(5)]
736    pub target_vtl: u8,
737    #[mesh(6)]
738    pub feature_flags: u32,
739    #[mesh(7)]
740    pub interrupt_page: Option<u64>,
741    #[mesh(8)]
742    pub client_id: Guid,
743    #[mesh(9)]
744    pub trusted: bool,
745}
746
747impl InitiateContactRequest {
748    fn save(value: &super::InitiateContactRequest) -> Self {
749        Self {
750            version_requested: value.version_requested,
751            target_message_vp: value.target_message_vp,
752            monitor_page: MonitorPageRequest::save(value.monitor_page),
753            target_sint: value.target_sint,
754            target_vtl: value.target_vtl,
755            feature_flags: value.feature_flags,
756            interrupt_page: value.interrupt_page,
757            client_id: value.client_id,
758            trusted: value.trusted,
759        }
760    }
761
762    fn restore(self) -> super::InitiateContactRequest {
763        super::InitiateContactRequest {
764            version_requested: self.version_requested,
765            target_message_vp: self.target_message_vp,
766            monitor_page: self.monitor_page.restore(),
767            target_sint: self.target_sint,
768            target_vtl: self.target_vtl,
769            feature_flags: self.feature_flags,
770            interrupt_page: self.interrupt_page,
771            client_id: self.client_id,
772            trusted: self.trusted,
773        }
774    }
775}
776
777#[derive(PartialEq, Eq, Debug, Copy, Clone, Protobuf)]
778#[mesh(package = "vmbus.server.channels")]
779pub struct MonitorPageGpas {
780    #[mesh(1)]
781    pub parent_to_child: u64,
782    #[mesh(2)]
783    pub child_to_parent: u64,
784}
785
786impl MonitorPageGpas {
787    fn save(value: super::MonitorPageGpaInfo) -> Self {
788        assert!(
789            !value.server_allocated,
790            "cannot save with server-allocated monitor pages"
791        );
792        Self {
793            child_to_parent: value.gpas.child_to_parent,
794            parent_to_child: value.gpas.parent_to_child,
795        }
796    }
797
798    fn restore(self) -> super::MonitorPageGpaInfo {
799        super::MonitorPageGpaInfo::from_guest_gpas(super::MonitorPageGpas {
800            child_to_parent: self.child_to_parent,
801            parent_to_child: self.parent_to_child,
802        })
803    }
804}
805
806#[derive(PartialEq, Eq, Debug, Copy, Clone, Protobuf)]
807#[mesh(package = "vmbus.server.channels")]
808pub enum MonitorPageRequest {
809    #[mesh(1)]
810    None,
811    #[mesh(2)]
812    Some(#[mesh(1)] MonitorPageGpas),
813    #[mesh(3)]
814    Invalid,
815}
816
817impl MonitorPageRequest {
818    fn save(value: super::MonitorPageRequest) -> Self {
819        match value {
820            super::MonitorPageRequest::None => MonitorPageRequest::None,
821            super::MonitorPageRequest::Some(mp) => MonitorPageRequest::Some(MonitorPageGpas::save(
822                super::MonitorPageGpaInfo::from_guest_gpas(mp),
823            )),
824            super::MonitorPageRequest::Invalid => MonitorPageRequest::Invalid,
825        }
826    }
827
828    fn restore(self) -> super::MonitorPageRequest {
829        match self {
830            MonitorPageRequest::None => super::MonitorPageRequest::None,
831            MonitorPageRequest::Some(mp) => super::MonitorPageRequest::Some(mp.restore().gpas),
832            MonitorPageRequest::Invalid => super::MonitorPageRequest::Invalid,
833        }
834    }
835}
836
837#[derive(PartialEq, Eq, Debug, Copy, Clone, Protobuf)]
838#[mesh(package = "vmbus.server.channels")]
839pub struct SignalInfo {
840    #[mesh(1)]
841    pub event_flag: u16,
842    #[mesh(2)]
843    pub connection_id: u32,
844}
845
846impl SignalInfo {
847    fn save(value: &super::SignalInfo) -> Self {
848        Self {
849            event_flag: value.event_flag,
850            connection_id: value.connection_id,
851        }
852    }
853
854    fn restore(self) -> super::SignalInfo {
855        super::SignalInfo {
856            event_flag: self.event_flag,
857            connection_id: self.connection_id,
858        }
859    }
860}
861
862#[derive(Copy, PartialEq, Eq, Clone, Protobuf)]
863#[mesh(package = "vmbus.server.channels")]
864pub struct OpenRequest {
865    #[mesh(1)]
866    pub open_id: u32,
867    #[mesh(2)]
868    pub ring_buffer_gpadl_id: GpadlId,
869    #[mesh(3)]
870    pub target_vp: u32,
871    #[mesh(4)]
872    pub downstream_ring_buffer_page_offset: u32,
873    #[mesh(5)]
874    pub user_data: [u8; 120],
875    #[mesh(6)]
876    pub guest_specified_interrupt_info: Option<SignalInfo>,
877    #[mesh(7)]
878    pub flags: u16,
879}
880
881impl std::fmt::Debug for OpenRequest {
882    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
883        let Self {
884            open_id,
885            ring_buffer_gpadl_id,
886            target_vp,
887            downstream_ring_buffer_page_offset,
888            user_data,
889            guest_specified_interrupt_info,
890            flags,
891        } = self;
892
893        let user_data_display: &dyn std::fmt::Debug = if self.user_data.iter().all(|&b| b == 0) {
894            &"[<all-zeroes>]"
895        } else {
896            struct HexDisplay<'a>(&'a [u8; 120]);
897            impl std::fmt::Debug for HexDisplay<'_> {
898                fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
899                    write!(f, "[")?;
900                    for byte in self.0 {
901                        write!(f, "{:02X}", byte)?;
902                    }
903                    write!(f, "]")
904                }
905            }
906            &HexDisplay(user_data)
907        };
908
909        f.debug_struct("OpenRequest")
910            .field("open_id", open_id)
911            .field("ring_buffer_gpadl_id", ring_buffer_gpadl_id)
912            .field("target_vp", target_vp)
913            .field(
914                "downstream_ring_buffer_page_offset",
915                downstream_ring_buffer_page_offset,
916            )
917            .field("user_data", user_data_display)
918            .field(
919                "guest_specified_interrupt_info",
920                guest_specified_interrupt_info,
921            )
922            .field("flags", flags)
923            .finish()
924    }
925}
926
927impl OpenRequest {
928    fn save(value: &super::OpenRequest) -> Self {
929        Self {
930            open_id: value.open_id,
931            ring_buffer_gpadl_id: value.ring_buffer_gpadl_id,
932            target_vp: value
933                .target_vp
934                .unwrap_or(protocol::VP_INDEX_DISABLE_INTERRUPT),
935            downstream_ring_buffer_page_offset: value.downstream_ring_buffer_page_offset,
936            user_data: value.user_data.into(),
937            guest_specified_interrupt_info: value
938                .guest_specified_interrupt_info
939                .as_ref()
940                .map(SignalInfo::save),
941            flags: value.flags.into(),
942        }
943    }
944
945    fn restore(self) -> super::OpenRequest {
946        super::OpenRequest {
947            open_id: self.open_id,
948            ring_buffer_gpadl_id: self.ring_buffer_gpadl_id,
949            target_vp: protocol::vp_index_if_enabled(self.target_vp),
950            downstream_ring_buffer_page_offset: self.downstream_ring_buffer_page_offset,
951            user_data: self.user_data.into(),
952            guest_specified_interrupt_info: self
953                .guest_specified_interrupt_info
954                .map(SignalInfo::restore),
955            flags: self.flags.into(),
956        }
957    }
958}
959
960#[derive(Debug, Copy, Clone, Protobuf)]
961#[mesh(package = "vmbus.server.channels")]
962pub enum ModifyState {
963    #[mesh(1)]
964    NotModifying,
965    #[mesh(2)]
966    Modifying {
967        #[mesh(1)]
968        pending_target_vp: Option<u32>,
969    },
970}
971
972impl ModifyState {
973    fn save(value: &super::ModifyState) -> Self {
974        match value {
975            super::ModifyState::NotModifying => Self::NotModifying,
976            super::ModifyState::Modifying { pending_target_vp } => Self::Modifying {
977                pending_target_vp: *pending_target_vp,
978            },
979        }
980    }
981
982    fn restore(self) -> super::ModifyState {
983        match self {
984            ModifyState::NotModifying => super::ModifyState::NotModifying,
985            ModifyState::Modifying { pending_target_vp } => {
986                super::ModifyState::Modifying { pending_target_vp }
987            }
988        }
989    }
990}
991
992#[derive(Debug, PartialEq, Eq, Clone, Protobuf)]
993#[mesh(package = "vmbus.server.channels")]
994pub struct ReservedState {
995    #[mesh(1)]
996    pub version: VersionInfo,
997    #[mesh(2)]
998    pub vp: u32,
999    #[mesh(3)]
1000    pub sint: u8,
1001}
1002
1003impl ReservedState {
1004    fn save(reserved_state: &super::ReservedState) -> Self {
1005        Self {
1006            version: VersionInfo::save(&reserved_state.version),
1007            vp: reserved_state.target.vp,
1008            sint: reserved_state.target.sint,
1009        }
1010    }
1011
1012    fn restore(&self) -> Result<super::ReservedState, RestoreError> {
1013        // We don't know if the connection when the channel was reserved was trusted, so assume it
1014        // was for what feature flags are accepted here; it doesn't affect any actual behavior.
1015        let version = self.version.clone().restore(true).map_err(|e| match e {
1016            RestoreError::UnsupportedVersion(v) => RestoreError::UnsupportedReserveVersion(v),
1017            RestoreError::UnsupportedFeatureFlags(f) => {
1018                RestoreError::UnsupportedReserveFeatureFlags(f)
1019            }
1020            err => err,
1021        })?;
1022
1023        if version.version < Version::Win10 {
1024            return Err(RestoreError::UnsupportedReserveVersion(
1025                version.version as u32,
1026            ));
1027        }
1028
1029        Ok(super::ReservedState {
1030            version,
1031            target: super::ConnectionTarget {
1032                vp: self.vp,
1033                sint: self.sint,
1034            },
1035        })
1036    }
1037}
1038
1039#[derive(Debug, Clone, Protobuf)]
1040#[mesh(package = "vmbus.server.channels")]
1041pub enum ChannelState {
1042    #[mesh(1)]
1043    Closed,
1044    #[mesh(2)]
1045    Opening {
1046        #[mesh(1)]
1047        request: OpenRequest,
1048        #[mesh(2)]
1049        reserved_state: Option<ReservedState>,
1050    },
1051    #[mesh(3)]
1052    Open {
1053        #[mesh(1)]
1054        params: OpenRequest,
1055        #[mesh(2)]
1056        modify_state: ModifyState,
1057        #[mesh(3)]
1058        reserved_state: Option<ReservedState>,
1059    },
1060    #[mesh(4)]
1061    Closing {
1062        #[mesh(1)]
1063        params: OpenRequest,
1064        #[mesh(2)]
1065        reserved_state: Option<ReservedState>,
1066    },
1067    #[mesh(5)]
1068    ClosingReopen {
1069        #[mesh(1)]
1070        params: OpenRequest,
1071        #[mesh(2)]
1072        request: OpenRequest,
1073    },
1074    #[mesh(6)]
1075    Revoked,
1076}
1077
1078impl ChannelState {
1079    fn save(value: &super::ChannelState) -> Option<Self> {
1080        Some(match value {
1081            super::ChannelState::Closed => ChannelState::Closed,
1082            super::ChannelState::Opening {
1083                request,
1084                reserved_state,
1085            } => ChannelState::Opening {
1086                request: OpenRequest::save(request),
1087                reserved_state: reserved_state.as_ref().map(ReservedState::save),
1088            },
1089            super::ChannelState::ClosingReopen { params, request } => ChannelState::ClosingReopen {
1090                params: OpenRequest::save(params),
1091                request: OpenRequest::save(request),
1092            },
1093            super::ChannelState::Open {
1094                params,
1095                modify_state,
1096                reserved_state,
1097            } => ChannelState::Open {
1098                params: OpenRequest::save(params),
1099                modify_state: ModifyState::save(modify_state),
1100                reserved_state: reserved_state.as_ref().map(ReservedState::save),
1101            },
1102            super::ChannelState::Closing {
1103                params,
1104                reserved_state,
1105            } => ChannelState::Closing {
1106                params: OpenRequest::save(params),
1107                reserved_state: reserved_state.as_ref().map(ReservedState::save),
1108            },
1109
1110            super::ChannelState::Revoked => ChannelState::Revoked,
1111            super::ChannelState::Reoffered => ChannelState::Revoked,
1112            super::ChannelState::ClientReleased
1113            | super::ChannelState::ClosingClientRelease
1114            | super::ChannelState::OpeningClientRelease => return None,
1115        })
1116    }
1117
1118    fn restore(&self) -> Result<super::ChannelState, RestoreError> {
1119        Ok(match self {
1120            ChannelState::Closed => super::ChannelState::Closed,
1121            ChannelState::Opening {
1122                request,
1123                reserved_state,
1124            } => super::ChannelState::Opening {
1125                request: request.restore(),
1126                reserved_state: reserved_state
1127                    .as_ref()
1128                    .map(ReservedState::restore)
1129                    .transpose()?,
1130            },
1131            ChannelState::ClosingReopen { params, request } => super::ChannelState::ClosingReopen {
1132                params: params.restore(),
1133                request: request.restore(),
1134            },
1135            ChannelState::Open {
1136                params,
1137                modify_state,
1138                reserved_state,
1139            } => super::ChannelState::Open {
1140                params: params.restore(),
1141                modify_state: modify_state.restore(),
1142                reserved_state: reserved_state
1143                    .as_ref()
1144                    .map(ReservedState::restore)
1145                    .transpose()?,
1146            },
1147            ChannelState::Closing {
1148                params,
1149                reserved_state,
1150            } => super::ChannelState::Closing {
1151                params: params.restore(),
1152                reserved_state: reserved_state
1153                    .as_ref()
1154                    .map(ReservedState::restore)
1155                    .transpose()?,
1156            },
1157            ChannelState::Revoked => {
1158                // Mark it reoffered for now. This may transition back to revoked in post_restore.
1159                super::ChannelState::Reoffered
1160            }
1161        })
1162    }
1163}
1164
1165impl Display for ChannelState {
1166    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1167        let state = match self {
1168            Self::Closed => "Closed",
1169            Self::Opening { .. } => "Opening",
1170            Self::Open { .. } => "Open",
1171            Self::Closing { .. } => "Closing",
1172            Self::ClosingReopen { .. } => "ClosingReopen",
1173            Self::Revoked => "Revoked",
1174        };
1175        write!(f, "{}", state)
1176    }
1177}
1178
1179#[derive(Debug, Clone, Protobuf)]
1180#[mesh(package = "vmbus.server.channels")]
1181pub struct Gpadl {
1182    #[mesh(1)]
1183    pub id: u32,
1184    #[mesh(2)]
1185    pub channel_id: u32,
1186    #[mesh(3)]
1187    pub count: u16,
1188    #[mesh(4)]
1189    pub buf: Vec<u64>,
1190    #[mesh(5)]
1191    pub state: GpadlState,
1192}
1193
1194impl Gpadl {
1195    fn save(gpadl_id: GpadlId, channel_id: ChannelId, gpadl: &super::Gpadl) -> Option<Self> {
1196        tracing::trace!(id = %gpadl_id.0, channel_id = %channel_id.0, "gpadl saved");
1197        Some(Gpadl {
1198            id: gpadl_id.0,
1199            channel_id: channel_id.0,
1200            count: gpadl.count,
1201            buf: gpadl.buf.clone(),
1202            state: match gpadl.state {
1203                super::GpadlState::InProgress => GpadlState::InProgress,
1204                super::GpadlState::Offered => GpadlState::Offered,
1205                super::GpadlState::Accepted => GpadlState::Accepted,
1206                super::GpadlState::TearingDown => GpadlState::TearingDown,
1207                super::GpadlState::OfferedTearingDown => return None,
1208            },
1209        })
1210    }
1211
1212    fn restore(self, channel: &super::Channel) -> Result<super::Gpadl, RestoreError> {
1213        if self.state != GpadlState::InProgress {
1214            // Validate the range.
1215            gparange::validate_gpa_ranges(self.count.into(), &self.buf)?;
1216        }
1217        let (state, allow_revoked) = match self.state {
1218            GpadlState::InProgress => (super::GpadlState::InProgress, true),
1219            GpadlState::Offered => (super::GpadlState::Offered, false),
1220            GpadlState::Accepted => {
1221                // It is assumed the device already knows about this GPADL.
1222                (super::GpadlState::Accepted, true)
1223            }
1224            GpadlState::TearingDown => (super::GpadlState::TearingDown, false),
1225        };
1226
1227        if !allow_revoked && channel.state.is_revoked() {
1228            return Err(RestoreError::GpadlForRevokedChannel(
1229                GpadlId(self.id),
1230                ChannelId(self.channel_id),
1231            ));
1232        }
1233
1234        Ok(super::Gpadl {
1235            count: self.count,
1236            buf: self.buf,
1237            state,
1238        })
1239    }
1240
1241    pub fn is_tearing_down(&self) -> bool {
1242        self.state == GpadlState::TearingDown
1243    }
1244}
1245
1246#[derive(Debug, Clone, Protobuf, PartialEq, Eq)]
1247#[mesh(package = "vmbus.server.channels")]
1248pub enum GpadlState {
1249    #[mesh(1)]
1250    InProgress,
1251    #[mesh(2)]
1252    Offered,
1253    #[mesh(3)]
1254    Accepted,
1255    #[mesh(4)]
1256    TearingDown,
1257}
1258
1259#[derive(Debug, Clone, Protobuf, PartialEq, Eq)]
1260#[mesh(package = "vmbus.server.channels")]
1261pub struct OutgoingMessage(pub Vec<u8>);
1262
1263impl OutgoingMessage {
1264    fn save(value: &vmbus_core::OutgoingMessage) -> Self {
1265        Self(value.data().to_vec())
1266    }
1267
1268    fn restore(self) -> Result<vmbus_core::OutgoingMessage, RestoreError> {
1269        vmbus_core::OutgoingMessage::from_message(&self.0)
1270            .map_err(|_| RestoreError::MessageTooLarge)
1271    }
1272}