1#![warn(missing_docs)]
7
8use inspect::Inspect;
9use pal_async::task::Spawn;
10use state_unit::NameInUse;
11use state_unit::SpawnedUnit;
12use state_unit::StateUnit;
13use state_unit::StateUnits;
14use state_unit::UnitBuilder;
15use state_unit::UnitHandle;
16use state_unit::run_async_unit;
17use std::sync::Arc;
18use vm_resource::Resource;
19use vm_resource::ResourceResolver;
20use vm_resource::kind::VmbusDeviceHandleKind;
21use vmbus_channel::channel::ChannelHandle;
22use vmbus_channel::channel::VmbusDevice;
23use vmbus_channel::channel::offer_channel;
24use vmbus_channel::channel::offer_generic_channel;
25use vmbus_channel::resources::ResolveVmbusDeviceHandleParams;
26use vmbus_channel::simple::SimpleDeviceHandle;
27use vmbus_channel::simple::SimpleVmbusDevice;
28use vmbus_channel::simple::offer_simple_device;
29use vmbus_server::VmbusServer;
30use vmbus_server::VmbusServerControl;
31use vmcore::save_restore::RestoreError;
32use vmcore::save_restore::SaveError;
33use vmcore::save_restore::SavedStateBlob;
34use vmcore::vm_task::VmTaskDriverSource;
35
36pub struct VmbusServerHandle {
40 unit: SpawnedUnit<VmbusServerUnit>,
41 control: Arc<VmbusServerControl>,
42}
43
44impl VmbusServerHandle {
45 pub fn new(
47 spawner: &impl Spawn,
48 builder: UnitBuilder<'_>,
49 server: VmbusServer,
50 ) -> Result<Self, NameInUse> {
51 let control = server.control();
52 let unit = builder.spawn(spawner, |recv| {
53 run_async_unit(VmbusServerUnit(server), recv)
54 })?;
55 Ok(Self { unit, control })
56 }
57
58 pub fn control(&self) -> &Arc<VmbusServerControl> {
60 &self.control
61 }
62
63 pub fn unit_handle(&self) -> &UnitHandle {
65 self.unit.handle()
66 }
67
68 pub async fn remove(self) -> VmbusServer {
70 self.unit.remove().await.0
71 }
72}
73
74#[derive(Inspect)]
76#[inspect(transparent)]
77struct VmbusServerUnit(VmbusServer);
78
79impl StateUnit for &'_ VmbusServerUnit {
80 async fn start(&mut self) {
81 self.0.start();
82 }
83
84 async fn stop(&mut self) {
85 self.0.stop().await;
86 }
87
88 async fn reset(&mut self) -> anyhow::Result<()> {
89 self.0.reset().await;
90 Ok(())
91 }
92
93 async fn save(&mut self) -> Result<Option<SavedStateBlob>, SaveError> {
94 Ok(Some(SavedStateBlob::new(self.0.save().await)))
95 }
96
97 async fn restore(&mut self, buffer: SavedStateBlob) -> Result<(), RestoreError> {
98 self.0
99 .restore(buffer.parse()?)
100 .await
101 .map_err(|err| RestoreError::Other(err.into()))
102 }
103}
104
105#[must_use]
107#[derive(Debug, Inspect)]
108#[inspect(transparent)]
109pub struct ChannelUnit<T: ?Sized>(ChannelHandle<T>);
110
111pub async fn offer_channel_unit<T: 'static + VmbusDevice>(
113 driver: &impl Spawn,
114 state_units: &StateUnits,
115 vmbus: &VmbusServerHandle,
116 channel: T,
117) -> anyhow::Result<SpawnedUnit<ChannelUnit<T>>> {
118 let offer = channel.offer();
119 let name = format!("{}:{}", offer.interface_name, offer.instance_id);
120 let handle = offer_channel(driver, vmbus.control.as_ref(), channel).await?;
121 let unit = state_units
122 .add(name)
123 .depends_on(vmbus.unit.handle())
124 .spawn(driver, |recv| run_async_unit(ChannelUnit(handle), recv))?;
125 Ok(unit)
126}
127
128impl<T: 'static + VmbusDevice> ChannelUnit<T> {
129 pub async fn revoke(self) -> T {
131 self.0.revoke().await.unwrap()
132 }
133}
134
135impl<T: 'static + VmbusDevice + ?Sized> StateUnit for &'_ ChannelUnit<T> {
136 async fn start(&mut self) {
137 self.0.start();
138 }
139
140 async fn stop(&mut self) {
141 self.0.stop().await;
142 }
143
144 async fn reset(&mut self) -> anyhow::Result<()> {
145 self.0.reset().await;
146 Ok(())
147 }
148
149 async fn save(&mut self) -> Result<Option<SavedStateBlob>, SaveError> {
150 let state = self.0.save().await.map_err(SaveError::Other)?;
151 Ok(state)
152 }
153
154 async fn restore(&mut self, state: SavedStateBlob) -> Result<(), RestoreError> {
155 self.0.restore(state).await.map_err(RestoreError::Other)
156 }
157}
158
159#[must_use]
161#[derive(Debug)]
162pub struct SimpleChannelUnit<T: SimpleVmbusDevice>(SimpleDeviceHandle<T>);
163
164pub async fn offer_simple_device_unit<T: SimpleVmbusDevice>(
166 driver_source: &VmTaskDriverSource,
167 state_units: &StateUnits,
168 vmbus: &VmbusServerHandle,
169 device: T,
170) -> anyhow::Result<SpawnedUnit<SimpleChannelUnit<T>>> {
171 let offer = device.offer();
172 let name = format!("{}:{}", offer.interface_name, offer.instance_id);
173 let handle = offer_simple_device(driver_source, vmbus.control.as_ref(), device).await?;
174 let unit = state_units
175 .add(name)
176 .depends_on(vmbus.unit.handle())
177 .spawn(driver_source.simple(), |recv| {
178 run_async_unit(SimpleChannelUnit(handle), recv)
179 })?;
180 Ok(unit)
181}
182
183impl<T: SimpleVmbusDevice> SimpleChannelUnit<T> {
184 pub async fn revoke(self) -> T {
186 self.0.revoke().await.unwrap()
187 }
188}
189
190impl<T: SimpleVmbusDevice> Inspect for SimpleChannelUnit<T> {
191 fn inspect(&self, req: inspect::Request<'_>) {
192 self.0.inspect(req);
193 }
194}
195
196impl<T: SimpleVmbusDevice> StateUnit for &'_ SimpleChannelUnit<T> {
197 async fn start(&mut self) {
198 self.0.start();
199 }
200
201 async fn stop(&mut self) {
202 self.0.stop().await;
203 }
204
205 async fn reset(&mut self) -> anyhow::Result<()> {
206 self.0.reset().await;
207 Ok(())
208 }
209
210 async fn save(&mut self) -> Result<Option<SavedStateBlob>, SaveError> {
211 let state = self.0.save().await.map_err(SaveError::Other)?;
212 Ok(state)
213 }
214
215 async fn restore(&mut self, state: SavedStateBlob) -> Result<(), RestoreError> {
216 self.0.restore(state).await.map_err(RestoreError::Other)
217 }
218}
219
220pub async fn offer_vmbus_device_handle_unit(
222 driver_source: &VmTaskDriverSource,
223 state_units: &StateUnits,
224 vmbus: &VmbusServerHandle,
225 resolver: &ResourceResolver,
226 resource: Resource<VmbusDeviceHandleKind>,
227) -> anyhow::Result<SpawnedUnit<ChannelUnit<dyn VmbusDevice>>> {
228 let channel = resolver
229 .resolve(resource, ResolveVmbusDeviceHandleParams { driver_source })
230 .await?;
231 let offer = channel.0.offer();
232 let name = format!("{}:{}", offer.interface_name, offer.instance_id);
233 let handle =
234 offer_generic_channel(&driver_source.simple(), vmbus.control.as_ref(), channel.0).await?;
235 let unit = state_units
236 .add(name)
237 .depends_on(vmbus.unit.handle())
238 .spawn(driver_source.simple(), |recv| {
239 run_async_unit(ChannelUnit(handle), recv)
240 })?;
241 Ok(unit)
242}