vmbus_client/
saved_state.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! Saved state support for the vmbus client.
5
6use crate::ConnectResult;
7use crate::OfferInfo;
8use crate::RestoreError;
9use crate::SUPPORTED_FEATURE_FLAGS;
10use guid::Guid;
11use mesh::payload::Protobuf;
12use vmbus_channel::bus::OfferKey;
13use vmbus_core::OutgoingMessage;
14use vmbus_core::VersionInfo;
15use vmbus_core::protocol;
16use vmbus_core::protocol::ChannelId;
17use vmbus_core::protocol::FeatureFlags;
18use vmbus_core::protocol::GpadlId;
19
20impl super::ClientTask {
21    pub fn handle_save(&mut self) -> SavedState {
22        assert!(!self.running);
23
24        let mut pending_messages = self
25            .inner
26            .messages
27            .queued
28            .iter()
29            .map(|msg| PendingMessage {
30                data: msg.data().to_vec(),
31            })
32            .collect::<Vec<_>>();
33
34        // It's the responsibility of the caller to ensure the client is in a state where it's
35        // possible to save.
36        SavedState {
37            client_state: match self.state {
38                super::ClientState::Disconnected => ClientState::Disconnected,
39                super::ClientState::Connecting { .. } => {
40                    unreachable!("Cannot save in Connecting state.")
41                }
42                super::ClientState::Connected { version, .. } => ClientState::Connected {
43                    version: version.version as u32,
44                    feature_flags: version.feature_flags.into(),
45                },
46                super::ClientState::RequestingOffers { .. } => {
47                    unreachable!("Cannot save in RequestingOffers state.")
48                }
49                super::ClientState::Disconnecting { .. } => {
50                    unreachable!("Cannot save in Disconnecting state.")
51                }
52            },
53            channels: self
54                .channels
55                .0
56                .iter()
57                .filter_map(|(&id, v)| {
58                    let Some(state) = ChannelState::save(&v.state) else {
59                        if let Some(request) = v.pending_request() {
60                            panic!("revoked channel {id} has pending request '{request}' that should be drained", id = id.0);
61                        }
62                        // The channel has been revoked, but the user is not
63                        // done with it. The channel won't be available for use
64                        // when we restore, so don't save it, but do save a
65                        // pending message to the server to release the channel
66                        // ID.
67                        pending_messages.push(PendingMessage {
68                            data: OutgoingMessage::new(&protocol::RelIdReleased { channel_id: id })
69                                .data()
70                                .to_vec(),
71                        });
72                        return None;
73                    };
74                    assert!(
75                        v.modify_response_send.is_none(),
76                        "Cannot save a channel that is being modified."
77                    );
78                    let key = offer_key(&v.offer);
79                    tracing::info!(%key, %v.state, "channel saved");
80                    Some(Channel {
81                        id: id.0,
82                        state,
83                        offer: v.offer.into(),
84                    })
85                })
86                .collect(),
87            gpadls: self
88                .channels
89                .0
90                .iter()
91                .flat_map(|(channel_id, channel)| {
92                    channel.gpadls.iter().map(|(gpadl_id, gpadl_state)| Gpadl {
93                        gpadl_id: gpadl_id.0,
94                        channel_id: channel_id.0,
95                        state: GpadlState::save(gpadl_state),
96                    })
97                })
98                .collect(),
99            pending_messages,
100        }
101    }
102
103    pub fn handle_restore(
104        &mut self,
105        saved_state: SavedState,
106    ) -> Result<Option<ConnectResult>, RestoreError> {
107        assert!(!self.running);
108
109        let SavedState {
110            client_state,
111            channels,
112            gpadls,
113            pending_messages,
114        } = saved_state;
115
116        let (version, feature_flags) = match client_state {
117            ClientState::Disconnected => return Ok(None),
118            ClientState::Connected {
119                version,
120                feature_flags,
121            } => (version, feature_flags),
122        };
123
124        let version = super::SUPPORTED_VERSIONS
125            .iter()
126            .find(|v| version == **v as u32)
127            .copied()
128            .ok_or(RestoreError::UnsupportedVersion(version))?;
129
130        let feature_flags = FeatureFlags::from(feature_flags);
131        if !SUPPORTED_FEATURE_FLAGS.contains(feature_flags) {
132            return Err(RestoreError::UnsupportedFeatureFlags(feature_flags.into()));
133        }
134
135        let version = VersionInfo {
136            version,
137            feature_flags,
138        };
139
140        let (offer_send, offer_recv) = mesh::channel();
141        self.state = super::ClientState::Connected {
142            version,
143            offer_send,
144        };
145
146        let mut restored_channels = Vec::new();
147        for saved_channel in channels {
148            let offer_info = self.restore_channel(saved_channel)?;
149            let key = offer_key(&offer_info.offer);
150            tracing::info!(%key, state = %saved_channel.state, "channel restored");
151            restored_channels.push(offer_info);
152        }
153
154        for gpadl in gpadls {
155            let channel_id = ChannelId(gpadl.channel_id);
156            let gpadl_id = GpadlId(gpadl.gpadl_id);
157            let gpadl_state = gpadl.state.restore();
158            let tearing_down = matches!(gpadl_state, super::GpadlState::TearingDown { .. });
159
160            let channel = self
161                .channels
162                .0
163                .get_mut(&channel_id)
164                .ok_or(RestoreError::GpadlForUnknownChannelId(channel_id.0))?;
165
166            if channel.gpadls.insert(gpadl_id, gpadl_state).is_some() {
167                return Err(RestoreError::DuplicateGpadlId(gpadl_id.0));
168            }
169
170            if tearing_down
171                && self
172                    .inner
173                    .teardown_gpadls
174                    .insert(gpadl_id, channel_id)
175                    .is_some()
176            {
177                unreachable!("gpadl ID validated above");
178            }
179        }
180
181        for message in pending_messages {
182            self.inner.messages.queued.push_back(
183                OutgoingMessage::from_message(&message.data)
184                    .map_err(RestoreError::InvalidPendingMessage)?,
185            );
186        }
187
188        Ok(Some(ConnectResult {
189            version,
190            offers: restored_channels,
191            offer_recv,
192        }))
193    }
194
195    pub fn handle_post_restore(&mut self) {
196        assert!(!self.running);
197
198        // Close restored channels that have not been claimed.
199        for (&channel_id, channel) in &mut self.channels.0 {
200            if let super::ChannelState::Restored = channel.state {
201                tracing::info!(
202                    channel_id = channel_id.0,
203                    "closing unclaimed restored channel"
204                );
205                self.inner
206                    .messages
207                    .send(&protocol::CloseChannel { channel_id });
208                channel.state = super::ChannelState::Offered;
209
210                for (&gpadl_id, gpadl_state) in &mut channel.gpadls {
211                    // FUTURE: wait for GPADL teardown so that everything is in a clean
212                    // state after this.
213                    match gpadl_state {
214                        crate::GpadlState::Offered(_) => unreachable!(),
215                        crate::GpadlState::Created => {
216                            self.inner.teardown_gpadls.insert(gpadl_id, channel_id);
217                            self.inner.messages.send(&protocol::GpadlTeardown {
218                                channel_id,
219                                gpadl_id,
220                            });
221                            *gpadl_state = crate::GpadlState::TearingDown { rpcs: Vec::new() };
222                        }
223                        crate::GpadlState::TearingDown { .. } => {}
224                    }
225                }
226            }
227        }
228    }
229
230    fn restore_channel(&mut self, channel: Channel) -> Result<OfferInfo, RestoreError> {
231        self.create_channel_core(channel.offer.into(), channel.state.restore())
232            .map_err(RestoreError::OfferFailed)
233    }
234}
235
236#[derive(Clone, Debug, PartialEq, Eq, Protobuf)]
237#[mesh(package = "vmbus.client")]
238pub struct SavedState {
239    #[mesh(1)]
240    pub client_state: ClientState,
241    #[mesh(2)]
242    pub channels: Vec<Channel>,
243    #[mesh(3)]
244    pub gpadls: Vec<Gpadl>,
245    #[mesh(4)]
246    pub pending_messages: Vec<PendingMessage>,
247}
248
249#[derive(Clone, Debug, PartialEq, Eq, Protobuf)]
250#[mesh(package = "vmbus.client")]
251pub struct PendingMessage {
252    #[mesh(1)]
253    pub data: Vec<u8>,
254}
255
256#[derive(Clone, Debug, PartialEq, Eq, Protobuf)]
257#[mesh(package = "vmbus.client")]
258pub enum ClientState {
259    #[mesh(1)]
260    Disconnected,
261    #[mesh(2)]
262    Connected {
263        #[mesh(1)]
264        version: u32,
265        #[mesh(2)]
266        feature_flags: u32,
267    },
268}
269
270#[derive(Debug, Copy, Clone, PartialEq, Eq, Protobuf)]
271#[mesh(package = "vmbus.client")]
272pub struct Channel {
273    #[mesh(1)]
274    pub id: u32,
275    #[mesh(2)]
276    pub state: ChannelState,
277    #[mesh(3)]
278    pub offer: Offer,
279}
280
281#[derive(Debug, Copy, Clone, PartialEq, Eq, Protobuf)]
282#[mesh(package = "vmbus.client")]
283pub enum ChannelState {
284    #[mesh(1)]
285    Offered,
286    #[mesh(2)]
287    Opened,
288}
289
290impl ChannelState {
291    fn save(state: &super::ChannelState) -> Option<Self> {
292        let s = match state {
293            super::ChannelState::Offered => Self::Offered,
294            super::ChannelState::Opening { .. } => {
295                unreachable!("Cannot save channel in opening state.")
296            }
297            super::ChannelState::Restored | super::ChannelState::Opened { .. } => Self::Opened,
298            super::ChannelState::Revoked => return None,
299        };
300        Some(s)
301    }
302
303    fn restore(self) -> super::ChannelState {
304        match self {
305            ChannelState::Offered => super::ChannelState::Offered,
306            ChannelState::Opened => super::ChannelState::Restored,
307        }
308    }
309}
310
311impl std::fmt::Display for ChannelState {
312    fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
313        match self {
314            ChannelState::Offered => write!(fmt, "Offered"),
315            ChannelState::Opened => write!(fmt, "Opened"),
316        }
317    }
318}
319
320#[derive(Clone, Debug, PartialEq, Eq, Protobuf)]
321#[mesh(package = "vmbus.client")]
322pub enum GpadlState {
323    #[mesh(1)]
324    Created,
325    #[mesh(2)]
326    TearingDown,
327}
328
329impl GpadlState {
330    fn save(value: &super::GpadlState) -> Self {
331        match value {
332            super::GpadlState::Offered(..) => unreachable!("Cannot save gpadl in offered state."),
333            super::GpadlState::Created => Self::Created,
334            super::GpadlState::TearingDown { .. } => Self::TearingDown,
335        }
336    }
337
338    fn restore(self) -> super::GpadlState {
339        match self {
340            GpadlState::Created => super::GpadlState::Created,
341            GpadlState::TearingDown => super::GpadlState::TearingDown { rpcs: Vec::new() },
342        }
343    }
344}
345
346#[derive(Clone, Debug, PartialEq, Eq, Protobuf)]
347#[mesh(package = "vmbus.client")]
348pub struct Gpadl {
349    #[mesh(1)]
350    pub gpadl_id: u32,
351    #[mesh(2)]
352    pub channel_id: u32,
353    #[mesh(3)]
354    pub state: GpadlState,
355}
356
357#[derive(Copy, Clone, Debug, PartialEq, Eq, Protobuf)]
358#[mesh(package = "vmbus.client")]
359pub struct Offer {
360    #[mesh(1)]
361    pub interface_id: Guid,
362    #[mesh(2)]
363    pub instance_id: Guid,
364    #[mesh(3)]
365    pub flags: u16,
366    #[mesh(4)]
367    pub mmio_megabytes: u16,
368    #[mesh(5)]
369    pub user_defined: [u8; 120],
370    #[mesh(6)]
371    pub subchannel_index: u16,
372    #[mesh(7)]
373    pub mmio_megabytes_optional: u16,
374    #[mesh(8)]
375    pub channel_id: u32,
376    #[mesh(9)]
377    pub monitor_id: u8,
378    #[mesh(10)]
379    pub monitor_allocated: u8,
380    #[mesh(11)]
381    pub is_dedicated: u16,
382    #[mesh(12)]
383    pub connection_id: u32,
384}
385
386impl From<protocol::OfferChannel> for Offer {
387    fn from(offer: protocol::OfferChannel) -> Self {
388        Self {
389            interface_id: offer.interface_id,
390            instance_id: offer.instance_id,
391            flags: offer.flags.into(),
392            mmio_megabytes: offer.mmio_megabytes,
393            user_defined: offer.user_defined.into(),
394            subchannel_index: offer.subchannel_index,
395            mmio_megabytes_optional: offer.mmio_megabytes_optional,
396            channel_id: offer.channel_id.0,
397            monitor_id: offer.monitor_id,
398            monitor_allocated: offer.monitor_allocated,
399            is_dedicated: offer.is_dedicated,
400            connection_id: offer.connection_id,
401        }
402    }
403}
404
405impl From<Offer> for protocol::OfferChannel {
406    fn from(offer: Offer) -> Self {
407        Self {
408            interface_id: offer.interface_id,
409            instance_id: offer.instance_id,
410            flags: offer.flags.into(),
411            rsvd: [0; 4],
412            mmio_megabytes: offer.mmio_megabytes,
413            user_defined: offer.user_defined.into(),
414            subchannel_index: offer.subchannel_index,
415            mmio_megabytes_optional: offer.mmio_megabytes_optional,
416            channel_id: ChannelId(offer.channel_id),
417            monitor_id: offer.monitor_id,
418            monitor_allocated: offer.monitor_allocated,
419            is_dedicated: offer.is_dedicated,
420            connection_id: offer.connection_id,
421        }
422    }
423}
424
425fn offer_key(offer: &protocol::OfferChannel) -> OfferKey {
426    OfferKey {
427        interface_id: offer.interface_id,
428        instance_id: offer.instance_id,
429        subchannel_index: offer.subchannel_index,
430    }
431}