1use crate::ConnectResult;
7use crate::OfferInfo;
8use crate::RestoreError;
9use crate::SUPPORTED_FEATURE_FLAGS;
10use guid::Guid;
11use mesh::payload::Protobuf;
12use vmbus_channel::bus::OfferKey;
13use vmbus_core::OutgoingMessage;
14use vmbus_core::VersionInfo;
15use vmbus_core::protocol;
16use vmbus_core::protocol::ChannelId;
17use vmbus_core::protocol::FeatureFlags;
18use vmbus_core::protocol::GpadlId;
19
20impl super::ClientTask {
21 pub fn handle_save(&mut self) -> SavedState {
22 assert!(!self.running);
23
24 let mut pending_messages = self
25 .inner
26 .messages
27 .queued
28 .iter()
29 .map(|msg| PendingMessage {
30 data: msg.data().to_vec(),
31 })
32 .collect::<Vec<_>>();
33
34 SavedState {
37 client_state: match self.state {
38 super::ClientState::Disconnected => ClientState::Disconnected,
39 super::ClientState::Connecting { .. } => {
40 unreachable!("Cannot save in Connecting state.")
41 }
42 super::ClientState::Connected { version, .. } => ClientState::Connected {
43 version: version.version as u32,
44 feature_flags: version.feature_flags.into(),
45 },
46 super::ClientState::RequestingOffers { .. } => {
47 unreachable!("Cannot save in RequestingOffers state.")
48 }
49 super::ClientState::Disconnecting { .. } => {
50 unreachable!("Cannot save in Disconnecting state.")
51 }
52 },
53 channels: self
54 .channels
55 .0
56 .iter()
57 .filter_map(|(&id, v)| {
58 let Some(state) = ChannelState::save(&v.state) else {
59 if let Some(request) = v.pending_request() {
60 panic!("revoked channel {id} has pending request '{request}' that should be drained", id = id.0);
61 }
62 pending_messages.push(PendingMessage {
68 data: OutgoingMessage::new(&protocol::RelIdReleased { channel_id: id })
69 .data()
70 .to_vec(),
71 });
72 return None;
73 };
74 assert!(
75 v.modify_response_send.is_none(),
76 "Cannot save a channel that is being modified."
77 );
78 let key = offer_key(&v.offer);
79 tracing::info!(%key, %v.state, "channel saved");
80 Some(Channel {
81 id: id.0,
82 state,
83 offer: v.offer.into(),
84 })
85 })
86 .collect(),
87 gpadls: self
88 .channels
89 .0
90 .iter()
91 .flat_map(|(channel_id, channel)| {
92 channel.gpadls.iter().map(|(gpadl_id, gpadl_state)| Gpadl {
93 gpadl_id: gpadl_id.0,
94 channel_id: channel_id.0,
95 state: GpadlState::save(gpadl_state),
96 })
97 })
98 .collect(),
99 pending_messages,
100 }
101 }
102
103 pub fn handle_restore(
104 &mut self,
105 saved_state: SavedState,
106 ) -> Result<Option<ConnectResult>, RestoreError> {
107 assert!(!self.running);
108
109 let SavedState {
110 client_state,
111 channels,
112 gpadls,
113 pending_messages,
114 } = saved_state;
115
116 let (version, feature_flags) = match client_state {
117 ClientState::Disconnected => return Ok(None),
118 ClientState::Connected {
119 version,
120 feature_flags,
121 } => (version, feature_flags),
122 };
123
124 let version = super::SUPPORTED_VERSIONS
125 .iter()
126 .find(|v| version == **v as u32)
127 .copied()
128 .ok_or(RestoreError::UnsupportedVersion(version))?;
129
130 let feature_flags = FeatureFlags::from(feature_flags);
131 if !SUPPORTED_FEATURE_FLAGS.contains(feature_flags) {
132 return Err(RestoreError::UnsupportedFeatureFlags(feature_flags.into()));
133 }
134
135 let version = VersionInfo {
136 version,
137 feature_flags,
138 };
139
140 let (offer_send, offer_recv) = mesh::channel();
141 self.state = super::ClientState::Connected {
142 version,
143 offer_send,
144 };
145
146 let mut restored_channels = Vec::new();
147 for saved_channel in channels {
148 let offer_info = self.restore_channel(saved_channel)?;
149 let key = offer_key(&offer_info.offer);
150 tracing::info!(%key, state = %saved_channel.state, "channel restored");
151 restored_channels.push(offer_info);
152 }
153
154 for gpadl in gpadls {
155 let channel_id = ChannelId(gpadl.channel_id);
156 let gpadl_id = GpadlId(gpadl.gpadl_id);
157 let gpadl_state = gpadl.state.restore();
158 let tearing_down = matches!(gpadl_state, super::GpadlState::TearingDown { .. });
159
160 let channel = self
161 .channels
162 .0
163 .get_mut(&channel_id)
164 .ok_or(RestoreError::GpadlForUnknownChannelId(channel_id.0))?;
165
166 if channel.gpadls.insert(gpadl_id, gpadl_state).is_some() {
167 return Err(RestoreError::DuplicateGpadlId(gpadl_id.0));
168 }
169
170 if tearing_down
171 && self
172 .inner
173 .teardown_gpadls
174 .insert(gpadl_id, channel_id)
175 .is_some()
176 {
177 unreachable!("gpadl ID validated above");
178 }
179 }
180
181 for message in pending_messages {
182 self.inner.messages.queued.push_back(
183 OutgoingMessage::from_message(&message.data)
184 .map_err(RestoreError::InvalidPendingMessage)?,
185 );
186 }
187
188 Ok(Some(ConnectResult {
189 version,
190 offers: restored_channels,
191 offer_recv,
192 }))
193 }
194
195 pub fn handle_post_restore(&mut self) {
196 assert!(!self.running);
197
198 for (&channel_id, channel) in &mut self.channels.0 {
200 if let super::ChannelState::Restored = channel.state {
201 tracing::info!(
202 channel_id = channel_id.0,
203 "closing unclaimed restored channel"
204 );
205 self.inner
206 .messages
207 .send(&protocol::CloseChannel { channel_id });
208 channel.state = super::ChannelState::Offered;
209
210 for (&gpadl_id, gpadl_state) in &mut channel.gpadls {
211 match gpadl_state {
214 crate::GpadlState::Offered(_) => unreachable!(),
215 crate::GpadlState::Created => {
216 self.inner.teardown_gpadls.insert(gpadl_id, channel_id);
217 self.inner.messages.send(&protocol::GpadlTeardown {
218 channel_id,
219 gpadl_id,
220 });
221 *gpadl_state = crate::GpadlState::TearingDown { rpcs: Vec::new() };
222 }
223 crate::GpadlState::TearingDown { .. } => {}
224 }
225 }
226 }
227 }
228 }
229
230 fn restore_channel(&mut self, channel: Channel) -> Result<OfferInfo, RestoreError> {
231 self.create_channel_core(channel.offer.into(), channel.state.restore())
232 .map_err(RestoreError::OfferFailed)
233 }
234}
235
236#[derive(Clone, Debug, PartialEq, Eq, Protobuf)]
237#[mesh(package = "vmbus.client")]
238pub struct SavedState {
239 #[mesh(1)]
240 pub client_state: ClientState,
241 #[mesh(2)]
242 pub channels: Vec<Channel>,
243 #[mesh(3)]
244 pub gpadls: Vec<Gpadl>,
245 #[mesh(4)]
246 pub pending_messages: Vec<PendingMessage>,
247}
248
249#[derive(Clone, Debug, PartialEq, Eq, Protobuf)]
250#[mesh(package = "vmbus.client")]
251pub struct PendingMessage {
252 #[mesh(1)]
253 pub data: Vec<u8>,
254}
255
256#[derive(Clone, Debug, PartialEq, Eq, Protobuf)]
257#[mesh(package = "vmbus.client")]
258pub enum ClientState {
259 #[mesh(1)]
260 Disconnected,
261 #[mesh(2)]
262 Connected {
263 #[mesh(1)]
264 version: u32,
265 #[mesh(2)]
266 feature_flags: u32,
267 },
268}
269
270#[derive(Debug, Copy, Clone, PartialEq, Eq, Protobuf)]
271#[mesh(package = "vmbus.client")]
272pub struct Channel {
273 #[mesh(1)]
274 pub id: u32,
275 #[mesh(2)]
276 pub state: ChannelState,
277 #[mesh(3)]
278 pub offer: Offer,
279}
280
281#[derive(Debug, Copy, Clone, PartialEq, Eq, Protobuf)]
282#[mesh(package = "vmbus.client")]
283pub enum ChannelState {
284 #[mesh(1)]
285 Offered,
286 #[mesh(2)]
287 Opened,
288}
289
290impl ChannelState {
291 fn save(state: &super::ChannelState) -> Option<Self> {
292 let s = match state {
293 super::ChannelState::Offered => Self::Offered,
294 super::ChannelState::Opening { .. } => {
295 unreachable!("Cannot save channel in opening state.")
296 }
297 super::ChannelState::Restored | super::ChannelState::Opened { .. } => Self::Opened,
298 super::ChannelState::Revoked => return None,
299 };
300 Some(s)
301 }
302
303 fn restore(self) -> super::ChannelState {
304 match self {
305 ChannelState::Offered => super::ChannelState::Offered,
306 ChannelState::Opened => super::ChannelState::Restored,
307 }
308 }
309}
310
311impl std::fmt::Display for ChannelState {
312 fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
313 match self {
314 ChannelState::Offered => write!(fmt, "Offered"),
315 ChannelState::Opened => write!(fmt, "Opened"),
316 }
317 }
318}
319
320#[derive(Clone, Debug, PartialEq, Eq, Protobuf)]
321#[mesh(package = "vmbus.client")]
322pub enum GpadlState {
323 #[mesh(1)]
324 Created,
325 #[mesh(2)]
326 TearingDown,
327}
328
329impl GpadlState {
330 fn save(value: &super::GpadlState) -> Self {
331 match value {
332 super::GpadlState::Offered(..) => unreachable!("Cannot save gpadl in offered state."),
333 super::GpadlState::Created => Self::Created,
334 super::GpadlState::TearingDown { .. } => Self::TearingDown,
335 }
336 }
337
338 fn restore(self) -> super::GpadlState {
339 match self {
340 GpadlState::Created => super::GpadlState::Created,
341 GpadlState::TearingDown => super::GpadlState::TearingDown { rpcs: Vec::new() },
342 }
343 }
344}
345
346#[derive(Clone, Debug, PartialEq, Eq, Protobuf)]
347#[mesh(package = "vmbus.client")]
348pub struct Gpadl {
349 #[mesh(1)]
350 pub gpadl_id: u32,
351 #[mesh(2)]
352 pub channel_id: u32,
353 #[mesh(3)]
354 pub state: GpadlState,
355}
356
357#[derive(Copy, Clone, Debug, PartialEq, Eq, Protobuf)]
358#[mesh(package = "vmbus.client")]
359pub struct Offer {
360 #[mesh(1)]
361 pub interface_id: Guid,
362 #[mesh(2)]
363 pub instance_id: Guid,
364 #[mesh(3)]
365 pub flags: u16,
366 #[mesh(4)]
367 pub mmio_megabytes: u16,
368 #[mesh(5)]
369 pub user_defined: [u8; 120],
370 #[mesh(6)]
371 pub subchannel_index: u16,
372 #[mesh(7)]
373 pub mmio_megabytes_optional: u16,
374 #[mesh(8)]
375 pub channel_id: u32,
376 #[mesh(9)]
377 pub monitor_id: u8,
378 #[mesh(10)]
379 pub monitor_allocated: u8,
380 #[mesh(11)]
381 pub is_dedicated: u16,
382 #[mesh(12)]
383 pub connection_id: u32,
384}
385
386impl From<protocol::OfferChannel> for Offer {
387 fn from(offer: protocol::OfferChannel) -> Self {
388 Self {
389 interface_id: offer.interface_id,
390 instance_id: offer.instance_id,
391 flags: offer.flags.into(),
392 mmio_megabytes: offer.mmio_megabytes,
393 user_defined: offer.user_defined.into(),
394 subchannel_index: offer.subchannel_index,
395 mmio_megabytes_optional: offer.mmio_megabytes_optional,
396 channel_id: offer.channel_id.0,
397 monitor_id: offer.monitor_id,
398 monitor_allocated: offer.monitor_allocated,
399 is_dedicated: offer.is_dedicated,
400 connection_id: offer.connection_id,
401 }
402 }
403}
404
405impl From<Offer> for protocol::OfferChannel {
406 fn from(offer: Offer) -> Self {
407 Self {
408 interface_id: offer.interface_id,
409 instance_id: offer.instance_id,
410 flags: offer.flags.into(),
411 rsvd: [0; 4],
412 mmio_megabytes: offer.mmio_megabytes,
413 user_defined: offer.user_defined.into(),
414 subchannel_index: offer.subchannel_index,
415 mmio_megabytes_optional: offer.mmio_megabytes_optional,
416 channel_id: ChannelId(offer.channel_id),
417 monitor_id: offer.monitor_id,
418 monitor_allocated: offer.monitor_allocated,
419 is_dedicated: offer.is_dedicated,
420 connection_id: offer.connection_id,
421 }
422 }
423}
424
425fn offer_key(offer: &protocol::OfferChannel) -> OfferKey {
426 OfferKey {
427 interface_id: offer.interface_id,
428 instance_id: offer.instance_id,
429 subchannel_index: offer.subchannel_index,
430 }
431}