vmbus_relay/
saved_state.rs1use crate::ChannelId;
5use crate::ChannelInfo;
6use crate::InterceptChannelRequest;
7use crate::InterruptRelay;
8use crate::RelayChannelRequest;
9use crate::RelayChannelTask;
10use crate::RelayTask;
11use anyhow::Context as _;
12use anyhow::Result;
13use mesh::payload::Protobuf;
14use mesh::rpc::RpcSend;
15use pal_event::Event;
16use std::sync::Arc;
17use std::sync::OnceLock;
18use std::sync::atomic::Ordering;
19use vmbus_channel::bus::ChannelServerRequest;
20use vmbus_channel::bus::OpenResult;
21use vmbus_client as client;
22use vmcore::interrupt::Interrupt;
23use vmcore::notify::Notify;
24use vmcore::save_restore::SavedStateRoot;
25
26impl RelayTask {
27 pub async fn handle_save(&self) -> SavedState {
28 assert!(!self.running);
29
30 let channels = futures::future::join_all(
31 self.channels
32 .iter()
33 .map(|(id, channel)| self.save_channel_state(*id, channel)),
34 )
35 .await
36 .drain(..)
37 .flatten()
38 .collect();
39
40 SavedState {
41 use_interrupt_relay: self.use_interrupt_relay.load(Ordering::SeqCst),
42 channels,
43 }
44 }
45
46 pub async fn handle_restore(&mut self, state: SavedState) -> Result<()> {
47 let SavedState {
48 use_interrupt_relay,
49 channels,
50 } = state;
51
52 self.use_interrupt_relay
53 .store(use_interrupt_relay, Ordering::SeqCst);
54
55 for saved_channel in channels {
56 let Some(channel) = self.channels.get_mut(&ChannelId(saved_channel.channel_id)) else {
57 tracing::info!(
58 channel_id = saved_channel.channel_id,
59 "channel not found during restore, probably revoked"
60 );
61 continue;
62 };
63 match channel {
64 ChannelInfo::Relay(info) => {
65 info.relay_request_send
66 .call_failable(RelayChannelRequest::Restore, saved_channel)
67 .await?;
68 }
69 ChannelInfo::Intercept(id) => {
70 if saved_channel.is_open {
71 anyhow::bail!("cannot restore intercepted channel {id}");
72 }
73 }
74 }
75 }
76
77 Ok(())
78 }
79
80 async fn save_channel_state(
81 &self,
82 channel_id: ChannelId,
83 channel: &ChannelInfo,
84 ) -> Option<Channel> {
85 match channel {
86 ChannelInfo::Relay(relay) => {
87 match relay
88 .relay_request_send
89 .call(RelayChannelRequest::Save, ())
90 .await
91 {
92 Ok(result) => Some(result),
93 Err(err) => {
94 tracing::error!(
95 err = &err as &dyn std::error::Error,
96 "Failed to save relay channel state"
97 );
98 None
99 }
100 }
101 }
102 ChannelInfo::Intercept(id) => {
103 let intercepted_save_state = if let Some(intercepted_channel) =
104 self.intercept_channels.get(id)
105 {
106 let result = intercepted_channel
107 .call(InterceptChannelRequest::Save, ())
108 .await;
109 match result {
110 Ok(save_state) => mesh_protobuf::encode(save_state),
111 Err(err) => {
112 tracing::error!(err = &err as &dyn std::error::Error, %id, "Failed to call device to save state");
113 Vec::new()
114 }
115 }
116 } else {
117 tracing::error!(%id, "Intercepted device missing during save operation");
118 Vec::new()
119 };
120 Some(Channel {
121 channel_id: channel_id.0,
122 event_flag: None,
123 intercepted: true,
124 intercepted_save_state,
125 is_open: false,
126 })
127 }
128 }
129 }
130}
131
132impl RelayChannelTask {
133 pub(crate) fn handle_save(&self) -> Channel {
135 Channel {
136 channel_id: self.channel.channel_id.0,
137 event_flag: self
138 .channel
139 .interrupt_relay
140 .as_ref()
141 .map(|interrupt| interrupt.event_flag),
142 intercepted: false,
143 intercepted_save_state: Vec::new(),
144 is_open: self.channel.is_open,
145 }
146 }
147
148 pub(crate) async fn handle_restore(&mut self, state: Channel) -> Result<()> {
149 let Channel {
150 channel_id: _,
151 event_flag,
152 intercepted,
153 intercepted_save_state: _,
154 is_open,
155 } = state;
156
157 if intercepted {
158 anyhow::bail!("cannot restore an intercepted channel");
159 }
160
161 let restored_interrupt = Arc::new(OnceLock::<Interrupt>::new());
165 let guest_to_host_interrupt = Interrupt::from_fn({
166 let x = restored_interrupt.clone();
167 move || {
168 if let Some(x) = x.get() {
169 x.deliver();
170 }
171 }
172 });
173
174 let open_result = is_open.then(|| OpenResult {
175 guest_to_host_interrupt,
176 });
177 let result = self
178 .channel
179 .server_request_send
180 .call(ChannelServerRequest::Restore, open_result)
181 .await
182 .context("Failed to send restore request")?
183 .map_err(|err| {
184 anyhow::Error::from(err).context("failed to restore vmbus relay channel")
185 })?;
186
187 if let Some(request) = result.open_request {
188 let use_interrupt_relay = self.channel.use_interrupt_relay.load(Ordering::SeqCst);
189 if use_interrupt_relay && event_flag.is_none() {
190 anyhow::bail!("using an interrupt relay but no event flag was provided");
191 }
192 let (incoming_event, notify) = if use_interrupt_relay {
193 let event = Event::new();
194 let notify = Notify::from_event(event.clone())
195 .pollable(self.driver.as_ref())
196 .context("failed to create polled notify")?;
197 Some((event, notify))
198 } else {
199 None
200 }
201 .unzip();
202
203 let opened = self
204 .channel
205 .request_send
206 .call_failable(
207 client::ChannelRequest::Restore,
208 client::RestoreRequest {
209 connection_id: request.open_data.connection_id,
210 redirected_event_flag: event_flag,
211 incoming_event,
212 },
213 )
214 .await
215 .context("client failed to restore channel")?;
216
217 if let Some(notify) = notify {
218 self.channel.interrupt_relay = Some(InterruptRelay {
219 event_flag: event_flag.unwrap(),
220 notify,
221 interrupt: request.interrupt,
222 });
223 }
224
225 restored_interrupt
226 .set(opened.guest_to_host_signal)
227 .ok()
228 .unwrap();
229 }
230
231 Ok(())
232 }
233}
234
235#[derive(Clone, Protobuf, SavedStateRoot)]
236#[mesh(package = "vmbus.relay")]
237pub struct SavedState {
238 #[mesh(1)]
239 pub(crate) use_interrupt_relay: bool,
240 #[mesh(5)]
242 pub(crate) channels: Vec<Channel>,
243}
244
245#[derive(Clone, Protobuf)]
246#[mesh(package = "vmbus.relay")]
247pub(crate) struct Channel {
248 #[mesh(1)]
249 pub(crate) channel_id: u32,
250 #[mesh(2)]
251 pub(crate) event_flag: Option<u16>,
252 #[mesh(3)]
253 pub(crate) intercepted: bool,
254 #[mesh(4)]
255 pub(crate) intercepted_save_state: Vec<u8>,
256 #[mesh(5)]
257 pub(crate) is_open: bool,
258}