vmbus_relay/
saved_state.rs1use crate::ChannelId;
5use crate::ChannelInfo;
6use crate::InterceptChannelRequest;
7use crate::InterruptRelay;
8use crate::RelayChannelRequest;
9use crate::RelayChannelTask;
10use crate::RelayTask;
11use anyhow::Context as _;
12use anyhow::Result;
13use mesh::payload::Protobuf;
14use mesh::rpc::RpcSend;
15use pal_event::Event;
16use std::sync::atomic::Ordering;
17use vmbus_channel::bus::ChannelServerRequest;
18use vmbus_client as client;
19use vmcore::notify::Notify;
20use vmcore::save_restore::SavedStateRoot;
21
22impl RelayTask {
23 pub async fn handle_save(&self) -> SavedState {
24 assert!(!self.running);
25
26 let channels = futures::future::join_all(
27 self.channels
28 .iter()
29 .map(|(id, channel)| self.save_channel_state(*id, channel)),
30 )
31 .await
32 .drain(..)
33 .flatten()
34 .collect();
35
36 SavedState {
37 use_interrupt_relay: self.use_interrupt_relay.load(Ordering::SeqCst),
38 channels,
39 }
40 }
41
42 pub async fn handle_restore(&mut self, state: SavedState) -> Result<()> {
43 let SavedState {
44 use_interrupt_relay,
45 channels,
46 } = state;
47
48 self.use_interrupt_relay
49 .store(use_interrupt_relay, Ordering::SeqCst);
50
51 for saved_channel in channels {
52 let Some(channel) = self.channels.get_mut(&ChannelId(saved_channel.channel_id)) else {
53 tracing::info!(
54 channel_id = saved_channel.channel_id,
55 "channel not found during restore, probably revoked"
56 );
57 continue;
58 };
59 match channel {
60 ChannelInfo::Relay(info) => {
61 info.relay_request_send
62 .call_failable(RelayChannelRequest::Restore, saved_channel)
63 .await?;
64 }
65 ChannelInfo::Intercept(id) => {
66 if saved_channel.is_open {
67 anyhow::bail!("cannot restore intercepted channel {id}");
68 }
69 }
70 }
71 }
72
73 Ok(())
74 }
75
76 async fn save_channel_state(
77 &self,
78 channel_id: ChannelId,
79 channel: &ChannelInfo,
80 ) -> Option<Channel> {
81 match channel {
82 ChannelInfo::Relay(relay) => {
83 match relay
84 .relay_request_send
85 .call(RelayChannelRequest::Save, ())
86 .await
87 {
88 Ok(result) => Some(result),
89 Err(err) => {
90 tracing::error!(
91 err = &err as &dyn std::error::Error,
92 "Failed to save relay channel state"
93 );
94 None
95 }
96 }
97 }
98 ChannelInfo::Intercept(id) => {
99 let intercepted_save_state = if let Some(intercepted_channel) =
100 self.intercept_channels.get(id)
101 {
102 let result = intercepted_channel
103 .call(InterceptChannelRequest::Save, ())
104 .await;
105 match result {
106 Ok(save_state) => mesh_protobuf::encode(save_state),
107 Err(err) => {
108 tracing::error!(err = &err as &dyn std::error::Error, %id, "Failed to call device to save state");
109 Vec::new()
110 }
111 }
112 } else {
113 tracing::error!(%id, "Intercepted device missing during save operation");
114 Vec::new()
115 };
116 Some(Channel {
117 channel_id: channel_id.0,
118 event_flag: None,
119 intercepted: true,
120 intercepted_save_state,
121 is_open: false,
122 })
123 }
124 }
125 }
126}
127
128impl RelayChannelTask {
129 pub(crate) fn handle_save(&self) -> Channel {
131 Channel {
132 channel_id: self.channel.channel_id.0,
133 event_flag: self
134 .channel
135 .interrupt_relay
136 .as_ref()
137 .map(|interrupt| interrupt.event_flag),
138 intercepted: false,
139 intercepted_save_state: Vec::new(),
140 is_open: self.channel.is_open,
141 }
142 }
143
144 pub(crate) async fn handle_restore(&mut self, state: Channel) -> Result<()> {
145 let Channel {
146 channel_id: _,
147 event_flag,
148 intercepted,
149 intercepted_save_state: _,
150 is_open,
151 } = state;
152
153 if intercepted {
154 anyhow::bail!("cannot restore an intercepted channel");
155 }
156
157 let result = self
158 .channel
159 .server_request_send
160 .call(ChannelServerRequest::Restore, is_open)
161 .await
162 .context("Failed to send restore request")?
163 .map_err(|err| {
164 anyhow::Error::from(err).context("failed to restore vmbus relay channel")
165 })?;
166
167 if let Some(request) = result.open_request {
168 let use_interrupt_relay = self.channel.use_interrupt_relay.load(Ordering::SeqCst);
169 if use_interrupt_relay && event_flag.is_none() {
170 anyhow::bail!("using an interrupt relay but no event flag was provided");
171 }
172 let (incoming_event, notify) = if use_interrupt_relay {
173 let event = Event::new();
174 let notify = Notify::from_event(event.clone())
175 .pollable(self.driver.as_ref())
176 .context("failed to create polled notify")?;
177 Some((event, notify))
178 } else {
179 None
180 }
181 .unzip();
182
183 self.channel
184 .request_send
185 .call_failable(
186 client::ChannelRequest::Restore,
187 client::RestoreRequest {
188 connection_id: request.open_data.connection_id,
189 redirected_event_flag: event_flag,
190 incoming_event,
191 },
192 )
193 .await
194 .context("client failed to restore channel")?;
195
196 if let Some(notify) = notify {
197 self.channel.interrupt_relay = Some(InterruptRelay {
198 event_flag: event_flag.unwrap(),
199 notify,
200 interrupt: request.interrupt,
201 });
202 }
203 }
204
205 Ok(())
206 }
207}
208
209#[derive(Clone, Protobuf, SavedStateRoot)]
210#[mesh(package = "vmbus.relay")]
211pub struct SavedState {
212 #[mesh(1)]
213 pub(crate) use_interrupt_relay: bool,
214 #[mesh(5)]
216 pub(crate) channels: Vec<Channel>,
217}
218
219#[derive(Clone, Protobuf)]
220#[mesh(package = "vmbus.relay")]
221pub(crate) struct Channel {
222 #[mesh(1)]
223 pub(crate) channel_id: u32,
224 #[mesh(2)]
225 pub(crate) event_flag: Option<u16>,
226 #[mesh(3)]
227 pub(crate) intercepted: bool,
228 #[mesh(4)]
229 pub(crate) intercepted_save_state: Vec<u8>,
230 #[mesh(5)]
231 pub(crate) is_open: bool,
232}