1use crate::RawAsyncChannel;
15use crate::bus::OfferParams;
16use crate::bus::OpenRequest;
17use crate::bus::ParentBus;
18use crate::channel::ChannelHandle;
19use crate::channel::ChannelOpenError;
20use crate::channel::DeviceResources;
21use crate::channel::RestoreControl;
22use crate::channel::SaveRestoreVmbusDevice;
23use crate::channel::VmbusDevice;
24use crate::channel::offer_channel;
25use crate::gpadl_ring::GpadlRingMem;
26use crate::gpadl_ring::gpadl_channel;
27use async_trait::async_trait;
28use guestmem::GuestMemory;
29use inspect::Inspect;
30use inspect::InspectMut;
31use mesh::payload::Protobuf;
32use task_control::AsyncRun;
33use task_control::Cancelled;
34use task_control::InspectTaskMut;
35use task_control::StopTask;
36use task_control::TaskControl;
37use vmbus_ring::RingMem;
38use vmcore::save_restore::RestoreError;
39use vmcore::save_restore::SaveError;
40use vmcore::save_restore::SavedStateBlob;
41use vmcore::save_restore::SavedStateRoot;
42use vmcore::vm_task::VmTaskDriver;
43use vmcore::vm_task::VmTaskDriverSource;
44
45#[async_trait]
47pub trait SimpleVmbusDevice<M: RingMem = GpadlRingMem>: 'static + Send {
48 type SavedState: SavedStateRoot + Send;
50
51 type Runner: 'static + Send;
53
54 fn offer(&self) -> OfferParams;
56
57 fn inspect(&mut self, req: inspect::Request<'_>, runner: Option<&mut Self::Runner>);
59
60 fn open(
64 &mut self,
65 channel: RawAsyncChannel<M>,
66 guest_memory: GuestMemory,
67 ) -> Result<Self::Runner, ChannelOpenError>;
68
69 async fn run(
71 &mut self,
72 stop: &mut StopTask<'_>,
73 runner: &mut Self::Runner,
74 ) -> Result<(), Cancelled>;
75
76 async fn close(&mut self) {}
78
79 fn supports_save_restore(
83 &mut self,
84 ) -> Option<
85 &mut dyn SaveRestoreSimpleVmbusDevice<SavedState = Self::SavedState, Runner = Self::Runner>,
86 >;
87}
88
89pub trait SaveRestoreSimpleVmbusDevice<M: RingMem = GpadlRingMem>: SimpleVmbusDevice {
93 fn save_open(&mut self, runner: &Self::Runner) -> Self::SavedState;
98
99 fn restore_open(
104 &mut self,
105 state: Self::SavedState,
106 channel: RawAsyncChannel<M>,
107 ) -> Result<Self::Runner, ChannelOpenError>;
108}
109
110#[derive(Debug, Protobuf, SavedStateRoot)]
112#[mesh(package = "vmbus")]
113struct SimpleSavedState {
114 #[mesh(1)]
118 channel: Option<SavedStateBlob>,
119}
120
121pub struct SimpleDeviceWrapper<T: SimpleVmbusDevice> {
123 driver: VmTaskDriver,
124 offer: OfferParams,
125 resources: DeviceResources,
126 device: TaskControl<DeviceTask<T>, T::Runner>,
127 running: bool,
128}
129
130struct DeviceTask<T>(T);
131
132impl<T: SimpleVmbusDevice> AsyncRun<T::Runner> for DeviceTask<T> {
133 async fn run(
134 &mut self,
135 stop: &mut StopTask<'_>,
136 runner: &mut T::Runner,
137 ) -> Result<(), Cancelled> {
138 self.0.run(stop, runner).await
139 }
140}
141
142impl<T: SimpleVmbusDevice> InspectTaskMut<T::Runner> for DeviceTask<T> {
143 fn inspect_mut(&mut self, req: inspect::Request<'_>, runner: Option<&mut T::Runner>) {
144 self.0.inspect(req, runner)
145 }
146}
147
148impl<T: SimpleVmbusDevice> SimpleDeviceWrapper<T> {
149 pub fn new(driver: VmTaskDriver, device: T) -> Self {
151 let offer = device.offer();
152 Self {
153 running: false,
154 driver,
155 offer,
156 resources: Default::default(),
157 device: TaskControl::new(DeviceTask(device)),
158 }
159 }
160
161 pub fn into_inner(self) -> T {
163 let (task, _) = self.device.into_inner();
164 task.0
165 }
166
167 fn insert_runner(&mut self, runner: T::Runner) {
168 self.device.insert(
169 &self.driver,
170 format!("{}-{}", self.offer.interface_name, self.offer.instance_id),
171 runner,
172 );
173 }
174
175 fn save(&mut self) -> SimpleSavedState {
176 assert!(!self.running);
177 let device = if let (state, Some(runner)) = self.device.get_mut() {
178 let sr = state.0.supports_save_restore().unwrap();
179 Some(SavedStateBlob::new(sr.save_open(runner)))
180 } else {
181 None
182 };
183 SimpleSavedState { channel: device }
184 }
185
186 fn restore(
187 &mut self,
188 open_request: Option<&OpenRequest>,
189 state: SimpleSavedState,
190 ) -> anyhow::Result<()> {
191 assert!(!self.running);
192 if let Some(device) = state.channel {
193 let device = device.parse()?;
194 let open_request = open_request.expect("open state mismatch");
195 let channel = self.build_channel(open_request)?;
196 let sr = self.device.task_mut().0.supports_save_restore().unwrap();
197 let task = sr.restore_open(device, channel)?;
198 self.insert_runner(task);
199 }
200 Ok(())
201 }
202
203 fn build_channel(
204 &mut self,
205 open_request: &OpenRequest,
206 ) -> anyhow::Result<RawAsyncChannel<GpadlRingMem>> {
207 self.driver.retarget_vp(open_request.open_data.target_vp);
208 let channel = gpadl_channel(&self.driver, &self.resources, open_request, 0)?;
209 Ok(channel)
210 }
211}
212
213#[async_trait]
214impl<T: SimpleVmbusDevice> VmbusDevice for SimpleDeviceWrapper<T> {
215 fn offer(&self) -> OfferParams {
216 self.offer.clone()
217 }
218
219 fn install(&mut self, resources: DeviceResources) {
220 self.resources = resources;
221 }
222
223 async fn open(
224 &mut self,
225 _channel_idx: u16,
226 open_request: &OpenRequest,
227 ) -> Result<(), anyhow::Error> {
228 assert!(self.running);
229 let channel = self.build_channel(open_request)?;
230 let gm = self
231 .resources
232 .offer_resources
233 .guest_memory(open_request)
234 .clone();
235 let runner = self.device.task_mut().0.open(channel, gm)?;
236
237 self.insert_runner(runner);
238 self.device.start();
239 Ok(())
240 }
241
242 async fn close(&mut self, _channel_idx: u16) {
243 self.device.stop().await;
244 self.device.remove();
245 self.device.task_mut().0.close().await;
246 }
247
248 async fn retarget_vp(&mut self, _channel_idx: u16, target_vp: u32) {
249 self.driver.retarget_vp(target_vp);
250 }
251
252 fn start(&mut self) {
253 assert!(!self.running);
254 self.device.start();
255 self.running = true;
256 }
257
258 async fn stop(&mut self) {
259 assert!(self.running);
260 self.device.stop().await;
261 self.running = false;
262 }
263
264 fn supports_save_restore(&mut self) -> Option<&mut dyn SaveRestoreVmbusDevice> {
265 assert!(!self.running);
266 let _ = self.device.task_mut().0.supports_save_restore()?;
267 Some(self)
268 }
269}
270
271#[async_trait]
272impl<T: SimpleVmbusDevice> SaveRestoreVmbusDevice for SimpleDeviceWrapper<T> {
273 async fn save(&mut self) -> Result<SavedStateBlob, SaveError> {
274 Ok(SavedStateBlob::new(self.save()))
275 }
276
277 async fn restore(
278 &mut self,
279 mut control: RestoreControl<'_>,
280 state: SavedStateBlob,
281 ) -> Result<(), RestoreError> {
282 let state: SimpleSavedState = state.parse()?;
283 let is_open = state.channel.is_some();
284 let open_request = control.restore(&[is_open]).await?;
285 self.restore(open_request[0].as_ref(), state)
286 .map_err(RestoreError::Other)?;
287 Ok(())
288 }
289}
290
291impl<T: SimpleVmbusDevice> InspectMut for SimpleDeviceWrapper<T> {
292 fn inspect_mut(&mut self, req: inspect::Request<'_>) {
293 req.respond()
294 .field("driver", &self.driver)
295 .merge(&mut self.device);
296 }
297}
298
299pub async fn offer_simple_device<T: 'static + SimpleVmbusDevice>(
302 driver_source: &VmTaskDriverSource,
303 bus: &(impl ParentBus + ?Sized),
304 device: T,
305) -> anyhow::Result<SimpleDeviceHandle<T>> {
306 let driver = driver_source.builder().target_vp(0).build("simple-vmbus");
307 let channel = SimpleDeviceWrapper::new(driver, device);
308 Ok(SimpleDeviceHandle(
309 offer_channel(&driver_source.simple(), bus, channel).await?,
310 ))
311}
312
313#[must_use]
315#[derive(Debug, Inspect)]
316#[inspect(transparent)]
317pub struct SimpleDeviceHandle<T: SimpleVmbusDevice>(ChannelHandle<SimpleDeviceWrapper<T>>);
318
319impl<T: SimpleVmbusDevice> SimpleDeviceHandle<T> {
320 pub async fn revoke(self) -> Option<T> {
322 self.0.revoke().await.map(|x| x.into_inner())
323 }
324
325 pub fn start(&self) {
327 self.0.start()
328 }
329
330 pub async fn stop(&self) {
332 self.0.stop().await
333 }
334
335 pub async fn reset(&self) {
337 self.0.reset().await
338 }
339
340 pub async fn save(&self) -> anyhow::Result<Option<SavedStateBlob>> {
342 self.0.save().await
343 }
344
345 pub async fn restore(&self, state: SavedStateBlob) -> anyhow::Result<()> {
347 self.0.restore(state).await
348 }
349}