vmbus_relay/
saved_state.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4use crate::ChannelId;
5use crate::ChannelInfo;
6use crate::InterceptChannelRequest;
7use crate::InterruptRelay;
8use crate::RelayChannelRequest;
9use crate::RelayChannelTask;
10use crate::RelayTask;
11use anyhow::Context as _;
12use anyhow::Result;
13use mesh::payload::Protobuf;
14use mesh::rpc::RpcSend;
15use pal_event::Event;
16use std::sync::atomic::Ordering;
17use vmbus_channel::bus::ChannelServerRequest;
18use vmbus_client as client;
19use vmcore::notify::Notify;
20use vmcore::save_restore::SavedStateRoot;
21
22impl RelayTask {
23    pub async fn handle_save(&self) -> SavedState {
24        assert!(!self.running);
25
26        let channels = futures::future::join_all(
27            self.channels
28                .iter()
29                .map(|(id, channel)| self.save_channel_state(*id, channel)),
30        )
31        .await
32        .drain(..)
33        .flatten()
34        .collect();
35
36        SavedState {
37            use_interrupt_relay: self.use_interrupt_relay.load(Ordering::SeqCst),
38            channels,
39        }
40    }
41
42    pub async fn handle_restore(&mut self, state: SavedState) -> Result<()> {
43        let SavedState {
44            use_interrupt_relay,
45            channels,
46        } = state;
47
48        self.use_interrupt_relay
49            .store(use_interrupt_relay, Ordering::SeqCst);
50
51        for saved_channel in channels {
52            let Some(channel) = self.channels.get_mut(&ChannelId(saved_channel.channel_id)) else {
53                tracing::info!(
54                    channel_id = saved_channel.channel_id,
55                    "channel not found during restore, probably revoked"
56                );
57                continue;
58            };
59            match channel {
60                ChannelInfo::Relay(info) => {
61                    info.relay_request_send
62                        .call_failable(RelayChannelRequest::Restore, saved_channel)
63                        .await?;
64                }
65                ChannelInfo::Intercept(id) => {
66                    if saved_channel.is_open {
67                        anyhow::bail!("cannot restore intercepted channel {id}");
68                    }
69                }
70            }
71        }
72
73        Ok(())
74    }
75
76    async fn save_channel_state(
77        &self,
78        channel_id: ChannelId,
79        channel: &ChannelInfo,
80    ) -> Option<Channel> {
81        match channel {
82            ChannelInfo::Relay(relay) => {
83                match relay
84                    .relay_request_send
85                    .call(RelayChannelRequest::Save, ())
86                    .await
87                {
88                    Ok(result) => Some(result),
89                    Err(err) => {
90                        tracing::error!(
91                            err = &err as &dyn std::error::Error,
92                            "Failed to save relay channel state"
93                        );
94                        None
95                    }
96                }
97            }
98            ChannelInfo::Intercept(id) => {
99                let intercepted_save_state = if let Some(intercepted_channel) =
100                    self.intercept_channels.get(id)
101                {
102                    let result = intercepted_channel
103                        .call(InterceptChannelRequest::Save, ())
104                        .await;
105                    match result {
106                        Ok(save_state) => mesh_protobuf::encode(save_state),
107                        Err(err) => {
108                            tracing::error!(err = &err as &dyn std::error::Error, %id, "Failed to call device to save state");
109                            Vec::new()
110                        }
111                    }
112                } else {
113                    tracing::error!(%id, "Intercepted device missing during save operation");
114                    Vec::new()
115                };
116                Some(Channel {
117                    channel_id: channel_id.0,
118                    event_flag: None,
119                    intercepted: true,
120                    intercepted_save_state,
121                    is_open: false,
122                })
123            }
124        }
125    }
126}
127
128impl RelayChannelTask {
129    /// Handle creating channel save state.
130    pub(crate) fn handle_save(&self) -> Channel {
131        Channel {
132            channel_id: self.channel.channel_id.0,
133            event_flag: self
134                .channel
135                .interrupt_relay
136                .as_ref()
137                .map(|interrupt| interrupt.event_flag),
138            intercepted: false,
139            intercepted_save_state: Vec::new(),
140            is_open: self.channel.is_open,
141        }
142    }
143
144    pub(crate) async fn handle_restore(&mut self, state: Channel) -> Result<()> {
145        let Channel {
146            channel_id: _,
147            event_flag,
148            intercepted,
149            intercepted_save_state: _,
150            is_open,
151        } = state;
152
153        if intercepted {
154            anyhow::bail!("cannot restore an intercepted channel");
155        }
156
157        let result = self
158            .channel
159            .server_request_send
160            .call(ChannelServerRequest::Restore, is_open)
161            .await
162            .context("Failed to send restore request")?
163            .map_err(|err| {
164                anyhow::Error::from(err).context("failed to restore vmbus relay channel")
165            })?;
166
167        if let Some(request) = result.open_request {
168            let use_interrupt_relay = self.channel.use_interrupt_relay.load(Ordering::SeqCst);
169            if use_interrupt_relay && event_flag.is_none() {
170                anyhow::bail!("using an interrupt relay but no event flag was provided");
171            }
172            let (incoming_event, notify) = if use_interrupt_relay {
173                let event = Event::new();
174                let notify = Notify::from_event(event.clone())
175                    .pollable(self.driver.as_ref())
176                    .context("failed to create polled notify")?;
177                Some((event, notify))
178            } else {
179                None
180            }
181            .unzip();
182
183            self.channel
184                .request_send
185                .call_failable(
186                    client::ChannelRequest::Restore,
187                    client::RestoreRequest {
188                        connection_id: request.open_data.connection_id,
189                        redirected_event_flag: event_flag,
190                        incoming_event,
191                    },
192                )
193                .await
194                .context("client failed to restore channel")?;
195
196            if let Some(notify) = notify {
197                self.channel.interrupt_relay = Some(InterruptRelay {
198                    event_flag: event_flag.unwrap(),
199                    notify,
200                    interrupt: request.interrupt,
201                });
202            }
203        }
204
205        Ok(())
206    }
207}
208
209#[derive(Clone, Protobuf, SavedStateRoot)]
210#[mesh(package = "vmbus.relay")]
211pub struct SavedState {
212    #[mesh(1)]
213    pub(crate) use_interrupt_relay: bool,
214    // Fields 2, 3, and 4 are used by the legacy saved state but are ignored here.
215    #[mesh(5)]
216    pub(crate) channels: Vec<Channel>,
217}
218
219#[derive(Clone, Protobuf)]
220#[mesh(package = "vmbus.relay")]
221pub(crate) struct Channel {
222    #[mesh(1)]
223    pub(crate) channel_id: u32,
224    #[mesh(2)]
225    pub(crate) event_flag: Option<u16>,
226    #[mesh(3)]
227    pub(crate) intercepted: bool,
228    #[mesh(4)]
229    pub(crate) intercepted_save_state: Vec<u8>,
230    #[mesh(5)]
231    pub(crate) is_open: bool,
232}