vmbus_relay/
saved_state.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4use crate::ChannelId;
5use crate::ChannelInfo;
6use crate::InterceptChannelRequest;
7use crate::InterruptRelay;
8use crate::RelayChannelRequest;
9use crate::RelayChannelTask;
10use crate::RelayTask;
11use anyhow::Context as _;
12use anyhow::Result;
13use mesh::payload::Protobuf;
14use mesh::rpc::RpcSend;
15use pal_event::Event;
16use std::sync::Arc;
17use std::sync::OnceLock;
18use std::sync::atomic::Ordering;
19use vmbus_channel::bus::ChannelServerRequest;
20use vmbus_channel::bus::OpenResult;
21use vmbus_client as client;
22use vmcore::interrupt::Interrupt;
23use vmcore::notify::Notify;
24use vmcore::save_restore::SavedStateRoot;
25
26impl RelayTask {
27    pub async fn handle_save(&self) -> SavedState {
28        assert!(!self.running);
29
30        let channels = futures::future::join_all(
31            self.channels
32                .iter()
33                .map(|(id, channel)| self.save_channel_state(*id, channel)),
34        )
35        .await
36        .drain(..)
37        .flatten()
38        .collect();
39
40        SavedState {
41            use_interrupt_relay: self.use_interrupt_relay.load(Ordering::SeqCst),
42            channels,
43        }
44    }
45
46    pub async fn handle_restore(&mut self, state: SavedState) -> Result<()> {
47        let SavedState {
48            use_interrupt_relay,
49            channels,
50        } = state;
51
52        self.use_interrupt_relay
53            .store(use_interrupt_relay, Ordering::SeqCst);
54
55        for saved_channel in channels {
56            let Some(channel) = self.channels.get_mut(&ChannelId(saved_channel.channel_id)) else {
57                tracing::info!(
58                    channel_id = saved_channel.channel_id,
59                    "channel not found during restore, probably revoked"
60                );
61                continue;
62            };
63            match channel {
64                ChannelInfo::Relay(info) => {
65                    info.relay_request_send
66                        .call_failable(RelayChannelRequest::Restore, saved_channel)
67                        .await?;
68                }
69                ChannelInfo::Intercept(id) => {
70                    if saved_channel.is_open {
71                        anyhow::bail!("cannot restore intercepted channel {id}");
72                    }
73                }
74            }
75        }
76
77        Ok(())
78    }
79
80    async fn save_channel_state(
81        &self,
82        channel_id: ChannelId,
83        channel: &ChannelInfo,
84    ) -> Option<Channel> {
85        match channel {
86            ChannelInfo::Relay(relay) => {
87                match relay
88                    .relay_request_send
89                    .call(RelayChannelRequest::Save, ())
90                    .await
91                {
92                    Ok(result) => Some(result),
93                    Err(err) => {
94                        tracing::error!(
95                            err = &err as &dyn std::error::Error,
96                            "Failed to save relay channel state"
97                        );
98                        None
99                    }
100                }
101            }
102            ChannelInfo::Intercept(id) => {
103                let intercepted_save_state = if let Some(intercepted_channel) =
104                    self.intercept_channels.get(id)
105                {
106                    let result = intercepted_channel
107                        .call(InterceptChannelRequest::Save, ())
108                        .await;
109                    match result {
110                        Ok(save_state) => mesh_protobuf::encode(save_state),
111                        Err(err) => {
112                            tracing::error!(err = &err as &dyn std::error::Error, %id, "Failed to call device to save state");
113                            Vec::new()
114                        }
115                    }
116                } else {
117                    tracing::error!(%id, "Intercepted device missing during save operation");
118                    Vec::new()
119                };
120                Some(Channel {
121                    channel_id: channel_id.0,
122                    event_flag: None,
123                    intercepted: true,
124                    intercepted_save_state,
125                    is_open: false,
126                })
127            }
128        }
129    }
130}
131
132impl RelayChannelTask {
133    /// Handle creating channel save state.
134    pub(crate) fn handle_save(&self) -> Channel {
135        Channel {
136            channel_id: self.channel.channel_id.0,
137            event_flag: self
138                .channel
139                .interrupt_relay
140                .as_ref()
141                .map(|interrupt| interrupt.event_flag),
142            intercepted: false,
143            intercepted_save_state: Vec::new(),
144            is_open: self.channel.is_open,
145        }
146    }
147
148    pub(crate) async fn handle_restore(&mut self, state: Channel) -> Result<()> {
149        let Channel {
150            channel_id: _,
151            event_flag,
152            intercepted,
153            intercepted_save_state: _,
154            is_open,
155        } = state;
156
157        if intercepted {
158            anyhow::bail!("cannot restore an intercepted channel");
159        }
160
161        // FUTURE: restore vmbus_client before vmbus_server to avoid this
162        // indirection. This requires vmbus_client saving/restoring the
163        // connection ID itself.
164        let restored_interrupt = Arc::new(OnceLock::<Interrupt>::new());
165        let guest_to_host_interrupt = Interrupt::from_fn({
166            let x = restored_interrupt.clone();
167            move || {
168                if let Some(x) = x.get() {
169                    x.deliver();
170                }
171            }
172        });
173
174        let open_result = is_open.then(|| OpenResult {
175            guest_to_host_interrupt,
176        });
177        let result = self
178            .channel
179            .server_request_send
180            .call(ChannelServerRequest::Restore, open_result)
181            .await
182            .context("Failed to send restore request")?
183            .map_err(|err| {
184                anyhow::Error::from(err).context("failed to restore vmbus relay channel")
185            })?;
186
187        if let Some(request) = result.open_request {
188            let use_interrupt_relay = self.channel.use_interrupt_relay.load(Ordering::SeqCst);
189            if use_interrupt_relay && event_flag.is_none() {
190                anyhow::bail!("using an interrupt relay but no event flag was provided");
191            }
192            let (incoming_event, notify) = if use_interrupt_relay {
193                let event = Event::new();
194                let notify = Notify::from_event(event.clone())
195                    .pollable(self.driver.as_ref())
196                    .context("failed to create polled notify")?;
197                Some((event, notify))
198            } else {
199                None
200            }
201            .unzip();
202
203            let opened = self
204                .channel
205                .request_send
206                .call_failable(
207                    client::ChannelRequest::Restore,
208                    client::RestoreRequest {
209                        connection_id: request.open_data.connection_id,
210                        redirected_event_flag: event_flag,
211                        incoming_event,
212                    },
213                )
214                .await
215                .context("client failed to restore channel")?;
216
217            if let Some(notify) = notify {
218                self.channel.interrupt_relay = Some(InterruptRelay {
219                    event_flag: event_flag.unwrap(),
220                    notify,
221                    interrupt: request.interrupt,
222                });
223            }
224
225            restored_interrupt
226                .set(opened.guest_to_host_signal)
227                .ok()
228                .unwrap();
229        }
230
231        Ok(())
232    }
233}
234
235#[derive(Clone, Protobuf, SavedStateRoot)]
236#[mesh(package = "vmbus.relay")]
237pub struct SavedState {
238    #[mesh(1)]
239    pub(crate) use_interrupt_relay: bool,
240    // Fields 2, 3, and 4 are used by the legacy saved state but are ignored here.
241    #[mesh(5)]
242    pub(crate) channels: Vec<Channel>,
243}
244
245#[derive(Clone, Protobuf)]
246#[mesh(package = "vmbus.relay")]
247pub(crate) struct Channel {
248    #[mesh(1)]
249    pub(crate) channel_id: u32,
250    #[mesh(2)]
251    pub(crate) event_flag: Option<u16>,
252    #[mesh(3)]
253    pub(crate) intercepted: bool,
254    #[mesh(4)]
255    pub(crate) intercepted_save_state: Vec<u8>,
256    #[mesh(5)]
257    pub(crate) is_open: bool,
258}