1use super::MnfUsage;
5use super::Notifier;
6use super::OfferError;
7use super::OfferParamsInternal;
8use super::OfferedInfo;
9use super::RestoreState;
10use super::SUPPORTED_FEATURE_FLAGS;
11use guid::Guid;
12pub use inner::SavedState;
13use mesh::payload::Protobuf;
14use std::fmt::Display;
15use thiserror::Error;
16use vmbus_channel::bus::OfferKey;
17use vmbus_core::protocol::ChannelId;
18use vmbus_core::protocol::FeatureFlags;
19use vmbus_core::protocol::GpadlId;
20use vmbus_core::protocol::Version;
21use vmbus_ring::gparange;
22use vmbus_ring::gparange::MultiPagedRangeBuf;
23use vmcore::monitor::MonitorId;
24
25impl super::Server {
26 fn restore_one_channel(&mut self, saved_channel: Channel) -> Result<(), RestoreError> {
27 let (info, stub_offer, state) = saved_channel.restore()?;
28 if let Some((offer_id, channel)) = self.channels.get_by_key_mut(&saved_channel.key) {
29 if !matches!(channel.state, super::ChannelState::ClientReleased)
32 || channel.restore_state != RestoreState::New
33 {
34 return Err(RestoreError::AlreadyRestored(saved_channel.key));
35 }
36
37 if let MnfUsage::Relayed { monitor_id } = channel.offer.use_mnf {
40 if info.monitor_id != Some(MonitorId(monitor_id)) {
41 return Err(RestoreError::MismatchedMonitorId(
42 monitor_id,
43 saved_channel.monitor_id,
44 ));
45 }
46 }
47
48 self.assigned_channels
49 .set(info.channel_id)?
50 .insert(offer_id);
51
52 channel.state = state;
53 channel.restore_state = RestoreState::Restoring;
54 channel.info = Some(info);
55 } else {
56 let entry = self
59 .assigned_channels
60 .set(ChannelId(saved_channel.channel_id))?;
61
62 let channel = super::Channel {
63 info: Some(info),
64 offer: stub_offer,
65 state,
66 restore_state: RestoreState::Unmatched,
67 };
68
69 let offer_id = self.channels.offer(channel);
70 entry.insert(offer_id);
71 }
72 Ok(())
73 }
74
75 fn restore_one_gpadl(&mut self, saved_gpadl: Gpadl) -> Result<(), RestoreError> {
76 let gpadl_id = GpadlId(saved_gpadl.id);
77 let channel_id = ChannelId(saved_gpadl.channel_id);
78 let (offer_id, channel) = self
79 .channels
80 .get_by_channel_id(&self.assigned_channels, channel_id)
81 .map_err(|_| RestoreError::MissingGpadlChannel(gpadl_id, channel_id))?;
82
83 if channel.restore_state == RestoreState::New || channel.state.is_released() {
84 return Err(RestoreError::MissingGpadlChannel(gpadl_id, channel_id));
85 }
86
87 let gpadl = saved_gpadl.restore(channel)?;
88 let state = gpadl.state;
89 if self.gpadls.insert((gpadl_id, offer_id), gpadl).is_some() {
90 return Err(RestoreError::GpadlIdInUse(gpadl_id, channel_id));
91 }
92
93 if state == super::GpadlState::InProgress
94 && self.incomplete_gpadls.insert(gpadl_id, offer_id).is_some()
95 {
96 unreachable!("gpadl ID validated above");
97 }
98
99 Ok(())
100 }
101
102 pub fn save(&self) -> SavedState {
104 SavedStateData {
105 state: if let Some(state) = self.save_connected_state() {
106 SavedConnectionState::Connected(state)
107 } else {
108 SavedConnectionState::Disconnected(self.save_disconnected_state())
109 },
110 pending_messages: self.save_pending_messages(),
111 }
112 .into()
113 }
114
115 fn save_connected_state(&self) -> Option<ConnectedState> {
116 let connection = Connection::save(&self.state)?;
117 let channels = self
118 .channels
119 .iter()
120 .filter_map(|(_, channel)| Channel::save(channel))
121 .collect();
122
123 let gpadls = self.save_gpadls();
124 Some(ConnectedState {
125 connection,
126 channels,
127 gpadls,
128 })
129 }
130
131 fn save_gpadls(&self) -> Vec<Gpadl> {
132 self.gpadls
133 .iter()
134 .filter_map(|((gpadl_id, offer_id), gpadl)| {
135 Gpadl::save(*gpadl_id, self.channels[*offer_id].info?.channel_id, gpadl)
136 })
137 .collect()
138 }
139
140 fn save_disconnected_state(&self) -> DisconnectedState {
141 let channels = self
143 .channels
144 .iter()
145 .filter_map(|(_, channel)| {
146 channel
147 .state
148 .is_reserved()
149 .then(|| Channel::save(channel))
150 .flatten()
151 })
152 .collect();
153
154 let gpadls = self.save_gpadls();
157 DisconnectedState {
158 reserved_channels: channels,
159 reserved_gpadls: gpadls,
160 }
161 }
162
163 fn save_pending_messages(&self) -> Vec<OutgoingMessage> {
164 self.pending_messages
165 .0
166 .iter()
167 .map(OutgoingMessage::save)
168 .collect()
169 }
170}
171
172impl<'a, N: 'a + Notifier> super::ServerWithNotifier<'a, N> {
173 pub fn restore(&mut self, saved: SavedState) -> Result<(), RestoreError> {
188 tracing::trace!(?saved, "restoring channel state");
189
190 let saved = SavedStateData::from(saved);
191 match saved.state {
192 SavedConnectionState::Connected(saved) => {
193 self.inner.state = saved.connection.restore()?;
194
195 let request = match self.inner.state {
198 super::ConnectionState::Connecting {
199 info,
200 next_action: _,
201 } => Some(super::ModifyConnectionRequest {
202 version: Some(info.version.version as u32),
203 interrupt_page: info.interrupt_page.into(),
204 monitor_page: info.monitor_page.into(),
205 target_message_vp: Some(info.target_message_vp),
206 notify_relay: true,
207 }),
208 super::ConnectionState::Connected(info) => {
209 Some(super::ModifyConnectionRequest {
210 version: None,
211 monitor_page: info.monitor_page.into(),
212 interrupt_page: info.interrupt_page.into(),
213 target_message_vp: Some(info.target_message_vp),
214 notify_relay: info.modifying,
218 })
219 }
220 super::ConnectionState::Disconnected
223 | super::ConnectionState::Disconnecting { .. } => None,
224 };
225
226 if let Some(request) = request {
227 self.notifier.modify_connection(request)?;
228 }
229
230 for saved_channel in saved.channels {
231 self.inner.restore_one_channel(saved_channel)?;
232 }
233
234 for saved_gpadl in saved.gpadls {
235 self.inner.restore_one_gpadl(saved_gpadl)?;
236 }
237 }
238 SavedConnectionState::Disconnected(saved) => {
239 self.inner.state = super::ConnectionState::Disconnected;
240 for saved_channel in saved.reserved_channels {
241 self.inner.restore_one_channel(saved_channel)?;
242 }
243
244 for saved_gpadl in saved.reserved_gpadls {
245 self.inner.restore_one_gpadl(saved_gpadl)?;
246 }
247 }
248 }
249
250 self.inner
251 .pending_messages
252 .0
253 .reserve(saved.pending_messages.len());
254
255 for message in saved.pending_messages {
256 self.inner.pending_messages.0.push_back(message.restore()?);
257 }
258
259 Ok(())
260 }
261}
262
263#[derive(Debug, Error)]
264pub enum RestoreError {
265 #[error(transparent)]
266 Offer(#[from] OfferError),
267
268 #[error("channel {0} has already been restored")]
269 AlreadyRestored(OfferKey),
270
271 #[error("gpadl {} is for missing channel {}", (.0).0, (.1).0)]
272 MissingGpadlChannel(GpadlId, ChannelId),
273
274 #[error("gpadl {} is for revoked channel {}", (.0).0, (.1).0)]
275 GpadlForRevokedChannel(GpadlId, ChannelId),
276
277 #[error("gpadl {} is already restored", (.0).0)]
278 GpadlIdInUse(GpadlId, ChannelId),
279
280 #[error("unsupported protocol version {0:#x}")]
281 UnsupportedVersion(u32),
282
283 #[error("invalid gpadl")]
284 InvalidGpadl(#[from] gparange::Error),
285
286 #[error("unsupported feature flags {0:#x}")]
287 UnsupportedFeatureFlags(u32),
288
289 #[error("channel {0} has a mismatched open state")]
290 MismatchedOpenState(OfferKey),
291
292 #[error("channel {0} is missing from the saved state")]
293 MissingChannel(OfferKey),
294
295 #[error("unsupported reserved channel protocol version {0:#x}")]
296 UnsupportedReserveVersion(u32),
297
298 #[error("unsupported reserved channel feature flags {0:#x}")]
299 UnsupportedReserveFeatureFlags(u32),
300
301 #[error("mismatched monitor id; expected {0}, actual {1:?}")]
302 MismatchedMonitorId(u8, Option<u8>),
303
304 #[error("monitor ID used by multiple channels in the saved state")]
305 DuplicateMonitorId(u8),
306
307 #[error(transparent)]
308 ServerError(#[from] anyhow::Error),
309
310 #[error(
311 "reserved channel with ID {0} has a pending message but is missing from the saved state"
312 )]
313 MissingReservedChannel(u32),
314 #[error("a saved pending message is larger than the maximum message size")]
315 MessageTooLarge,
316}
317
318mod inner {
319 use super::*;
320
321 #[derive(Debug, Protobuf, Clone)]
326 #[mesh(package = "vmbus.server.channels")]
327 pub struct SavedState {
328 #[mesh(1)]
329 state: Option<ConnectedState>,
330 #[mesh(2)]
337 disconnected_state: Option<DisconnectedState>,
338 #[mesh(3)]
339 pending_messages: Vec<OutgoingMessage>,
340 }
341
342 impl From<SavedStateData> for SavedState {
343 fn from(value: SavedStateData) -> Self {
344 let (state, disconnected_state) = match value.state {
345 SavedConnectionState::Connected(connected) => (Some(connected), None),
346 SavedConnectionState::Disconnected(disconnected) => (None, Some(disconnected)),
347 };
348
349 Self {
350 state,
351 disconnected_state,
352 pending_messages: value.pending_messages,
353 }
354 }
355 }
356
357 impl From<SavedState> for SavedStateData {
358 fn from(value: SavedState) -> Self {
359 Self {
360 state: if let Some(connected) = value.state {
361 SavedConnectionState::Connected(connected)
362 } else {
363 SavedConnectionState::Disconnected(value.disconnected_state.unwrap_or_default())
366 },
367 pending_messages: value.pending_messages,
368 }
369 }
370 }
371}
372
373enum SavedConnectionState {
375 Connected(ConnectedState),
376 Disconnected(DisconnectedState),
377}
378
379pub struct SavedStateData {
382 state: SavedConnectionState,
383 pending_messages: Vec<OutgoingMessage>,
384}
385
386impl SavedStateData {
387 pub fn find_channel(&self, offer: OfferKey) -> Option<&Channel> {
389 let (channels, _) = self.channels_and_gpadls();
390 channels.iter().find(|c| c.key == offer)
391 }
392
393 pub fn channels_and_gpadls(&self) -> (&[Channel], &[Gpadl]) {
396 match &self.state {
397 SavedConnectionState::Connected(connected) => (&connected.channels, &connected.gpadls),
398 SavedConnectionState::Disconnected(disconnected) => (
399 &disconnected.reserved_channels,
400 &disconnected.reserved_gpadls,
401 ),
402 }
403 }
404}
405
406#[derive(Debug, Clone, Protobuf)]
407#[mesh(package = "vmbus.server.channels")]
408struct ConnectedState {
409 #[mesh(1)]
410 connection: Connection,
411 #[mesh(2)]
412 channels: Vec<Channel>,
413 #[mesh(3)]
414 gpadls: Vec<Gpadl>,
415}
416
417#[derive(Default, Debug, Clone, Protobuf)]
418#[mesh(package = "vmbus.server.channels")]
419struct DisconnectedState {
420 #[mesh(1)]
421 reserved_channels: Vec<Channel>,
422 #[mesh(2)]
423 reserved_gpadls: Vec<Gpadl>,
424}
425
426#[derive(Debug, Clone, Protobuf)]
427#[mesh(package = "vmbus.server.channels")]
428struct VersionInfo {
429 #[mesh(1)]
430 version: u32,
431 #[mesh(2)]
432 feature_flags: u32,
433}
434
435impl VersionInfo {
436 fn save(value: &super::VersionInfo) -> Self {
437 Self {
438 version: value.version as u32,
439 feature_flags: value.feature_flags.into(),
440 }
441 }
442
443 fn restore(self, trusted: bool) -> Result<vmbus_core::VersionInfo, RestoreError> {
444 let version = super::SUPPORTED_VERSIONS
445 .iter()
446 .find(|v| self.version == **v as u32)
447 .copied()
448 .ok_or(RestoreError::UnsupportedVersion(self.version))?;
449
450 let feature_flags = FeatureFlags::from(self.feature_flags);
451 let supported_flags = SUPPORTED_FEATURE_FLAGS.with_confidential_channels(trusted);
452 if !supported_flags.contains(feature_flags) {
453 return Err(RestoreError::UnsupportedFeatureFlags(feature_flags.into()));
454 }
455
456 Ok(super::VersionInfo {
457 version,
458 feature_flags,
459 })
460 }
461}
462
463#[derive(Debug, Clone, Protobuf)]
464#[mesh(package = "vmbus.server.channels")]
465enum Connection {
466 #[mesh(1)]
467 Disconnecting {
468 #[mesh(1)]
469 next_action: ConnectionAction,
470 },
471 #[mesh(2)]
472 Connecting {
473 #[mesh(1)]
474 version: VersionInfo,
475 #[mesh(2)]
476 interrupt_page: Option<u64>,
477 #[mesh(3)]
478 monitor_page: Option<MonitorPageGpas>,
479 #[mesh(4)]
480 target_message_vp: u32,
481 #[mesh(5)]
482 next_action: ConnectionAction,
483 #[mesh(6)]
484 client_id: Option<Guid>,
485 #[mesh(7)]
486 trusted: bool,
487 },
488 #[mesh(3)]
489 Connected {
490 #[mesh(1)]
491 version: VersionInfo,
492 #[mesh(2)]
493 offers_sent: bool,
494 #[mesh(3)]
495 interrupt_page: Option<u64>,
496 #[mesh(4)]
497 monitor_page: Option<MonitorPageGpas>,
498 #[mesh(5)]
499 target_message_vp: u32,
500 #[mesh(6)]
501 modifying: bool,
502 #[mesh(7)]
503 client_id: Option<Guid>,
504 #[mesh(8)]
505 trusted: bool,
506 #[mesh(9)]
507 paused: bool,
508 },
509}
510
511impl Connection {
512 fn save(value: &super::ConnectionState) -> Option<Self> {
513 match value {
514 super::ConnectionState::Disconnected => {
515 None
517 }
518 super::ConnectionState::Connecting { info, next_action } => {
519 Some(Connection::Connecting {
520 version: VersionInfo::save(&info.version),
521 interrupt_page: info.interrupt_page,
522 monitor_page: info.monitor_page.map(MonitorPageGpas::save),
523 target_message_vp: info.target_message_vp,
524 next_action: ConnectionAction::save(next_action),
525 client_id: Some(info.client_id),
526 trusted: info.trusted,
527 })
528 }
529 super::ConnectionState::Connected(info) => Some(Connection::Connected {
530 version: VersionInfo::save(&info.version),
531 offers_sent: info.offers_sent,
532 interrupt_page: info.interrupt_page,
533 monitor_page: info.monitor_page.map(MonitorPageGpas::save),
534 target_message_vp: info.target_message_vp,
535 modifying: info.modifying,
536 client_id: Some(info.client_id),
537 trusted: info.trusted,
538 paused: info.paused,
539 }),
540 super::ConnectionState::Disconnecting {
541 next_action,
542 modify_sent: _,
543 } => Some(Connection::Disconnecting {
544 next_action: ConnectionAction::save(next_action),
545 }),
546 }
547 }
548
549 fn restore(self) -> Result<super::ConnectionState, RestoreError> {
550 Ok(match self {
551 Connection::Connecting {
552 version,
553 interrupt_page,
554 monitor_page,
555 target_message_vp,
556 next_action,
557 client_id,
558 trusted,
559 } => super::ConnectionState::Connecting {
560 info: super::ConnectionInfo {
561 version: version.restore(trusted)?,
562 trusted,
563 interrupt_page,
564 monitor_page: monitor_page.map(MonitorPageGpas::restore),
565 target_message_vp,
566 offers_sent: false,
567 modifying: false,
568 client_id: client_id.unwrap_or(Guid::ZERO),
569 paused: false,
570 },
571 next_action: next_action.restore(),
572 },
573 Connection::Connected {
574 version,
575 offers_sent,
576 interrupt_page,
577 monitor_page,
578 target_message_vp,
579 modifying,
580 client_id,
581 trusted,
582 paused,
583 } => super::ConnectionState::Connected(super::ConnectionInfo {
584 version: version.restore(trusted)?,
585 trusted,
586 offers_sent,
587 interrupt_page,
588 monitor_page: monitor_page.map(MonitorPageGpas::restore),
589 target_message_vp,
590 modifying,
591 client_id: client_id.unwrap_or(Guid::ZERO),
592 paused,
593 }),
594 Connection::Disconnecting { next_action } => super::ConnectionState::Disconnecting {
595 next_action: next_action.restore(),
596 modify_sent: false,
598 },
599 })
600 }
601}
602
603#[derive(Debug, Clone, Protobuf)]
604#[mesh(package = "vmbus.server.channels")]
605enum ConnectionAction {
606 #[mesh(1)]
607 None,
608 #[mesh(2)]
609 SendUnloadComplete,
610 #[mesh(3)]
611 Reconnect {
612 #[mesh(1)]
613 initiate_contact: InitiateContactRequest,
614 },
615 #[mesh(4)]
616 SendFailedVersionResponse,
617}
618
619impl ConnectionAction {
620 fn save(value: &super::ConnectionAction) -> Self {
621 match value {
622 super::ConnectionAction::Reset | super::ConnectionAction::None => {
623 Self::None
626 }
627 super::ConnectionAction::SendUnloadComplete => Self::SendUnloadComplete,
628 super::ConnectionAction::Reconnect { initiate_contact } => Self::Reconnect {
629 initiate_contact: InitiateContactRequest::save(initiate_contact),
630 },
631 super::ConnectionAction::SendFailedVersionResponse => Self::SendFailedVersionResponse,
632 }
633 }
634
635 fn restore(self) -> super::ConnectionAction {
636 match self {
637 Self::None => super::ConnectionAction::None,
638 Self::SendUnloadComplete => super::ConnectionAction::SendUnloadComplete,
639 Self::Reconnect { initiate_contact } => super::ConnectionAction::Reconnect {
640 initiate_contact: initiate_contact.restore(),
641 },
642 Self::SendFailedVersionResponse => super::ConnectionAction::SendFailedVersionResponse,
643 }
644 }
645}
646
647#[derive(Debug, Clone, Protobuf)]
648#[mesh(package = "vmbus.server.channels")]
649pub struct Channel {
650 #[mesh(1)]
651 key: OfferKey,
652 #[mesh(2)]
653 channel_id: u32,
654 #[mesh(3)]
655 offered_connection_id: u32,
656 #[mesh(4)]
657 state: ChannelState,
658 #[mesh(5)]
659 monitor_id: Option<u8>,
660}
661
662impl Channel {
663 fn save(value: &super::Channel) -> Option<Self> {
664 let info = value.info.as_ref()?;
665 let key = value.offer.key();
666 if let Some(state) = ChannelState::save(&value.state) {
667 tracing::trace!(%key, %state, "channel saved");
668 Some(Channel {
669 channel_id: info.channel_id.0,
670 offered_connection_id: info.connection_id,
671 key,
672 state,
673 monitor_id: info.monitor_id.map(|id| id.0),
674 })
675 } else {
676 tracing::info!(%key, state = %value.state, "skipping channel save");
677 None
678 }
679 }
680
681 fn restore(
682 &self,
683 ) -> Result<(OfferedInfo, OfferParamsInternal, super::ChannelState), RestoreError> {
684 let info = OfferedInfo {
685 channel_id: ChannelId(self.channel_id),
686 connection_id: self.offered_connection_id,
687 monitor_id: self.monitor_id.map(MonitorId),
688 };
689
690 let stub_offer = OfferParamsInternal {
691 instance_id: self.key.instance_id,
692 interface_id: self.key.interface_id,
693 subchannel_index: self.key.subchannel_index,
694 ..Default::default()
695 };
696
697 let state = self.state.restore()?;
698 tracing::info!(key = %self.key, %state, "channel restored");
699 Ok((info, stub_offer, state))
700 }
701
702 pub fn channel_id(&self) -> u32 {
703 self.channel_id
704 }
705
706 pub fn key(&self) -> OfferKey {
707 self.key
708 }
709
710 pub fn open_request(&self) -> Option<OpenRequest> {
711 match self.state {
712 ChannelState::Closed => None,
713 ChannelState::Opening { request, .. } => Some(request),
714 ChannelState::Open { params, .. } => Some(params),
715 ChannelState::Closing { params, .. } => Some(params),
716 ChannelState::ClosingReopen { params, .. } => Some(params),
717 ChannelState::Revoked => None,
718 }
719 }
720}
721
722#[derive(PartialEq, Eq, Debug, Copy, Clone, Protobuf)]
723#[mesh(package = "vmbus.server.channels")]
724struct InitiateContactRequest {
725 #[mesh(1)]
726 version_requested: u32,
727 #[mesh(2)]
728 target_message_vp: u32,
729 #[mesh(3)]
730 monitor_page: MonitorPageRequest,
731 #[mesh(4)]
732 target_sint: u8,
733 #[mesh(5)]
734 target_vtl: u8,
735 #[mesh(6)]
736 feature_flags: u32,
737 #[mesh(7)]
738 interrupt_page: Option<u64>,
739 #[mesh(8)]
740 client_id: Guid,
741 #[mesh(9)]
742 trusted: bool,
743}
744
745impl InitiateContactRequest {
746 fn save(value: &super::InitiateContactRequest) -> Self {
747 Self {
748 version_requested: value.version_requested,
749 target_message_vp: value.target_message_vp,
750 monitor_page: MonitorPageRequest::save(value.monitor_page),
751 target_sint: value.target_sint,
752 target_vtl: value.target_vtl,
753 feature_flags: value.feature_flags,
754 interrupt_page: value.interrupt_page,
755 client_id: value.client_id,
756 trusted: value.trusted,
757 }
758 }
759
760 fn restore(self) -> super::InitiateContactRequest {
761 super::InitiateContactRequest {
762 version_requested: self.version_requested,
763 target_message_vp: self.target_message_vp,
764 monitor_page: self.monitor_page.restore(),
765 target_sint: self.target_sint,
766 target_vtl: self.target_vtl,
767 feature_flags: self.feature_flags,
768 interrupt_page: self.interrupt_page,
769 client_id: self.client_id,
770 trusted: self.trusted,
771 }
772 }
773}
774
775#[derive(PartialEq, Eq, Debug, Copy, Clone, Protobuf)]
776#[mesh(package = "vmbus.server.channels")]
777struct MonitorPageGpas {
778 #[mesh(1)]
779 parent_to_child: u64,
780 #[mesh(2)]
781 child_to_parent: u64,
782}
783
784impl MonitorPageGpas {
785 fn save(value: super::MonitorPageGpas) -> Self {
786 Self {
787 child_to_parent: value.child_to_parent,
788 parent_to_child: value.parent_to_child,
789 }
790 }
791
792 fn restore(self) -> super::MonitorPageGpas {
793 super::MonitorPageGpas {
794 child_to_parent: self.child_to_parent,
795 parent_to_child: self.parent_to_child,
796 }
797 }
798}
799
800#[derive(PartialEq, Eq, Debug, Copy, Clone, Protobuf)]
801#[mesh(package = "vmbus.server.channels")]
802enum MonitorPageRequest {
803 #[mesh(1)]
804 None,
805 #[mesh(2)]
806 Some(#[mesh(1)] MonitorPageGpas),
807 #[mesh(3)]
808 Invalid,
809}
810
811impl MonitorPageRequest {
812 fn save(value: super::MonitorPageRequest) -> Self {
813 match value {
814 super::MonitorPageRequest::None => MonitorPageRequest::None,
815 super::MonitorPageRequest::Some(mp) => {
816 MonitorPageRequest::Some(MonitorPageGpas::save(mp))
817 }
818 super::MonitorPageRequest::Invalid => MonitorPageRequest::Invalid,
819 }
820 }
821
822 fn restore(self) -> super::MonitorPageRequest {
823 match self {
824 MonitorPageRequest::None => super::MonitorPageRequest::None,
825 MonitorPageRequest::Some(mp) => super::MonitorPageRequest::Some(mp.restore()),
826 MonitorPageRequest::Invalid => super::MonitorPageRequest::Invalid,
827 }
828 }
829}
830
831#[derive(Debug, Copy, Clone, Protobuf)]
832#[mesh(package = "vmbus.server.channels")]
833struct SignalInfo {
834 #[mesh(1)]
835 event_flag: u16,
836 #[mesh(2)]
837 connection_id: u32,
838}
839
840impl SignalInfo {
841 fn save(value: &super::SignalInfo) -> Self {
842 Self {
843 event_flag: value.event_flag,
844 connection_id: value.connection_id,
845 }
846 }
847
848 fn restore(self) -> super::SignalInfo {
849 super::SignalInfo {
850 event_flag: self.event_flag,
851 connection_id: self.connection_id,
852 }
853 }
854}
855
856#[derive(Debug, Copy, Clone, Protobuf)]
857#[mesh(package = "vmbus.server.channels")]
858pub struct OpenRequest {
859 #[mesh(1)]
860 open_id: u32,
861 #[mesh(2)]
862 pub ring_buffer_gpadl_id: GpadlId,
863 #[mesh(3)]
864 target_vp: u32,
865 #[mesh(4)]
866 pub downstream_ring_buffer_page_offset: u32,
867 #[mesh(5)]
868 user_data: [u8; 120],
869 #[mesh(6)]
870 guest_specified_interrupt_info: Option<SignalInfo>,
871 #[mesh(7)]
872 flags: u16,
873}
874
875impl OpenRequest {
876 fn save(value: &super::OpenRequest) -> Self {
877 Self {
878 open_id: value.open_id,
879 ring_buffer_gpadl_id: value.ring_buffer_gpadl_id,
880 target_vp: value.target_vp,
881 downstream_ring_buffer_page_offset: value.downstream_ring_buffer_page_offset,
882 user_data: value.user_data.into(),
883 guest_specified_interrupt_info: value
884 .guest_specified_interrupt_info
885 .as_ref()
886 .map(SignalInfo::save),
887 flags: value.flags.into(),
888 }
889 }
890
891 fn restore(self) -> super::OpenRequest {
892 super::OpenRequest {
893 open_id: self.open_id,
894 ring_buffer_gpadl_id: self.ring_buffer_gpadl_id,
895 target_vp: self.target_vp,
896 downstream_ring_buffer_page_offset: self.downstream_ring_buffer_page_offset,
897 user_data: self.user_data.into(),
898 guest_specified_interrupt_info: self
899 .guest_specified_interrupt_info
900 .map(SignalInfo::restore),
901 flags: self.flags.into(),
902 }
903 }
904}
905
906#[derive(Debug, Copy, Clone, Protobuf)]
907#[mesh(package = "vmbus.server.channels")]
908enum ModifyState {
909 #[mesh(1)]
910 NotModifying,
911 #[mesh(2)]
912 Modifying {
913 #[mesh(1)]
914 pending_target_vp: Option<u32>,
915 },
916}
917
918impl ModifyState {
919 fn save(value: &super::ModifyState) -> Self {
920 match value {
921 super::ModifyState::NotModifying => Self::NotModifying,
922 super::ModifyState::Modifying { pending_target_vp } => Self::Modifying {
923 pending_target_vp: *pending_target_vp,
924 },
925 }
926 }
927
928 fn restore(self) -> super::ModifyState {
929 match self {
930 ModifyState::NotModifying => super::ModifyState::NotModifying,
931 ModifyState::Modifying { pending_target_vp } => {
932 super::ModifyState::Modifying { pending_target_vp }
933 }
934 }
935 }
936}
937
938#[derive(Debug, Clone, Protobuf)]
939#[mesh(package = "vmbus.server.channels")]
940struct ReservedState {
941 #[mesh(1)]
942 version: VersionInfo,
943 #[mesh(2)]
944 vp: u32,
945 #[mesh(3)]
946 sint: u8,
947}
948
949impl ReservedState {
950 fn save(reserved_state: &super::ReservedState) -> Self {
951 Self {
952 version: VersionInfo::save(&reserved_state.version),
953 vp: reserved_state.target.vp,
954 sint: reserved_state.target.sint,
955 }
956 }
957
958 fn restore(&self) -> Result<super::ReservedState, RestoreError> {
959 let version = self.version.clone().restore(true).map_err(|e| match e {
962 RestoreError::UnsupportedVersion(v) => RestoreError::UnsupportedReserveVersion(v),
963 RestoreError::UnsupportedFeatureFlags(f) => {
964 RestoreError::UnsupportedReserveFeatureFlags(f)
965 }
966 err => err,
967 })?;
968
969 if version.version < Version::Win10 {
970 return Err(RestoreError::UnsupportedReserveVersion(
971 version.version as u32,
972 ));
973 }
974
975 Ok(super::ReservedState {
976 version,
977 target: super::ConnectionTarget {
978 vp: self.vp,
979 sint: self.sint,
980 },
981 })
982 }
983}
984
985#[derive(Debug, Clone, Protobuf)]
986#[mesh(package = "vmbus.server.channels")]
987enum ChannelState {
988 #[mesh(1)]
989 Closed,
990 #[mesh(2)]
991 Opening {
992 #[mesh(1)]
993 request: OpenRequest,
994 #[mesh(2)]
995 reserved_state: Option<ReservedState>,
996 },
997 #[mesh(3)]
998 Open {
999 #[mesh(1)]
1000 params: OpenRequest,
1001 #[mesh(2)]
1002 modify_state: ModifyState,
1003 #[mesh(3)]
1004 reserved_state: Option<ReservedState>,
1005 },
1006 #[mesh(4)]
1007 Closing {
1008 #[mesh(1)]
1009 params: OpenRequest,
1010 #[mesh(2)]
1011 reserved_state: Option<ReservedState>,
1012 },
1013 #[mesh(5)]
1014 ClosingReopen {
1015 #[mesh(1)]
1016 params: OpenRequest,
1017 #[mesh(2)]
1018 request: OpenRequest,
1019 },
1020 #[mesh(6)]
1021 Revoked,
1022}
1023
1024impl ChannelState {
1025 fn save(value: &super::ChannelState) -> Option<Self> {
1026 Some(match value {
1027 super::ChannelState::Closed => ChannelState::Closed,
1028 super::ChannelState::Opening {
1029 request,
1030 reserved_state,
1031 } => ChannelState::Opening {
1032 request: OpenRequest::save(request),
1033 reserved_state: reserved_state.as_ref().map(ReservedState::save),
1034 },
1035 super::ChannelState::ClosingReopen { params, request } => ChannelState::ClosingReopen {
1036 params: OpenRequest::save(params),
1037 request: OpenRequest::save(request),
1038 },
1039 super::ChannelState::Open {
1040 params,
1041 modify_state,
1042 reserved_state,
1043 } => ChannelState::Open {
1044 params: OpenRequest::save(params),
1045 modify_state: ModifyState::save(modify_state),
1046 reserved_state: reserved_state.as_ref().map(ReservedState::save),
1047 },
1048 super::ChannelState::Closing {
1049 params,
1050 reserved_state,
1051 } => ChannelState::Closing {
1052 params: OpenRequest::save(params),
1053 reserved_state: reserved_state.as_ref().map(ReservedState::save),
1054 },
1055
1056 super::ChannelState::Revoked => ChannelState::Revoked,
1057 super::ChannelState::Reoffered => ChannelState::Revoked,
1058 super::ChannelState::ClientReleased
1059 | super::ChannelState::ClosingClientRelease
1060 | super::ChannelState::OpeningClientRelease => return None,
1061 })
1062 }
1063
1064 fn restore(&self) -> Result<super::ChannelState, RestoreError> {
1065 Ok(match self {
1066 ChannelState::Closed => super::ChannelState::Closed,
1067 ChannelState::Opening {
1068 request,
1069 reserved_state,
1070 } => super::ChannelState::Opening {
1071 request: request.restore(),
1072 reserved_state: reserved_state
1073 .as_ref()
1074 .map(ReservedState::restore)
1075 .transpose()?,
1076 },
1077 ChannelState::ClosingReopen { params, request } => super::ChannelState::ClosingReopen {
1078 params: params.restore(),
1079 request: request.restore(),
1080 },
1081 ChannelState::Open {
1082 params,
1083 modify_state,
1084 reserved_state,
1085 } => super::ChannelState::Open {
1086 params: params.restore(),
1087 modify_state: modify_state.restore(),
1088 reserved_state: reserved_state
1089 .as_ref()
1090 .map(ReservedState::restore)
1091 .transpose()?,
1092 },
1093 ChannelState::Closing {
1094 params,
1095 reserved_state,
1096 } => super::ChannelState::Closing {
1097 params: params.restore(),
1098 reserved_state: reserved_state
1099 .as_ref()
1100 .map(ReservedState::restore)
1101 .transpose()?,
1102 },
1103 ChannelState::Revoked => {
1104 super::ChannelState::Reoffered
1106 }
1107 })
1108 }
1109}
1110
1111impl Display for ChannelState {
1112 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1113 let state = match self {
1114 Self::Closed => "Closed",
1115 Self::Opening { .. } => "Opening",
1116 Self::Open { .. } => "Open",
1117 Self::Closing { .. } => "Closing",
1118 Self::ClosingReopen { .. } => "ClosingReopen",
1119 Self::Revoked => "Revoked",
1120 };
1121 write!(f, "{}", state)
1122 }
1123}
1124
1125#[derive(Debug, Clone, Protobuf)]
1126#[mesh(package = "vmbus.server.channels")]
1127pub struct Gpadl {
1128 #[mesh(1)]
1129 pub id: u32,
1130 #[mesh(2)]
1131 pub channel_id: u32,
1132 #[mesh(3)]
1133 pub count: u16,
1134 #[mesh(4)]
1135 pub buf: Vec<u64>,
1136 #[mesh(5)]
1137 state: GpadlState,
1138}
1139
1140impl Gpadl {
1141 fn save(gpadl_id: GpadlId, channel_id: ChannelId, gpadl: &super::Gpadl) -> Option<Self> {
1142 tracing::trace!(id = %gpadl_id.0, channel_id = %channel_id.0, "gpadl saved");
1143 Some(Gpadl {
1144 id: gpadl_id.0,
1145 channel_id: channel_id.0,
1146 count: gpadl.count,
1147 buf: gpadl.buf.clone(),
1148 state: match gpadl.state {
1149 super::GpadlState::InProgress => GpadlState::InProgress,
1150 super::GpadlState::Offered => GpadlState::Offered,
1151 super::GpadlState::Accepted => GpadlState::Accepted,
1152 super::GpadlState::TearingDown => GpadlState::TearingDown,
1153 super::GpadlState::OfferedTearingDown => return None,
1154 },
1155 })
1156 }
1157
1158 fn restore(self, channel: &super::Channel) -> Result<super::Gpadl, RestoreError> {
1159 let mut buf = self.buf;
1160 if self.state != GpadlState::InProgress {
1161 buf = MultiPagedRangeBuf::new(self.count.into(), buf)?.into_buffer();
1163 }
1164 let (state, allow_revoked) = match self.state {
1165 GpadlState::InProgress => (super::GpadlState::InProgress, true),
1166 GpadlState::Offered => (super::GpadlState::Offered, false),
1167 GpadlState::Accepted => {
1168 (super::GpadlState::Accepted, true)
1170 }
1171 GpadlState::TearingDown => (super::GpadlState::TearingDown, false),
1172 };
1173
1174 if !allow_revoked && channel.state.is_revoked() {
1175 return Err(RestoreError::GpadlForRevokedChannel(
1176 GpadlId(self.id),
1177 ChannelId(self.channel_id),
1178 ));
1179 }
1180
1181 Ok(super::Gpadl {
1182 count: self.count,
1183 buf,
1184 state,
1185 })
1186 }
1187
1188 pub fn is_tearing_down(&self) -> bool {
1189 self.state == GpadlState::TearingDown
1190 }
1191}
1192
1193#[derive(Debug, Clone, Protobuf, PartialEq, Eq)]
1194#[mesh(package = "vmbus.server.channels")]
1195pub enum GpadlState {
1196 #[mesh(1)]
1197 InProgress,
1198 #[mesh(2)]
1199 Offered,
1200 #[mesh(3)]
1201 Accepted,
1202 #[mesh(4)]
1203 TearingDown,
1204}
1205
1206#[derive(Debug, Clone, Protobuf, PartialEq, Eq)]
1207#[mesh(package = "vmbus.server.channels")]
1208struct OutgoingMessage(Vec<u8>);
1209
1210impl OutgoingMessage {
1211 fn save(value: &vmbus_core::OutgoingMessage) -> Self {
1212 Self(value.data().to_vec())
1213 }
1214
1215 fn restore(self) -> Result<vmbus_core::OutgoingMessage, RestoreError> {
1216 vmbus_core::OutgoingMessage::from_message(&self.0)
1217 .map_err(|_| RestoreError::MessageTooLarge)
1218 }
1219}