1use crate::ConnectResult;
5use crate::OfferInfo;
6use crate::RestoreError;
7use crate::SUPPORTED_FEATURE_FLAGS;
8use guid::Guid;
9use mesh::payload::Protobuf;
10use vmbus_channel::bus::OfferKey;
11use vmbus_core::OutgoingMessage;
12use vmbus_core::VersionInfo;
13use vmbus_core::protocol;
14use vmbus_core::protocol::ChannelId;
15use vmbus_core::protocol::FeatureFlags;
16use vmbus_core::protocol::GpadlId;
17
18impl super::ClientTask {
19 pub fn handle_save(&mut self) -> SavedState {
20 assert!(!self.running);
21
22 let mut pending_messages = self
23 .inner
24 .messages
25 .queued
26 .iter()
27 .map(|msg| PendingMessage {
28 data: msg.data().to_vec(),
29 })
30 .collect::<Vec<_>>();
31
32 SavedState {
35 client_state: match self.state {
36 super::ClientState::Disconnected => ClientState::Disconnected,
37 super::ClientState::Connecting { .. } => {
38 unreachable!("Cannot save in Connecting state.")
39 }
40 super::ClientState::Connected { version, .. } => ClientState::Connected {
41 version: version.version as u32,
42 feature_flags: version.feature_flags.into(),
43 },
44 super::ClientState::RequestingOffers { .. } => {
45 unreachable!("Cannot save in RequestingOffers state.")
46 }
47 super::ClientState::Disconnecting { .. } => {
48 unreachable!("Cannot save in Disconnecting state.")
49 }
50 },
51 channels: self
52 .channels
53 .0
54 .iter()
55 .filter_map(|(&id, v)| {
56 let Some(state) = ChannelState::save(&v.state) else {
57 if let Some(request) = v.pending_request() {
58 panic!("revoked channel {id} has pending request '{request}' that should be drained", id = id.0);
59 }
60 pending_messages.push(PendingMessage {
66 data: OutgoingMessage::new(&protocol::RelIdReleased { channel_id: id })
67 .data()
68 .to_vec(),
69 });
70 return None;
71 };
72 assert!(
73 v.modify_response_send.is_none(),
74 "Cannot save a channel that is being modified."
75 );
76 let key = offer_key(&v.offer);
77 tracing::info!(%key, %v.state, "channel saved");
78 Some(Channel {
79 id: id.0,
80 state,
81 offer: v.offer.into(),
82 })
83 })
84 .collect(),
85 gpadls: self
86 .channels
87 .0
88 .iter()
89 .flat_map(|(channel_id, channel)| {
90 channel.gpadls.iter().map(|(gpadl_id, gpadl_state)| Gpadl {
91 gpadl_id: gpadl_id.0,
92 channel_id: channel_id.0,
93 state: GpadlState::save(gpadl_state),
94 })
95 })
96 .collect(),
97 pending_messages,
98 }
99 }
100
101 pub fn handle_restore(
102 &mut self,
103 saved_state: SavedState,
104 ) -> Result<Option<ConnectResult>, RestoreError> {
105 assert!(!self.running);
106
107 let SavedState {
108 client_state,
109 channels,
110 gpadls,
111 pending_messages,
112 } = saved_state;
113
114 let (version, feature_flags) = match client_state {
115 ClientState::Disconnected => return Ok(None),
116 ClientState::Connected {
117 version,
118 feature_flags,
119 } => (version, feature_flags),
120 };
121
122 let version = super::SUPPORTED_VERSIONS
123 .iter()
124 .find(|v| version == **v as u32)
125 .copied()
126 .ok_or(RestoreError::UnsupportedVersion(version))?;
127
128 let feature_flags = FeatureFlags::from(feature_flags);
129 if !SUPPORTED_FEATURE_FLAGS.contains(feature_flags) {
130 return Err(RestoreError::UnsupportedFeatureFlags(feature_flags.into()));
131 }
132
133 let version = VersionInfo {
134 version,
135 feature_flags,
136 };
137
138 let (offer_send, offer_recv) = mesh::channel();
139 self.state = super::ClientState::Connected {
140 version,
141 offer_send,
142 };
143
144 let mut restored_channels = Vec::new();
145 for saved_channel in channels {
146 let offer_info = self.restore_channel(saved_channel)?;
147 let key = offer_key(&offer_info.offer);
148 tracing::info!(%key, state = %saved_channel.state, "channel restored");
149 restored_channels.push(offer_info);
150 }
151
152 for gpadl in gpadls {
153 let channel_id = ChannelId(gpadl.channel_id);
154 let gpadl_id = GpadlId(gpadl.gpadl_id);
155 let gpadl_state = gpadl.state.restore();
156 let tearing_down = matches!(gpadl_state, super::GpadlState::TearingDown { .. });
157
158 let channel = self
159 .channels
160 .0
161 .get_mut(&channel_id)
162 .ok_or(RestoreError::GpadlForUnknownChannelId(channel_id.0))?;
163
164 if channel.gpadls.insert(gpadl_id, gpadl_state).is_some() {
165 return Err(RestoreError::DuplicateGpadlId(gpadl_id.0));
166 }
167
168 if tearing_down
169 && self
170 .inner
171 .teardown_gpadls
172 .insert(gpadl_id, channel_id)
173 .is_some()
174 {
175 unreachable!("gpadl ID validated above");
176 }
177 }
178
179 for message in pending_messages {
180 self.inner.messages.queued.push_back(
181 OutgoingMessage::from_message(&message.data)
182 .map_err(RestoreError::InvalidPendingMessage)?,
183 );
184 }
185
186 Ok(Some(ConnectResult {
187 version,
188 offers: restored_channels,
189 offer_recv,
190 }))
191 }
192
193 pub fn handle_post_restore(&mut self) {
194 assert!(!self.running);
195
196 for (&channel_id, channel) in &mut self.channels.0 {
198 if let super::ChannelState::Restored = channel.state {
199 tracing::info!(
200 channel_id = channel_id.0,
201 "closing unclaimed restored channel"
202 );
203 self.inner
204 .messages
205 .send(&protocol::CloseChannel { channel_id });
206 channel.state = super::ChannelState::Offered;
207
208 for (&gpadl_id, gpadl_state) in &mut channel.gpadls {
209 match gpadl_state {
212 crate::GpadlState::Offered(_) => unreachable!(),
213 crate::GpadlState::Created => {
214 self.inner.teardown_gpadls.insert(gpadl_id, channel_id);
215 self.inner.messages.send(&protocol::GpadlTeardown {
216 channel_id,
217 gpadl_id,
218 });
219 *gpadl_state = crate::GpadlState::TearingDown { rpcs: Vec::new() };
220 }
221 crate::GpadlState::TearingDown { .. } => {}
222 }
223 }
224 }
225 }
226 }
227
228 fn restore_channel(&mut self, channel: Channel) -> Result<OfferInfo, RestoreError> {
229 self.create_channel_core(channel.offer.into(), channel.state.restore())
230 .map_err(RestoreError::OfferFailed)
231 }
232}
233
234#[derive(Clone, Debug, PartialEq, Eq, Protobuf)]
235#[mesh(package = "vmbus.client")]
236pub struct SavedState {
237 #[mesh(1)]
238 pub client_state: ClientState,
239 #[mesh(2)]
240 pub channels: Vec<Channel>,
241 #[mesh(3)]
242 pub gpadls: Vec<Gpadl>,
243 #[mesh(4)]
244 pub pending_messages: Vec<PendingMessage>,
245}
246
247#[derive(Clone, Debug, PartialEq, Eq, Protobuf)]
248#[mesh(package = "vmbus.client")]
249pub struct PendingMessage {
250 #[mesh(1)]
251 pub data: Vec<u8>,
252}
253
254#[derive(Clone, Debug, PartialEq, Eq, Protobuf)]
255#[mesh(package = "vmbus.client")]
256pub enum ClientState {
257 #[mesh(1)]
258 Disconnected,
259 #[mesh(2)]
260 Connected {
261 #[mesh(1)]
262 version: u32,
263 #[mesh(2)]
264 feature_flags: u32,
265 },
266}
267
268#[derive(Debug, Copy, Clone, PartialEq, Eq, Protobuf)]
269#[mesh(package = "vmbus.client")]
270pub struct Channel {
271 #[mesh(1)]
272 pub id: u32,
273 #[mesh(2)]
274 pub state: ChannelState,
275 #[mesh(3)]
276 pub offer: Offer,
277}
278
279#[derive(Debug, Copy, Clone, PartialEq, Eq, Protobuf)]
280#[mesh(package = "vmbus.client")]
281pub enum ChannelState {
282 #[mesh(1)]
283 Offered,
284 #[mesh(2)]
285 Opened,
286}
287
288impl ChannelState {
289 fn save(state: &super::ChannelState) -> Option<Self> {
290 let s = match state {
291 super::ChannelState::Offered => Self::Offered,
292 super::ChannelState::Opening { .. } => {
293 unreachable!("Cannot save channel in opening state.")
294 }
295 super::ChannelState::Restored | super::ChannelState::Opened { .. } => Self::Opened,
296 super::ChannelState::Revoked => return None,
297 };
298 Some(s)
299 }
300
301 fn restore(self) -> super::ChannelState {
302 match self {
303 ChannelState::Offered => super::ChannelState::Offered,
304 ChannelState::Opened => super::ChannelState::Restored,
305 }
306 }
307}
308
309impl std::fmt::Display for ChannelState {
310 fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
311 match self {
312 ChannelState::Offered => write!(fmt, "Offered"),
313 ChannelState::Opened => write!(fmt, "Opened"),
314 }
315 }
316}
317
318#[derive(Clone, Debug, PartialEq, Eq, Protobuf)]
319#[mesh(package = "vmbus.client")]
320pub enum GpadlState {
321 #[mesh(1)]
322 Created,
323 #[mesh(2)]
324 TearingDown,
325}
326
327impl GpadlState {
328 fn save(value: &super::GpadlState) -> Self {
329 match value {
330 super::GpadlState::Offered(..) => unreachable!("Cannot save gpadl in offered state."),
331 super::GpadlState::Created => Self::Created,
332 super::GpadlState::TearingDown { .. } => Self::TearingDown,
333 }
334 }
335
336 fn restore(self) -> super::GpadlState {
337 match self {
338 GpadlState::Created => super::GpadlState::Created,
339 GpadlState::TearingDown => super::GpadlState::TearingDown { rpcs: Vec::new() },
340 }
341 }
342}
343
344#[derive(Clone, Debug, PartialEq, Eq, Protobuf)]
345#[mesh(package = "vmbus.client")]
346pub struct Gpadl {
347 #[mesh(1)]
348 pub gpadl_id: u32,
349 #[mesh(2)]
350 pub channel_id: u32,
351 #[mesh(3)]
352 pub state: GpadlState,
353}
354
355#[derive(Copy, Clone, Debug, PartialEq, Eq, Protobuf)]
356#[mesh(package = "vmbus.client")]
357pub struct Offer {
358 #[mesh(1)]
359 pub interface_id: Guid,
360 #[mesh(2)]
361 pub instance_id: Guid,
362 #[mesh(3)]
363 pub flags: u16,
364 #[mesh(4)]
365 pub mmio_megabytes: u16,
366 #[mesh(5)]
367 pub user_defined: [u8; 120],
368 #[mesh(6)]
369 pub subchannel_index: u16,
370 #[mesh(7)]
371 pub mmio_megabytes_optional: u16,
372 #[mesh(8)]
373 pub channel_id: u32,
374 #[mesh(9)]
375 pub monitor_id: u8,
376 #[mesh(10)]
377 pub monitor_allocated: u8,
378 #[mesh(11)]
379 pub is_dedicated: u16,
380 #[mesh(12)]
381 pub connection_id: u32,
382}
383
384impl From<protocol::OfferChannel> for Offer {
385 fn from(offer: protocol::OfferChannel) -> Self {
386 Self {
387 interface_id: offer.interface_id,
388 instance_id: offer.instance_id,
389 flags: offer.flags.into(),
390 mmio_megabytes: offer.mmio_megabytes,
391 user_defined: offer.user_defined.into(),
392 subchannel_index: offer.subchannel_index,
393 mmio_megabytes_optional: offer.mmio_megabytes_optional,
394 channel_id: offer.channel_id.0,
395 monitor_id: offer.monitor_id,
396 monitor_allocated: offer.monitor_allocated,
397 is_dedicated: offer.is_dedicated,
398 connection_id: offer.connection_id,
399 }
400 }
401}
402
403impl From<Offer> for protocol::OfferChannel {
404 fn from(offer: Offer) -> Self {
405 Self {
406 interface_id: offer.interface_id,
407 instance_id: offer.instance_id,
408 flags: offer.flags.into(),
409 rsvd: [0; 4],
410 mmio_megabytes: offer.mmio_megabytes,
411 user_defined: offer.user_defined.into(),
412 subchannel_index: offer.subchannel_index,
413 mmio_megabytes_optional: offer.mmio_megabytes_optional,
414 channel_id: ChannelId(offer.channel_id),
415 monitor_id: offer.monitor_id,
416 monitor_allocated: offer.monitor_allocated,
417 is_dedicated: offer.is_dedicated,
418 connection_id: offer.connection_id,
419 }
420 }
421}
422
423fn offer_key(offer: &protocol::OfferChannel) -> OfferKey {
424 OfferKey {
425 interface_id: offer.interface_id,
426 instance_id: offer.instance_id,
427 subchannel_index: offer.subchannel_index,
428 }
429}