1use super::MnfUsage;
5use super::Notifier;
6use super::OfferError;
7use super::OfferParamsInternal;
8use super::OfferedInfo;
9use super::RestoreState;
10use super::SUPPORTED_FEATURE_FLAGS;
11use guid::Guid;
12pub use inner::SavedState;
13use mesh::payload::Protobuf;
14use std::fmt::Display;
15use thiserror::Error;
16use vmbus_channel::bus::OfferKey;
17use vmbus_core::protocol;
18use vmbus_core::protocol::ChannelId;
19use vmbus_core::protocol::FeatureFlags;
20use vmbus_core::protocol::GpadlId;
21use vmbus_core::protocol::Version;
22use vmbus_ring::gparange;
23use vmcore::monitor::MonitorId;
24
25impl super::Server {
26 fn restore_one_channel(&mut self, saved_channel: Channel) -> Result<(), RestoreError> {
27 let (info, stub_offer, state) = saved_channel.restore()?;
28 if let Some((offer_id, channel)) = self.channels.get_by_key_mut(&saved_channel.key) {
29 if !matches!(channel.state, super::ChannelState::ClientReleased)
32 || channel.restore_state != RestoreState::New
33 {
34 return Err(RestoreError::AlreadyRestored(saved_channel.key));
35 }
36
37 if let MnfUsage::Relayed { monitor_id } = channel.offer.use_mnf {
40 if info.monitor_id != Some(MonitorId(monitor_id)) {
41 return Err(RestoreError::MismatchedMonitorId(
42 monitor_id,
43 saved_channel.monitor_id,
44 ));
45 }
46 }
47
48 self.assigned_channels
49 .set(info.channel_id)?
50 .insert(offer_id);
51
52 channel.state = state;
53 channel.restore_state = RestoreState::Restoring;
54 channel.info = Some(info);
55 } else {
56 let entry = self
59 .assigned_channels
60 .set(ChannelId(saved_channel.channel_id))?;
61
62 let channel = super::Channel {
63 info: Some(info),
64 offer: stub_offer,
65 state,
66 restore_state: RestoreState::Unmatched,
67 };
68
69 let offer_id = self.channels.offer(channel);
70 entry.insert(offer_id);
71 }
72 Ok(())
73 }
74
75 fn restore_one_gpadl(&mut self, saved_gpadl: Gpadl) -> Result<(), RestoreError> {
76 let gpadl_id = GpadlId(saved_gpadl.id);
77 let channel_id = ChannelId(saved_gpadl.channel_id);
78 let (offer_id, channel) = self
79 .channels
80 .get_by_channel_id(&self.assigned_channels, channel_id)
81 .map_err(|_| RestoreError::MissingGpadlChannel(gpadl_id, channel_id))?;
82
83 if channel.restore_state == RestoreState::New || channel.state.is_released() {
84 return Err(RestoreError::MissingGpadlChannel(gpadl_id, channel_id));
85 }
86
87 let gpadl = saved_gpadl.restore(channel)?;
88 let state = gpadl.state;
89 if self.gpadls.insert((gpadl_id, offer_id), gpadl).is_some() {
90 return Err(RestoreError::GpadlIdInUse(gpadl_id, channel_id));
91 }
92
93 if state == super::GpadlState::InProgress
94 && self.incomplete_gpadls.insert(gpadl_id, offer_id).is_some()
95 {
96 unreachable!("gpadl ID validated above");
97 }
98
99 Ok(())
100 }
101
102 pub fn save(&self) -> SavedState {
104 SavedStateData {
105 state: if let Some(state) = self.save_connected_state() {
106 SavedConnectionState::Connected(state)
107 } else {
108 SavedConnectionState::Disconnected(self.save_disconnected_state())
109 },
110 pending_messages: self.save_pending_messages(),
111 }
112 .into()
113 }
114
115 fn save_connected_state(&self) -> Option<ConnectedState> {
116 let connection = Connection::save(&self.state)?;
117 let channels = self
118 .channels
119 .iter()
120 .filter_map(|(_, channel)| Channel::save(channel))
121 .collect();
122
123 let gpadls = self.save_gpadls();
124 Some(ConnectedState {
125 connection,
126 channels,
127 gpadls,
128 })
129 }
130
131 fn save_gpadls(&self) -> Vec<Gpadl> {
132 self.gpadls
133 .iter()
134 .filter_map(|((gpadl_id, offer_id), gpadl)| {
135 Gpadl::save(*gpadl_id, self.channels[*offer_id].info?.channel_id, gpadl)
136 })
137 .collect()
138 }
139
140 fn save_disconnected_state(&self) -> DisconnectedState {
141 let channels = self
143 .channels
144 .iter()
145 .filter_map(|(_, channel)| {
146 channel
147 .state
148 .is_reserved()
149 .then(|| Channel::save(channel))
150 .flatten()
151 })
152 .collect();
153
154 let gpadls = self.save_gpadls();
157 DisconnectedState {
158 reserved_channels: channels,
159 reserved_gpadls: gpadls,
160 }
161 }
162
163 fn save_pending_messages(&self) -> Vec<OutgoingMessage> {
164 self.pending_messages
165 .0
166 .iter()
167 .map(OutgoingMessage::save)
168 .collect()
169 }
170}
171
172impl<'a, N: 'a + Notifier> super::ServerWithNotifier<'a, N> {
173 pub fn restore(&mut self, saved: SavedState) -> Result<(), RestoreError> {
188 tracing::trace!(?saved, "restoring channel state");
189
190 let saved = SavedStateData::from(saved);
191 match saved.state {
192 SavedConnectionState::Connected(saved) => {
193 self.inner.state = saved.connection.restore()?;
194
195 let request = match self.inner.state {
198 super::ConnectionState::Connecting {
199 info,
200 next_action: _,
201 } => Some(super::ModifyConnectionRequest {
202 version: Some(info.version),
203 interrupt_page: info.interrupt_page.into(),
204 monitor_page: info.monitor_page.map(|mp| mp.gpas).into(),
205 target_message_vp: Some(info.target_message_vp),
206 notify_relay: true,
207 }),
208 super::ConnectionState::Connected(info) => {
209 Some(super::ModifyConnectionRequest {
210 version: None,
211 monitor_page: info.monitor_page.map(|mp| mp.gpas).into(),
212 interrupt_page: info.interrupt_page.into(),
213 target_message_vp: Some(info.target_message_vp),
214 notify_relay: info.modifying,
218 })
219 }
220 super::ConnectionState::Disconnected
223 | super::ConnectionState::Disconnecting { .. } => None,
224 };
225
226 if let Some(request) = request {
227 self.notifier.modify_connection(request)?;
228 }
229
230 for saved_channel in saved.channels {
231 self.inner.restore_one_channel(saved_channel)?;
232 }
233
234 for saved_gpadl in saved.gpadls {
235 self.inner.restore_one_gpadl(saved_gpadl)?;
236 }
237 }
238 SavedConnectionState::Disconnected(saved) => {
239 self.inner.state = super::ConnectionState::Disconnected;
240 for saved_channel in saved.reserved_channels {
241 self.inner.restore_one_channel(saved_channel)?;
242 }
243
244 for saved_gpadl in saved.reserved_gpadls {
245 self.inner.restore_one_gpadl(saved_gpadl)?;
246 }
247 }
248 }
249
250 self.inner
251 .pending_messages
252 .0
253 .reserve(saved.pending_messages.len());
254
255 for message in saved.pending_messages {
256 self.inner.pending_messages.0.push_back(message.restore()?);
257 }
258
259 Ok(())
260 }
261}
262
263#[derive(Debug, Error)]
264pub enum RestoreError {
265 #[error(transparent)]
266 Offer(#[from] OfferError),
267
268 #[error("channel {0} has already been restored")]
269 AlreadyRestored(OfferKey),
270
271 #[error("gpadl {} is for missing channel {}", (.0).0, (.1).0)]
272 MissingGpadlChannel(GpadlId, ChannelId),
273
274 #[error("gpadl {} is for revoked channel {}", (.0).0, (.1).0)]
275 GpadlForRevokedChannel(GpadlId, ChannelId),
276
277 #[error("gpadl {} is already restored", (.0).0)]
278 GpadlIdInUse(GpadlId, ChannelId),
279
280 #[error("unsupported protocol version {0:#x}")]
281 UnsupportedVersion(u32),
282
283 #[error("invalid gpadl")]
284 InvalidGpadl(#[from] gparange::Error),
285
286 #[error("unsupported feature flags {0:#x}")]
287 UnsupportedFeatureFlags(u32),
288
289 #[error("channel {0} has a mismatched open state")]
290 MismatchedOpenState(OfferKey),
291
292 #[error("channel {0} is missing from the saved state")]
293 MissingChannel(OfferKey),
294
295 #[error("unsupported reserved channel protocol version {0:#x}")]
296 UnsupportedReserveVersion(u32),
297
298 #[error("unsupported reserved channel feature flags {0:#x}")]
299 UnsupportedReserveFeatureFlags(u32),
300
301 #[error("mismatched monitor id; expected {0}, actual {1:?}")]
302 MismatchedMonitorId(u8, Option<u8>),
303
304 #[error("monitor ID used by multiple channels in the saved state")]
305 DuplicateMonitorId(u8),
306
307 #[error(transparent)]
308 ServerError(#[from] anyhow::Error),
309
310 #[error(
311 "reserved channel with ID {0} has a pending message but is missing from the saved state"
312 )]
313 MissingReservedChannel(u32),
314 #[error("a saved pending message is larger than the maximum message size")]
315 MessageTooLarge,
316}
317
318mod inner {
319 use super::*;
320
321 #[derive(Debug, Protobuf, Clone)]
326 #[mesh(package = "vmbus.server.channels")]
327 pub struct SavedState {
328 #[mesh(1)]
329 state: Option<ConnectedState>,
330 #[mesh(2)]
337 disconnected_state: Option<DisconnectedState>,
338 #[mesh(3)]
339 pending_messages: Vec<OutgoingMessage>,
340 }
341
342 impl From<SavedStateData> for SavedState {
343 fn from(value: SavedStateData) -> Self {
344 let (state, disconnected_state) = match value.state {
345 SavedConnectionState::Connected(connected) => (Some(connected), None),
346 SavedConnectionState::Disconnected(disconnected) => (None, Some(disconnected)),
347 };
348
349 Self {
350 state,
351 disconnected_state,
352 pending_messages: value.pending_messages,
353 }
354 }
355 }
356
357 impl From<SavedState> for SavedStateData {
358 fn from(value: SavedState) -> Self {
359 Self {
360 state: if let Some(connected) = value.state {
361 SavedConnectionState::Connected(connected)
362 } else {
363 SavedConnectionState::Disconnected(value.disconnected_state.unwrap_or_default())
366 },
367 pending_messages: value.pending_messages,
368 }
369 }
370 }
371}
372
373#[derive(Debug, Clone)]
375pub enum SavedConnectionState {
376 Connected(ConnectedState),
377 Disconnected(DisconnectedState),
378}
379
380#[derive(Debug, Clone)]
383pub struct SavedStateData {
384 pub state: SavedConnectionState,
385 pub pending_messages: Vec<OutgoingMessage>,
386}
387
388impl SavedStateData {
389 pub fn find_channel(&self, offer: OfferKey) -> Option<&Channel> {
391 let (channels, _) = self.channels_and_gpadls();
392 channels.iter().find(|c| c.key == offer)
393 }
394
395 pub fn channels_and_gpadls(&self) -> (&[Channel], &[Gpadl]) {
398 match &self.state {
399 SavedConnectionState::Connected(connected) => (&connected.channels, &connected.gpadls),
400 SavedConnectionState::Disconnected(disconnected) => (
401 &disconnected.reserved_channels,
402 &disconnected.reserved_gpadls,
403 ),
404 }
405 }
406}
407
408#[derive(Debug, Clone, Protobuf)]
409#[mesh(package = "vmbus.server.channels")]
410pub struct ConnectedState {
411 #[mesh(1)]
412 pub connection: Connection,
413 #[mesh(2)]
414 pub channels: Vec<Channel>,
415 #[mesh(3)]
416 pub gpadls: Vec<Gpadl>,
417}
418
419#[derive(Default, Debug, Clone, Protobuf)]
420#[mesh(package = "vmbus.server.channels")]
421pub struct DisconnectedState {
422 #[mesh(1)]
423 pub reserved_channels: Vec<Channel>,
424 #[mesh(2)]
425 pub reserved_gpadls: Vec<Gpadl>,
426}
427
428#[derive(Debug, PartialEq, Eq, Clone, Protobuf)]
429#[mesh(package = "vmbus.server.channels")]
430pub struct VersionInfo {
431 #[mesh(1)]
432 pub version: u32,
433 #[mesh(2)]
434 pub feature_flags: u32,
435}
436
437impl VersionInfo {
438 fn save(value: &super::VersionInfo) -> Self {
439 Self {
440 version: value.version as u32,
441 feature_flags: value.feature_flags.into(),
442 }
443 }
444
445 fn restore(self, trusted: bool) -> Result<vmbus_core::VersionInfo, RestoreError> {
446 let version = super::SUPPORTED_VERSIONS
447 .iter()
448 .find(|v| self.version == **v as u32)
449 .copied()
450 .ok_or(RestoreError::UnsupportedVersion(self.version))?;
451
452 let feature_flags = FeatureFlags::from(self.feature_flags);
453 let supported_flags = SUPPORTED_FEATURE_FLAGS.with_confidential_channels(trusted);
454 if !supported_flags.contains(feature_flags) {
455 return Err(RestoreError::UnsupportedFeatureFlags(feature_flags.into()));
456 }
457
458 Ok(super::VersionInfo {
459 version,
460 feature_flags,
461 })
462 }
463}
464
465#[derive(Debug, Clone, Protobuf)]
466#[mesh(package = "vmbus.server.channels")]
467pub enum Connection {
468 #[mesh(1)]
469 Disconnecting {
470 #[mesh(1)]
471 next_action: ConnectionAction,
472 },
473 #[mesh(2)]
474 Connecting {
475 #[mesh(1)]
476 version: VersionInfo,
477 #[mesh(2)]
478 interrupt_page: Option<u64>,
479 #[mesh(3)]
480 monitor_page: Option<MonitorPageGpas>,
481 #[mesh(4)]
482 target_message_vp: u32,
483 #[mesh(5)]
484 next_action: ConnectionAction,
485 #[mesh(6)]
486 client_id: Option<Guid>,
487 #[mesh(7)]
488 trusted: bool,
489 },
490 #[mesh(3)]
491 Connected {
492 #[mesh(1)]
493 version: VersionInfo,
494 #[mesh(2)]
495 offers_sent: bool,
496 #[mesh(3)]
497 interrupt_page: Option<u64>,
498 #[mesh(4)]
499 monitor_page: Option<MonitorPageGpas>,
500 #[mesh(5)]
501 target_message_vp: u32,
502 #[mesh(6)]
503 modifying: bool,
504 #[mesh(7)]
505 client_id: Option<Guid>,
506 #[mesh(8)]
507 trusted: bool,
508 #[mesh(9)]
509 paused: bool,
510 },
511}
512
513impl Connection {
514 fn save(value: &super::ConnectionState) -> Option<Self> {
515 match value {
516 super::ConnectionState::Disconnected => {
517 None
519 }
520 super::ConnectionState::Connecting { info, next_action } => {
521 Some(Connection::Connecting {
522 version: VersionInfo::save(&info.version),
523 interrupt_page: info.interrupt_page,
524 monitor_page: info.monitor_page.map(MonitorPageGpas::save),
525 target_message_vp: info.target_message_vp,
526 next_action: ConnectionAction::save(next_action),
527 client_id: Some(info.client_id),
528 trusted: info.trusted,
529 })
530 }
531 super::ConnectionState::Connected(info) => Some(Connection::Connected {
532 version: VersionInfo::save(&info.version),
533 offers_sent: info.offers_sent,
534 interrupt_page: info.interrupt_page,
535 monitor_page: info.monitor_page.map(MonitorPageGpas::save),
536 target_message_vp: info.target_message_vp,
537 modifying: info.modifying,
538 client_id: Some(info.client_id),
539 trusted: info.trusted,
540 paused: info.paused,
541 }),
542 super::ConnectionState::Disconnecting {
543 next_action,
544 modify_sent: _,
545 } => Some(Connection::Disconnecting {
546 next_action: ConnectionAction::save(next_action),
547 }),
548 }
549 }
550
551 fn restore(self) -> Result<super::ConnectionState, RestoreError> {
552 Ok(match self {
553 Connection::Connecting {
554 version,
555 interrupt_page,
556 monitor_page,
557 target_message_vp,
558 next_action,
559 client_id,
560 trusted,
561 } => super::ConnectionState::Connecting {
562 info: super::ConnectionInfo {
563 version: version.restore(trusted)?,
564 trusted,
565 interrupt_page,
566 monitor_page: monitor_page.map(MonitorPageGpas::restore),
567 target_message_vp,
568 offers_sent: false,
569 modifying: false,
570 client_id: client_id.unwrap_or(Guid::ZERO),
571 paused: false,
572 },
573 next_action: next_action.restore(),
574 },
575 Connection::Connected {
576 version,
577 offers_sent,
578 interrupt_page,
579 monitor_page,
580 target_message_vp,
581 modifying,
582 client_id,
583 trusted,
584 paused,
585 } => super::ConnectionState::Connected(super::ConnectionInfo {
586 version: version.restore(trusted)?,
587 trusted,
588 offers_sent,
589 interrupt_page,
590 monitor_page: monitor_page.map(MonitorPageGpas::restore),
591 target_message_vp,
592 modifying,
593 client_id: client_id.unwrap_or(Guid::ZERO),
594 paused,
595 }),
596 Connection::Disconnecting { next_action } => super::ConnectionState::Disconnecting {
597 next_action: next_action.restore(),
598 modify_sent: false,
600 },
601 })
602 }
603}
604
605#[derive(Debug, Clone, Protobuf)]
606#[mesh(package = "vmbus.server.channels")]
607pub enum ConnectionAction {
608 #[mesh(1)]
609 None,
610 #[mesh(2)]
611 SendUnloadComplete,
612 #[mesh(3)]
613 Reconnect {
614 #[mesh(1)]
615 initiate_contact: InitiateContactRequest,
616 },
617 #[mesh(4)]
618 SendFailedVersionResponse,
619}
620
621impl ConnectionAction {
622 fn save(value: &super::ConnectionAction) -> Self {
623 match value {
624 super::ConnectionAction::Reset | super::ConnectionAction::None => {
625 Self::None
628 }
629 super::ConnectionAction::SendUnloadComplete => Self::SendUnloadComplete,
630 super::ConnectionAction::Reconnect { initiate_contact } => Self::Reconnect {
631 initiate_contact: InitiateContactRequest::save(initiate_contact),
632 },
633 super::ConnectionAction::SendFailedVersionResponse => Self::SendFailedVersionResponse,
634 }
635 }
636
637 fn restore(self) -> super::ConnectionAction {
638 match self {
639 Self::None => super::ConnectionAction::None,
640 Self::SendUnloadComplete => super::ConnectionAction::SendUnloadComplete,
641 Self::Reconnect { initiate_contact } => super::ConnectionAction::Reconnect {
642 initiate_contact: initiate_contact.restore(),
643 },
644 Self::SendFailedVersionResponse => super::ConnectionAction::SendFailedVersionResponse,
645 }
646 }
647}
648
649#[derive(Debug, Clone, Protobuf)]
650#[mesh(package = "vmbus.server.channels")]
651pub struct Channel {
652 #[mesh(1)]
653 pub key: OfferKey,
654 #[mesh(2)]
655 pub channel_id: u32,
656 #[mesh(3)]
657 pub offered_connection_id: u32,
658 #[mesh(4)]
659 pub state: ChannelState,
660 #[mesh(5)]
661 pub monitor_id: Option<u8>,
662}
663
664impl Channel {
665 fn save(value: &super::Channel) -> Option<Self> {
666 let info = value.info.as_ref()?;
667 let key = value.offer.key();
668 if let Some(state) = ChannelState::save(&value.state) {
669 tracing::trace!(%key, %state, "channel saved");
670 Some(Channel {
671 channel_id: info.channel_id.0,
672 offered_connection_id: info.connection_id,
673 key,
674 state,
675 monitor_id: info.monitor_id.map(|id| id.0),
676 })
677 } else {
678 tracing::info!(%key, state = %value.state, "skipping channel save");
679 None
680 }
681 }
682
683 fn restore(
684 &self,
685 ) -> Result<(OfferedInfo, OfferParamsInternal, super::ChannelState), RestoreError> {
686 let info = OfferedInfo {
687 channel_id: ChannelId(self.channel_id),
688 connection_id: self.offered_connection_id,
689 monitor_id: self.monitor_id.map(MonitorId),
690 };
691
692 let stub_offer = OfferParamsInternal {
693 instance_id: self.key.instance_id,
694 interface_id: self.key.interface_id,
695 subchannel_index: self.key.subchannel_index,
696 ..Default::default()
697 };
698
699 let state = self.state.restore()?;
700 tracing::info!(key = %self.key, %state, "channel restored");
701 Ok((info, stub_offer, state))
702 }
703
704 pub fn channel_id(&self) -> u32 {
705 self.channel_id
706 }
707
708 pub fn key(&self) -> OfferKey {
709 self.key
710 }
711
712 pub fn open_request(&self) -> Option<OpenRequest> {
713 match self.state {
714 ChannelState::Closed => None,
715 ChannelState::Opening { request, .. } => Some(request),
716 ChannelState::Open { params, .. } => Some(params),
717 ChannelState::Closing { params, .. } => Some(params),
718 ChannelState::ClosingReopen { params, .. } => Some(params),
719 ChannelState::Revoked => None,
720 }
721 }
722}
723
724#[derive(PartialEq, Eq, Debug, Copy, Clone, Protobuf)]
725#[mesh(package = "vmbus.server.channels")]
726pub struct InitiateContactRequest {
727 #[mesh(1)]
728 pub version_requested: u32,
729 #[mesh(2)]
730 pub target_message_vp: u32,
731 #[mesh(3)]
732 pub monitor_page: MonitorPageRequest,
733 #[mesh(4)]
734 pub target_sint: u8,
735 #[mesh(5)]
736 pub target_vtl: u8,
737 #[mesh(6)]
738 pub feature_flags: u32,
739 #[mesh(7)]
740 pub interrupt_page: Option<u64>,
741 #[mesh(8)]
742 pub client_id: Guid,
743 #[mesh(9)]
744 pub trusted: bool,
745}
746
747impl InitiateContactRequest {
748 fn save(value: &super::InitiateContactRequest) -> Self {
749 Self {
750 version_requested: value.version_requested,
751 target_message_vp: value.target_message_vp,
752 monitor_page: MonitorPageRequest::save(value.monitor_page),
753 target_sint: value.target_sint,
754 target_vtl: value.target_vtl,
755 feature_flags: value.feature_flags,
756 interrupt_page: value.interrupt_page,
757 client_id: value.client_id,
758 trusted: value.trusted,
759 }
760 }
761
762 fn restore(self) -> super::InitiateContactRequest {
763 super::InitiateContactRequest {
764 version_requested: self.version_requested,
765 target_message_vp: self.target_message_vp,
766 monitor_page: self.monitor_page.restore(),
767 target_sint: self.target_sint,
768 target_vtl: self.target_vtl,
769 feature_flags: self.feature_flags,
770 interrupt_page: self.interrupt_page,
771 client_id: self.client_id,
772 trusted: self.trusted,
773 }
774 }
775}
776
777#[derive(PartialEq, Eq, Debug, Copy, Clone, Protobuf)]
778#[mesh(package = "vmbus.server.channels")]
779pub struct MonitorPageGpas {
780 #[mesh(1)]
781 pub parent_to_child: u64,
782 #[mesh(2)]
783 pub child_to_parent: u64,
784}
785
786impl MonitorPageGpas {
787 fn save(value: super::MonitorPageGpaInfo) -> Self {
788 assert!(
789 !value.server_allocated,
790 "cannot save with server-allocated monitor pages"
791 );
792 Self {
793 child_to_parent: value.gpas.child_to_parent,
794 parent_to_child: value.gpas.parent_to_child,
795 }
796 }
797
798 fn restore(self) -> super::MonitorPageGpaInfo {
799 super::MonitorPageGpaInfo::from_guest_gpas(super::MonitorPageGpas {
800 child_to_parent: self.child_to_parent,
801 parent_to_child: self.parent_to_child,
802 })
803 }
804}
805
806#[derive(PartialEq, Eq, Debug, Copy, Clone, Protobuf)]
807#[mesh(package = "vmbus.server.channels")]
808pub enum MonitorPageRequest {
809 #[mesh(1)]
810 None,
811 #[mesh(2)]
812 Some(#[mesh(1)] MonitorPageGpas),
813 #[mesh(3)]
814 Invalid,
815}
816
817impl MonitorPageRequest {
818 fn save(value: super::MonitorPageRequest) -> Self {
819 match value {
820 super::MonitorPageRequest::None => MonitorPageRequest::None,
821 super::MonitorPageRequest::Some(mp) => MonitorPageRequest::Some(MonitorPageGpas::save(
822 super::MonitorPageGpaInfo::from_guest_gpas(mp),
823 )),
824 super::MonitorPageRequest::Invalid => MonitorPageRequest::Invalid,
825 }
826 }
827
828 fn restore(self) -> super::MonitorPageRequest {
829 match self {
830 MonitorPageRequest::None => super::MonitorPageRequest::None,
831 MonitorPageRequest::Some(mp) => super::MonitorPageRequest::Some(mp.restore().gpas),
832 MonitorPageRequest::Invalid => super::MonitorPageRequest::Invalid,
833 }
834 }
835}
836
837#[derive(PartialEq, Eq, Debug, Copy, Clone, Protobuf)]
838#[mesh(package = "vmbus.server.channels")]
839pub struct SignalInfo {
840 #[mesh(1)]
841 pub event_flag: u16,
842 #[mesh(2)]
843 pub connection_id: u32,
844}
845
846impl SignalInfo {
847 fn save(value: &super::SignalInfo) -> Self {
848 Self {
849 event_flag: value.event_flag,
850 connection_id: value.connection_id,
851 }
852 }
853
854 fn restore(self) -> super::SignalInfo {
855 super::SignalInfo {
856 event_flag: self.event_flag,
857 connection_id: self.connection_id,
858 }
859 }
860}
861
862#[derive(Copy, PartialEq, Eq, Clone, Protobuf)]
863#[mesh(package = "vmbus.server.channels")]
864pub struct OpenRequest {
865 #[mesh(1)]
866 pub open_id: u32,
867 #[mesh(2)]
868 pub ring_buffer_gpadl_id: GpadlId,
869 #[mesh(3)]
870 pub target_vp: u32,
871 #[mesh(4)]
872 pub downstream_ring_buffer_page_offset: u32,
873 #[mesh(5)]
874 pub user_data: [u8; 120],
875 #[mesh(6)]
876 pub guest_specified_interrupt_info: Option<SignalInfo>,
877 #[mesh(7)]
878 pub flags: u16,
879}
880
881impl std::fmt::Debug for OpenRequest {
882 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
883 let Self {
884 open_id,
885 ring_buffer_gpadl_id,
886 target_vp,
887 downstream_ring_buffer_page_offset,
888 user_data,
889 guest_specified_interrupt_info,
890 flags,
891 } = self;
892
893 let user_data_display: &dyn std::fmt::Debug = if self.user_data.iter().all(|&b| b == 0) {
894 &"[<all-zeroes>]"
895 } else {
896 struct HexDisplay<'a>(&'a [u8; 120]);
897 impl std::fmt::Debug for HexDisplay<'_> {
898 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
899 write!(f, "[")?;
900 for byte in self.0 {
901 write!(f, "{:02X}", byte)?;
902 }
903 write!(f, "]")
904 }
905 }
906 &HexDisplay(user_data)
907 };
908
909 f.debug_struct("OpenRequest")
910 .field("open_id", open_id)
911 .field("ring_buffer_gpadl_id", ring_buffer_gpadl_id)
912 .field("target_vp", target_vp)
913 .field(
914 "downstream_ring_buffer_page_offset",
915 downstream_ring_buffer_page_offset,
916 )
917 .field("user_data", user_data_display)
918 .field(
919 "guest_specified_interrupt_info",
920 guest_specified_interrupt_info,
921 )
922 .field("flags", flags)
923 .finish()
924 }
925}
926
927impl OpenRequest {
928 fn save(value: &super::OpenRequest) -> Self {
929 Self {
930 open_id: value.open_id,
931 ring_buffer_gpadl_id: value.ring_buffer_gpadl_id,
932 target_vp: value
933 .target_vp
934 .unwrap_or(protocol::VP_INDEX_DISABLE_INTERRUPT),
935 downstream_ring_buffer_page_offset: value.downstream_ring_buffer_page_offset,
936 user_data: value.user_data.into(),
937 guest_specified_interrupt_info: value
938 .guest_specified_interrupt_info
939 .as_ref()
940 .map(SignalInfo::save),
941 flags: value.flags.into(),
942 }
943 }
944
945 fn restore(self) -> super::OpenRequest {
946 super::OpenRequest {
947 open_id: self.open_id,
948 ring_buffer_gpadl_id: self.ring_buffer_gpadl_id,
949 target_vp: protocol::vp_index_if_enabled(self.target_vp),
950 downstream_ring_buffer_page_offset: self.downstream_ring_buffer_page_offset,
951 user_data: self.user_data.into(),
952 guest_specified_interrupt_info: self
953 .guest_specified_interrupt_info
954 .map(SignalInfo::restore),
955 flags: self.flags.into(),
956 }
957 }
958}
959
960#[derive(Debug, Copy, Clone, Protobuf)]
961#[mesh(package = "vmbus.server.channels")]
962pub enum ModifyState {
963 #[mesh(1)]
964 NotModifying,
965 #[mesh(2)]
966 Modifying {
967 #[mesh(1)]
968 pending_target_vp: Option<u32>,
969 },
970}
971
972impl ModifyState {
973 fn save(value: &super::ModifyState) -> Self {
974 match value {
975 super::ModifyState::NotModifying => Self::NotModifying,
976 super::ModifyState::Modifying { pending_target_vp } => Self::Modifying {
977 pending_target_vp: *pending_target_vp,
978 },
979 }
980 }
981
982 fn restore(self) -> super::ModifyState {
983 match self {
984 ModifyState::NotModifying => super::ModifyState::NotModifying,
985 ModifyState::Modifying { pending_target_vp } => {
986 super::ModifyState::Modifying { pending_target_vp }
987 }
988 }
989 }
990}
991
992#[derive(Debug, PartialEq, Eq, Clone, Protobuf)]
993#[mesh(package = "vmbus.server.channels")]
994pub struct ReservedState {
995 #[mesh(1)]
996 pub version: VersionInfo,
997 #[mesh(2)]
998 pub vp: u32,
999 #[mesh(3)]
1000 pub sint: u8,
1001}
1002
1003impl ReservedState {
1004 fn save(reserved_state: &super::ReservedState) -> Self {
1005 Self {
1006 version: VersionInfo::save(&reserved_state.version),
1007 vp: reserved_state.target.vp,
1008 sint: reserved_state.target.sint,
1009 }
1010 }
1011
1012 fn restore(&self) -> Result<super::ReservedState, RestoreError> {
1013 let version = self.version.clone().restore(true).map_err(|e| match e {
1016 RestoreError::UnsupportedVersion(v) => RestoreError::UnsupportedReserveVersion(v),
1017 RestoreError::UnsupportedFeatureFlags(f) => {
1018 RestoreError::UnsupportedReserveFeatureFlags(f)
1019 }
1020 err => err,
1021 })?;
1022
1023 if version.version < Version::Win10 {
1024 return Err(RestoreError::UnsupportedReserveVersion(
1025 version.version as u32,
1026 ));
1027 }
1028
1029 Ok(super::ReservedState {
1030 version,
1031 target: super::ConnectionTarget {
1032 vp: self.vp,
1033 sint: self.sint,
1034 },
1035 })
1036 }
1037}
1038
1039#[derive(Debug, Clone, Protobuf)]
1040#[mesh(package = "vmbus.server.channels")]
1041pub enum ChannelState {
1042 #[mesh(1)]
1043 Closed,
1044 #[mesh(2)]
1045 Opening {
1046 #[mesh(1)]
1047 request: OpenRequest,
1048 #[mesh(2)]
1049 reserved_state: Option<ReservedState>,
1050 },
1051 #[mesh(3)]
1052 Open {
1053 #[mesh(1)]
1054 params: OpenRequest,
1055 #[mesh(2)]
1056 modify_state: ModifyState,
1057 #[mesh(3)]
1058 reserved_state: Option<ReservedState>,
1059 },
1060 #[mesh(4)]
1061 Closing {
1062 #[mesh(1)]
1063 params: OpenRequest,
1064 #[mesh(2)]
1065 reserved_state: Option<ReservedState>,
1066 },
1067 #[mesh(5)]
1068 ClosingReopen {
1069 #[mesh(1)]
1070 params: OpenRequest,
1071 #[mesh(2)]
1072 request: OpenRequest,
1073 },
1074 #[mesh(6)]
1075 Revoked,
1076}
1077
1078impl ChannelState {
1079 fn save(value: &super::ChannelState) -> Option<Self> {
1080 Some(match value {
1081 super::ChannelState::Closed => ChannelState::Closed,
1082 super::ChannelState::Opening {
1083 request,
1084 reserved_state,
1085 } => ChannelState::Opening {
1086 request: OpenRequest::save(request),
1087 reserved_state: reserved_state.as_ref().map(ReservedState::save),
1088 },
1089 super::ChannelState::ClosingReopen { params, request } => ChannelState::ClosingReopen {
1090 params: OpenRequest::save(params),
1091 request: OpenRequest::save(request),
1092 },
1093 super::ChannelState::Open {
1094 params,
1095 modify_state,
1096 reserved_state,
1097 } => ChannelState::Open {
1098 params: OpenRequest::save(params),
1099 modify_state: ModifyState::save(modify_state),
1100 reserved_state: reserved_state.as_ref().map(ReservedState::save),
1101 },
1102 super::ChannelState::Closing {
1103 params,
1104 reserved_state,
1105 } => ChannelState::Closing {
1106 params: OpenRequest::save(params),
1107 reserved_state: reserved_state.as_ref().map(ReservedState::save),
1108 },
1109
1110 super::ChannelState::Revoked => ChannelState::Revoked,
1111 super::ChannelState::Reoffered => ChannelState::Revoked,
1112 super::ChannelState::ClientReleased
1113 | super::ChannelState::ClosingClientRelease
1114 | super::ChannelState::OpeningClientRelease => return None,
1115 })
1116 }
1117
1118 fn restore(&self) -> Result<super::ChannelState, RestoreError> {
1119 Ok(match self {
1120 ChannelState::Closed => super::ChannelState::Closed,
1121 ChannelState::Opening {
1122 request,
1123 reserved_state,
1124 } => super::ChannelState::Opening {
1125 request: request.restore(),
1126 reserved_state: reserved_state
1127 .as_ref()
1128 .map(ReservedState::restore)
1129 .transpose()?,
1130 },
1131 ChannelState::ClosingReopen { params, request } => super::ChannelState::ClosingReopen {
1132 params: params.restore(),
1133 request: request.restore(),
1134 },
1135 ChannelState::Open {
1136 params,
1137 modify_state,
1138 reserved_state,
1139 } => super::ChannelState::Open {
1140 params: params.restore(),
1141 modify_state: modify_state.restore(),
1142 reserved_state: reserved_state
1143 .as_ref()
1144 .map(ReservedState::restore)
1145 .transpose()?,
1146 },
1147 ChannelState::Closing {
1148 params,
1149 reserved_state,
1150 } => super::ChannelState::Closing {
1151 params: params.restore(),
1152 reserved_state: reserved_state
1153 .as_ref()
1154 .map(ReservedState::restore)
1155 .transpose()?,
1156 },
1157 ChannelState::Revoked => {
1158 super::ChannelState::Reoffered
1160 }
1161 })
1162 }
1163}
1164
1165impl Display for ChannelState {
1166 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1167 let state = match self {
1168 Self::Closed => "Closed",
1169 Self::Opening { .. } => "Opening",
1170 Self::Open { .. } => "Open",
1171 Self::Closing { .. } => "Closing",
1172 Self::ClosingReopen { .. } => "ClosingReopen",
1173 Self::Revoked => "Revoked",
1174 };
1175 write!(f, "{}", state)
1176 }
1177}
1178
1179#[derive(Debug, Clone, Protobuf)]
1180#[mesh(package = "vmbus.server.channels")]
1181pub struct Gpadl {
1182 #[mesh(1)]
1183 pub id: u32,
1184 #[mesh(2)]
1185 pub channel_id: u32,
1186 #[mesh(3)]
1187 pub count: u16,
1188 #[mesh(4)]
1189 pub buf: Vec<u64>,
1190 #[mesh(5)]
1191 pub state: GpadlState,
1192}
1193
1194impl Gpadl {
1195 fn save(gpadl_id: GpadlId, channel_id: ChannelId, gpadl: &super::Gpadl) -> Option<Self> {
1196 tracing::trace!(id = %gpadl_id.0, channel_id = %channel_id.0, "gpadl saved");
1197 Some(Gpadl {
1198 id: gpadl_id.0,
1199 channel_id: channel_id.0,
1200 count: gpadl.count,
1201 buf: gpadl.buf.clone(),
1202 state: match gpadl.state {
1203 super::GpadlState::InProgress => GpadlState::InProgress,
1204 super::GpadlState::Offered => GpadlState::Offered,
1205 super::GpadlState::Accepted => GpadlState::Accepted,
1206 super::GpadlState::TearingDown => GpadlState::TearingDown,
1207 super::GpadlState::OfferedTearingDown => return None,
1208 },
1209 })
1210 }
1211
1212 fn restore(self, channel: &super::Channel) -> Result<super::Gpadl, RestoreError> {
1213 if self.state != GpadlState::InProgress {
1214 gparange::validate_gpa_ranges(self.count.into(), &self.buf)?;
1216 }
1217 let (state, allow_revoked) = match self.state {
1218 GpadlState::InProgress => (super::GpadlState::InProgress, true),
1219 GpadlState::Offered => (super::GpadlState::Offered, false),
1220 GpadlState::Accepted => {
1221 (super::GpadlState::Accepted, true)
1223 }
1224 GpadlState::TearingDown => (super::GpadlState::TearingDown, false),
1225 };
1226
1227 if !allow_revoked && channel.state.is_revoked() {
1228 return Err(RestoreError::GpadlForRevokedChannel(
1229 GpadlId(self.id),
1230 ChannelId(self.channel_id),
1231 ));
1232 }
1233
1234 Ok(super::Gpadl {
1235 count: self.count,
1236 buf: self.buf,
1237 state,
1238 })
1239 }
1240
1241 pub fn is_tearing_down(&self) -> bool {
1242 self.state == GpadlState::TearingDown
1243 }
1244}
1245
1246#[derive(Debug, Clone, Protobuf, PartialEq, Eq)]
1247#[mesh(package = "vmbus.server.channels")]
1248pub enum GpadlState {
1249 #[mesh(1)]
1250 InProgress,
1251 #[mesh(2)]
1252 Offered,
1253 #[mesh(3)]
1254 Accepted,
1255 #[mesh(4)]
1256 TearingDown,
1257}
1258
1259#[derive(Debug, Clone, Protobuf, PartialEq, Eq)]
1260#[mesh(package = "vmbus.server.channels")]
1261pub struct OutgoingMessage(pub Vec<u8>);
1262
1263impl OutgoingMessage {
1264 fn save(value: &vmbus_core::OutgoingMessage) -> Self {
1265 Self(value.data().to_vec())
1266 }
1267
1268 fn restore(self) -> Result<vmbus_core::OutgoingMessage, RestoreError> {
1269 vmbus_core::OutgoingMessage::from_message(&self.0)
1270 .map_err(|_| RestoreError::MessageTooLarge)
1271 }
1272}