1use crate::RawAsyncChannel;
15use crate::bus::OfferParams;
16use crate::bus::OpenRequest;
17use crate::bus::ParentBus;
18use crate::channel::ChannelHandle;
19use crate::channel::ChannelOpenError;
20use crate::channel::DeviceResources;
21use crate::channel::RestoreControl;
22use crate::channel::SaveRestoreVmbusDevice;
23use crate::channel::VmbusDevice;
24use crate::channel::offer_channel;
25use crate::gpadl_ring::GpadlRingMem;
26use crate::gpadl_ring::gpadl_channel;
27use async_trait::async_trait;
28use guestmem::GuestMemory;
29use inspect::Inspect;
30use inspect::InspectMut;
31use mesh::payload::Protobuf;
32use task_control::AsyncRun;
33use task_control::Cancelled;
34use task_control::InspectTaskMut;
35use task_control::StopTask;
36use task_control::TaskControl;
37use vmbus_ring::RingMem;
38use vmcore::save_restore::RestoreError;
39use vmcore::save_restore::SaveError;
40use vmcore::save_restore::SavedStateBlob;
41use vmcore::save_restore::SavedStateRoot;
42use vmcore::vm_task::VmTaskDriver;
43use vmcore::vm_task::VmTaskDriverSource;
44
45#[async_trait]
47pub trait SimpleVmbusDevice<M: RingMem = GpadlRingMem>: 'static + Send {
48 type SavedState: SavedStateRoot + Send;
50
51 type Runner: 'static + Send;
53
54 fn offer(&self) -> OfferParams;
56
57 fn inspect(&mut self, req: inspect::Request<'_>, runner: Option<&mut Self::Runner>);
59
60 fn open(
64 &mut self,
65 channel: RawAsyncChannel<M>,
66 guest_memory: GuestMemory,
67 ) -> Result<Self::Runner, ChannelOpenError>;
68
69 async fn run(
71 &mut self,
72 stop: &mut StopTask<'_>,
73 runner: &mut Self::Runner,
74 ) -> Result<(), Cancelled>;
75
76 async fn close(&mut self) {}
78
79 fn supports_save_restore(
83 &mut self,
84 ) -> Option<
85 &mut dyn SaveRestoreSimpleVmbusDevice<SavedState = Self::SavedState, Runner = Self::Runner>,
86 >;
87}
88
89pub trait SaveRestoreSimpleVmbusDevice<M: RingMem = GpadlRingMem>: SimpleVmbusDevice {
93 fn save_open(&mut self, runner: &Self::Runner) -> Self::SavedState;
98
99 fn restore_open(
104 &mut self,
105 state: Self::SavedState,
106 channel: RawAsyncChannel<M>,
107 ) -> Result<Self::Runner, ChannelOpenError>;
108}
109
110#[derive(Debug, Protobuf, SavedStateRoot)]
112#[mesh(package = "vmbus")]
113struct SimpleSavedState {
114 #[mesh(1)]
118 channel: Option<SavedStateBlob>,
119}
120
121pub struct SimpleDeviceWrapper<T: SimpleVmbusDevice> {
123 driver: VmTaskDriver,
124 offer: OfferParams,
125 resources: DeviceResources,
126 device: TaskControl<DeviceTask<T>, T::Runner>,
127 running: bool,
128}
129
130struct DeviceTask<T>(T);
131
132impl<T: SimpleVmbusDevice> AsyncRun<T::Runner> for DeviceTask<T> {
133 async fn run(
134 &mut self,
135 stop: &mut StopTask<'_>,
136 runner: &mut T::Runner,
137 ) -> Result<(), Cancelled> {
138 self.0.run(stop, runner).await
139 }
140}
141
142impl<T: SimpleVmbusDevice> InspectTaskMut<T::Runner> for DeviceTask<T> {
143 fn inspect_mut(&mut self, req: inspect::Request<'_>, runner: Option<&mut T::Runner>) {
144 self.0.inspect(req, runner)
145 }
146}
147
148impl<T: SimpleVmbusDevice> SimpleDeviceWrapper<T> {
149 pub fn new(driver: VmTaskDriver, device: T) -> Self {
151 let offer = device.offer();
152 Self {
153 running: false,
154 driver,
155 offer,
156 resources: Default::default(),
157 device: TaskControl::new(DeviceTask(device)),
158 }
159 }
160
161 pub fn into_inner(self) -> T {
163 let (task, _) = self.device.into_inner();
164 task.0
165 }
166
167 fn insert_runner(&mut self, runner: T::Runner) {
168 self.device.insert(
169 &self.driver,
170 format!("{}-{}", self.offer.interface_name, self.offer.instance_id),
171 runner,
172 );
173 }
174
175 fn save(&mut self) -> SimpleSavedState {
176 assert!(!self.running);
177 let device = if let (state, Some(runner)) = self.device.get_mut() {
178 let sr = state.0.supports_save_restore().unwrap();
179 Some(SavedStateBlob::new(sr.save_open(runner)))
180 } else {
181 None
182 };
183 SimpleSavedState { channel: device }
184 }
185
186 fn restore(
187 &mut self,
188 open_request: Option<&OpenRequest>,
189 state: SimpleSavedState,
190 ) -> anyhow::Result<()> {
191 assert!(!self.running);
192 if let Some(device) = state.channel {
193 let device = device.parse()?;
194 let open_request = open_request.expect("open state mismatch");
195 let channel = self.build_channel(open_request)?;
196 let sr = self.device.task_mut().0.supports_save_restore().unwrap();
197 let task = sr.restore_open(device, channel)?;
198 self.insert_runner(task);
199 }
200 Ok(())
201 }
202
203 fn build_channel(
204 &mut self,
205 open_request: &OpenRequest,
206 ) -> anyhow::Result<RawAsyncChannel<GpadlRingMem>> {
207 self.driver
210 .retarget_vp(open_request.open_data.target_vp.unwrap_or_default());
211 let channel = gpadl_channel(&self.driver, &self.resources, open_request, 0)?;
212 Ok(channel)
213 }
214}
215
216#[async_trait]
217impl<T: SimpleVmbusDevice> VmbusDevice for SimpleDeviceWrapper<T> {
218 fn offer(&self) -> OfferParams {
219 self.offer.clone()
220 }
221
222 fn install(&mut self, resources: DeviceResources) {
223 self.resources = resources;
224 }
225
226 async fn open(
227 &mut self,
228 _channel_idx: u16,
229 open_request: &OpenRequest,
230 ) -> Result<(), anyhow::Error> {
231 assert!(self.running);
232 let channel = self.build_channel(open_request)?;
233 let gm = self
234 .resources
235 .offer_resources
236 .guest_memory(open_request)
237 .clone();
238 let runner = self.device.task_mut().0.open(channel, gm)?;
239
240 self.insert_runner(runner);
241 self.device.start();
242 Ok(())
243 }
244
245 async fn close(&mut self, _channel_idx: u16) {
246 self.device.stop().await;
247 self.device.remove();
248 self.device.task_mut().0.close().await;
249 }
250
251 async fn retarget_vp(&mut self, _channel_idx: u16, target_vp: u32) {
252 self.driver.retarget_vp(target_vp);
253 }
254
255 fn start(&mut self) {
256 assert!(!self.running);
257 self.device.start();
258 self.running = true;
259 }
260
261 async fn stop(&mut self) {
262 assert!(self.running);
263 self.device.stop().await;
264 self.running = false;
265 }
266
267 fn supports_save_restore(&mut self) -> Option<&mut dyn SaveRestoreVmbusDevice> {
268 assert!(!self.running);
269 let _ = self.device.task_mut().0.supports_save_restore()?;
270 Some(self)
271 }
272}
273
274#[async_trait]
275impl<T: SimpleVmbusDevice> SaveRestoreVmbusDevice for SimpleDeviceWrapper<T> {
276 async fn save(&mut self) -> Result<SavedStateBlob, SaveError> {
277 Ok(SavedStateBlob::new(self.save()))
278 }
279
280 async fn restore(
281 &mut self,
282 mut control: RestoreControl<'_>,
283 state: SavedStateBlob,
284 ) -> Result<(), RestoreError> {
285 let state: SimpleSavedState = state.parse()?;
286 let is_open = state.channel.is_some();
287 let open_request = control.restore(&[is_open]).await?;
288 self.restore(open_request[0].as_ref(), state)
289 .map_err(RestoreError::Other)?;
290 Ok(())
291 }
292}
293
294impl<T: SimpleVmbusDevice> InspectMut for SimpleDeviceWrapper<T> {
295 fn inspect_mut(&mut self, req: inspect::Request<'_>) {
296 req.respond()
297 .field("driver", &self.driver)
298 .merge(&mut self.device);
299 }
300}
301
302pub async fn offer_simple_device<T: 'static + SimpleVmbusDevice>(
305 driver_source: &VmTaskDriverSource,
306 bus: &(impl ParentBus + ?Sized),
307 device: T,
308) -> anyhow::Result<SimpleDeviceHandle<T>> {
309 let driver = driver_source.builder().target_vp(0).build("simple-vmbus");
310 let channel = SimpleDeviceWrapper::new(driver, device);
311 Ok(SimpleDeviceHandle(
312 offer_channel(&driver_source.simple(), bus, channel).await?,
313 ))
314}
315
316#[must_use]
318#[derive(Debug, Inspect)]
319#[inspect(transparent)]
320pub struct SimpleDeviceHandle<T: SimpleVmbusDevice>(ChannelHandle<SimpleDeviceWrapper<T>>);
321
322impl<T: SimpleVmbusDevice> SimpleDeviceHandle<T> {
323 pub async fn revoke(self) -> Option<T> {
325 self.0.revoke().await.map(|x| x.into_inner())
326 }
327
328 pub fn start(&self) {
330 self.0.start()
331 }
332
333 pub async fn stop(&self) {
335 self.0.stop().await
336 }
337
338 pub async fn reset(&self) {
340 self.0.reset().await
341 }
342
343 pub async fn save(&self) -> anyhow::Result<Option<SavedStateBlob>> {
345 self.0.save().await
346 }
347
348 pub async fn restore(&self, state: SavedStateBlob) -> anyhow::Result<()> {
350 self.0.restore(state).await
351 }
352}