vmbus_client/
saved_state.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4use crate::ConnectResult;
5use crate::OfferInfo;
6use crate::RestoreError;
7use crate::SUPPORTED_FEATURE_FLAGS;
8use guid::Guid;
9use mesh::payload::Protobuf;
10use vmbus_channel::bus::OfferKey;
11use vmbus_core::OutgoingMessage;
12use vmbus_core::VersionInfo;
13use vmbus_core::protocol;
14use vmbus_core::protocol::ChannelId;
15use vmbus_core::protocol::FeatureFlags;
16use vmbus_core::protocol::GpadlId;
17
18impl super::ClientTask {
19    pub fn handle_save(&mut self) -> SavedState {
20        assert!(!self.running);
21
22        let mut pending_messages = self
23            .inner
24            .messages
25            .queued
26            .iter()
27            .map(|msg| PendingMessage {
28                data: msg.data().to_vec(),
29            })
30            .collect::<Vec<_>>();
31
32        // It's the responsibility of the caller to ensure the client is in a state where it's
33        // possible to save.
34        SavedState {
35            client_state: match self.state {
36                super::ClientState::Disconnected => ClientState::Disconnected,
37                super::ClientState::Connecting { .. } => {
38                    unreachable!("Cannot save in Connecting state.")
39                }
40                super::ClientState::Connected { version, .. } => ClientState::Connected {
41                    version: version.version as u32,
42                    feature_flags: version.feature_flags.into(),
43                },
44                super::ClientState::RequestingOffers { .. } => {
45                    unreachable!("Cannot save in RequestingOffers state.")
46                }
47                super::ClientState::Disconnecting { .. } => {
48                    unreachable!("Cannot save in Disconnecting state.")
49                }
50            },
51            channels: self
52                .channels
53                .0
54                .iter()
55                .filter_map(|(&id, v)| {
56                    let Some(state) = ChannelState::save(&v.state) else {
57                        if let Some(request) = v.pending_request() {
58                            panic!("revoked channel {id} has pending request '{request}' that should be drained", id = id.0);
59                        }
60                        // The channel has been revoked, but the user is not
61                        // done with it. The channel won't be available for use
62                        // when we restore, so don't save it, but do save a
63                        // pending message to the server to release the channel
64                        // ID.
65                        pending_messages.push(PendingMessage {
66                            data: OutgoingMessage::new(&protocol::RelIdReleased { channel_id: id })
67                                .data()
68                                .to_vec(),
69                        });
70                        return None;
71                    };
72                    assert!(
73                        v.modify_response_send.is_none(),
74                        "Cannot save a channel that is being modified."
75                    );
76                    let key = offer_key(&v.offer);
77                    tracing::info!(%key, %v.state, "channel saved");
78                    Some(Channel {
79                        id: id.0,
80                        state,
81                        offer: v.offer.into(),
82                    })
83                })
84                .collect(),
85            gpadls: self
86                .channels
87                .0
88                .iter()
89                .flat_map(|(channel_id, channel)| {
90                    channel.gpadls.iter().map(|(gpadl_id, gpadl_state)| Gpadl {
91                        gpadl_id: gpadl_id.0,
92                        channel_id: channel_id.0,
93                        state: GpadlState::save(gpadl_state),
94                    })
95                })
96                .collect(),
97            pending_messages,
98        }
99    }
100
101    pub fn handle_restore(
102        &mut self,
103        saved_state: SavedState,
104    ) -> Result<Option<ConnectResult>, RestoreError> {
105        assert!(!self.running);
106
107        let SavedState {
108            client_state,
109            channels,
110            gpadls,
111            pending_messages,
112        } = saved_state;
113
114        let (version, feature_flags) = match client_state {
115            ClientState::Disconnected => return Ok(None),
116            ClientState::Connected {
117                version,
118                feature_flags,
119            } => (version, feature_flags),
120        };
121
122        let version = super::SUPPORTED_VERSIONS
123            .iter()
124            .find(|v| version == **v as u32)
125            .copied()
126            .ok_or(RestoreError::UnsupportedVersion(version))?;
127
128        let feature_flags = FeatureFlags::from(feature_flags);
129        if !SUPPORTED_FEATURE_FLAGS.contains(feature_flags) {
130            return Err(RestoreError::UnsupportedFeatureFlags(feature_flags.into()));
131        }
132
133        let version = VersionInfo {
134            version,
135            feature_flags,
136        };
137
138        let (offer_send, offer_recv) = mesh::channel();
139        self.state = super::ClientState::Connected {
140            version,
141            offer_send,
142        };
143
144        let mut restored_channels = Vec::new();
145        for saved_channel in channels {
146            let offer_info = self.restore_channel(saved_channel)?;
147            let key = offer_key(&offer_info.offer);
148            tracing::info!(%key, state = %saved_channel.state, "channel restored");
149            restored_channels.push(offer_info);
150        }
151
152        for gpadl in gpadls {
153            let channel_id = ChannelId(gpadl.channel_id);
154            let gpadl_id = GpadlId(gpadl.gpadl_id);
155            let gpadl_state = gpadl.state.restore();
156            let tearing_down = matches!(gpadl_state, super::GpadlState::TearingDown { .. });
157
158            let channel = self
159                .channels
160                .0
161                .get_mut(&channel_id)
162                .ok_or(RestoreError::GpadlForUnknownChannelId(channel_id.0))?;
163
164            if channel.gpadls.insert(gpadl_id, gpadl_state).is_some() {
165                return Err(RestoreError::DuplicateGpadlId(gpadl_id.0));
166            }
167
168            if tearing_down
169                && self
170                    .inner
171                    .teardown_gpadls
172                    .insert(gpadl_id, channel_id)
173                    .is_some()
174            {
175                unreachable!("gpadl ID validated above");
176            }
177        }
178
179        for message in pending_messages {
180            self.inner.messages.queued.push_back(
181                OutgoingMessage::from_message(&message.data)
182                    .map_err(RestoreError::InvalidPendingMessage)?,
183            );
184        }
185
186        Ok(Some(ConnectResult {
187            version,
188            offers: restored_channels,
189            offer_recv,
190        }))
191    }
192
193    pub fn handle_post_restore(&mut self) {
194        assert!(!self.running);
195
196        // Close restored channels that have not been claimed.
197        for (&channel_id, channel) in &mut self.channels.0 {
198            if let super::ChannelState::Restored = channel.state {
199                tracing::info!(
200                    channel_id = channel_id.0,
201                    "closing unclaimed restored channel"
202                );
203                self.inner
204                    .messages
205                    .send(&protocol::CloseChannel { channel_id });
206                channel.state = super::ChannelState::Offered;
207
208                for (&gpadl_id, gpadl_state) in &mut channel.gpadls {
209                    // FUTURE: wait for GPADL teardown so that everything is in a clean
210                    // state after this.
211                    match gpadl_state {
212                        crate::GpadlState::Offered(_) => unreachable!(),
213                        crate::GpadlState::Created => {
214                            self.inner.teardown_gpadls.insert(gpadl_id, channel_id);
215                            self.inner.messages.send(&protocol::GpadlTeardown {
216                                channel_id,
217                                gpadl_id,
218                            });
219                            *gpadl_state = crate::GpadlState::TearingDown { rpcs: Vec::new() };
220                        }
221                        crate::GpadlState::TearingDown { .. } => {}
222                    }
223                }
224            }
225        }
226    }
227
228    fn restore_channel(&mut self, channel: Channel) -> Result<OfferInfo, RestoreError> {
229        self.create_channel_core(channel.offer.into(), channel.state.restore())
230            .map_err(RestoreError::OfferFailed)
231    }
232}
233
234#[derive(Clone, Debug, PartialEq, Eq, Protobuf)]
235#[mesh(package = "vmbus.client")]
236pub struct SavedState {
237    #[mesh(1)]
238    pub client_state: ClientState,
239    #[mesh(2)]
240    pub channels: Vec<Channel>,
241    #[mesh(3)]
242    pub gpadls: Vec<Gpadl>,
243    #[mesh(4)]
244    pub pending_messages: Vec<PendingMessage>,
245}
246
247#[derive(Clone, Debug, PartialEq, Eq, Protobuf)]
248#[mesh(package = "vmbus.client")]
249pub struct PendingMessage {
250    #[mesh(1)]
251    pub data: Vec<u8>,
252}
253
254#[derive(Clone, Debug, PartialEq, Eq, Protobuf)]
255#[mesh(package = "vmbus.client")]
256pub enum ClientState {
257    #[mesh(1)]
258    Disconnected,
259    #[mesh(2)]
260    Connected {
261        #[mesh(1)]
262        version: u32,
263        #[mesh(2)]
264        feature_flags: u32,
265    },
266}
267
268#[derive(Debug, Copy, Clone, PartialEq, Eq, Protobuf)]
269#[mesh(package = "vmbus.client")]
270pub struct Channel {
271    #[mesh(1)]
272    pub id: u32,
273    #[mesh(2)]
274    pub state: ChannelState,
275    #[mesh(3)]
276    pub offer: Offer,
277}
278
279#[derive(Debug, Copy, Clone, PartialEq, Eq, Protobuf)]
280#[mesh(package = "vmbus.client")]
281pub enum ChannelState {
282    #[mesh(1)]
283    Offered,
284    #[mesh(2)]
285    Opened,
286}
287
288impl ChannelState {
289    fn save(state: &super::ChannelState) -> Option<Self> {
290        let s = match state {
291            super::ChannelState::Offered => Self::Offered,
292            super::ChannelState::Opening { .. } => {
293                unreachable!("Cannot save channel in opening state.")
294            }
295            super::ChannelState::Restored | super::ChannelState::Opened { .. } => Self::Opened,
296            super::ChannelState::Revoked => return None,
297        };
298        Some(s)
299    }
300
301    fn restore(self) -> super::ChannelState {
302        match self {
303            ChannelState::Offered => super::ChannelState::Offered,
304            ChannelState::Opened => super::ChannelState::Restored,
305        }
306    }
307}
308
309impl std::fmt::Display for ChannelState {
310    fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
311        match self {
312            ChannelState::Offered => write!(fmt, "Offered"),
313            ChannelState::Opened => write!(fmt, "Opened"),
314        }
315    }
316}
317
318#[derive(Clone, Debug, PartialEq, Eq, Protobuf)]
319#[mesh(package = "vmbus.client")]
320pub enum GpadlState {
321    #[mesh(1)]
322    Created,
323    #[mesh(2)]
324    TearingDown,
325}
326
327impl GpadlState {
328    fn save(value: &super::GpadlState) -> Self {
329        match value {
330            super::GpadlState::Offered(..) => unreachable!("Cannot save gpadl in offered state."),
331            super::GpadlState::Created => Self::Created,
332            super::GpadlState::TearingDown { .. } => Self::TearingDown,
333        }
334    }
335
336    fn restore(self) -> super::GpadlState {
337        match self {
338            GpadlState::Created => super::GpadlState::Created,
339            GpadlState::TearingDown => super::GpadlState::TearingDown { rpcs: Vec::new() },
340        }
341    }
342}
343
344#[derive(Clone, Debug, PartialEq, Eq, Protobuf)]
345#[mesh(package = "vmbus.client")]
346pub struct Gpadl {
347    #[mesh(1)]
348    pub gpadl_id: u32,
349    #[mesh(2)]
350    pub channel_id: u32,
351    #[mesh(3)]
352    pub state: GpadlState,
353}
354
355#[derive(Copy, Clone, Debug, PartialEq, Eq, Protobuf)]
356#[mesh(package = "vmbus.client")]
357pub struct Offer {
358    #[mesh(1)]
359    pub interface_id: Guid,
360    #[mesh(2)]
361    pub instance_id: Guid,
362    #[mesh(3)]
363    pub flags: u16,
364    #[mesh(4)]
365    pub mmio_megabytes: u16,
366    #[mesh(5)]
367    pub user_defined: [u8; 120],
368    #[mesh(6)]
369    pub subchannel_index: u16,
370    #[mesh(7)]
371    pub mmio_megabytes_optional: u16,
372    #[mesh(8)]
373    pub channel_id: u32,
374    #[mesh(9)]
375    pub monitor_id: u8,
376    #[mesh(10)]
377    pub monitor_allocated: u8,
378    #[mesh(11)]
379    pub is_dedicated: u16,
380    #[mesh(12)]
381    pub connection_id: u32,
382}
383
384impl From<protocol::OfferChannel> for Offer {
385    fn from(offer: protocol::OfferChannel) -> Self {
386        Self {
387            interface_id: offer.interface_id,
388            instance_id: offer.instance_id,
389            flags: offer.flags.into(),
390            mmio_megabytes: offer.mmio_megabytes,
391            user_defined: offer.user_defined.into(),
392            subchannel_index: offer.subchannel_index,
393            mmio_megabytes_optional: offer.mmio_megabytes_optional,
394            channel_id: offer.channel_id.0,
395            monitor_id: offer.monitor_id,
396            monitor_allocated: offer.monitor_allocated,
397            is_dedicated: offer.is_dedicated,
398            connection_id: offer.connection_id,
399        }
400    }
401}
402
403impl From<Offer> for protocol::OfferChannel {
404    fn from(offer: Offer) -> Self {
405        Self {
406            interface_id: offer.interface_id,
407            instance_id: offer.instance_id,
408            flags: offer.flags.into(),
409            rsvd: [0; 4],
410            mmio_megabytes: offer.mmio_megabytes,
411            user_defined: offer.user_defined.into(),
412            subchannel_index: offer.subchannel_index,
413            mmio_megabytes_optional: offer.mmio_megabytes_optional,
414            channel_id: ChannelId(offer.channel_id),
415            monitor_id: offer.monitor_id,
416            monitor_allocated: offer.monitor_allocated,
417            is_dedicated: offer.is_dedicated,
418            connection_id: offer.connection_id,
419        }
420    }
421}
422
423fn offer_key(offer: &protocol::OfferChannel) -> OfferKey {
424    OfferKey {
425        interface_id: offer.interface_id,
426        instance_id: offer.instance_id,
427        subchannel_index: offer.subchannel_index,
428    }
429}