1#![expect(missing_docs)]
11#![forbid(unsafe_code)]
12
13pub mod ring_buffer;
14
15use crate::ring_buffer::MemoryBlockRingBuffer;
16use anyhow::Context;
17use anyhow::Result;
18use futures::StreamExt;
19use futures_concurrency::stream::Merge;
20use guid::Guid;
21use inspect::InspectMut;
22use mesh::rpc::RpcSend;
23use pal_async::driver::SpawnDriver;
24use std::future::Future;
25use std::future::pending;
26use std::pin::pin;
27use std::sync::Arc;
28use task_control::AsyncRun;
29use task_control::Cancelled;
30use task_control::InspectTaskMut;
31use task_control::StopTask;
32use task_control::TaskControl;
33use tracing::Instrument;
34use user_driver::DmaClient;
35use user_driver::memory::MemoryBlock;
36use vmbus_channel::ChannelClosed;
37use vmbus_channel::RawAsyncChannel;
38use vmbus_channel::SignalVmbusChannel;
39use vmbus_channel::bus::GpadlRequest;
40use vmbus_channel::bus::OpenData;
41use vmbus_client::ChannelRequest;
42use vmbus_client::OfferInfo;
43use vmbus_client::OpenOutput;
44use vmbus_client::OpenRequest;
45use vmbus_core::protocol::GpadlId;
46use vmbus_core::protocol::UserDefinedData;
47use vmbus_relay::InterceptChannelRequest;
48use vmbus_ring::IncomingRing;
49use vmbus_ring::OutgoingRing;
50use vmbus_ring::PAGE_SIZE;
51use vmcore::interrupt::Interrupt;
52use vmcore::notify::Notify;
53use vmcore::notify::PolledNotify;
54use vmcore::save_restore::NoSavedState;
55use vmcore::save_restore::SavedStateBlob;
56use vmcore::save_restore::SavedStateRoot;
57use zerocopy::FromZeros;
58
59pub enum OfferResponse {
60 Ignore,
61 Open,
62}
63
64pub trait SimpleVmbusClientDevice {
65 type SavedState: SavedStateRoot + Send + Sync;
67
68 type Runner: 'static + Send + Sync;
70
71 fn inspect(&mut self, req: inspect::Request<'_>, runner: Option<&mut Self::Runner>);
73
74 fn instance_id(&self) -> Guid;
76
77 fn offer(&self, offer: &vmbus_core::protocol::OfferChannel) -> OfferResponse;
79
80 fn open(
84 &mut self,
85 channel_idx: u16,
86 channel: RawAsyncChannel<MemoryBlockRingBuffer>,
87 ) -> Result<Self::Runner>;
88
89 fn close(&mut self, channel_idx: u16);
92
93 fn supports_save_restore(
95 &mut self,
96 ) -> Option<
97 &mut dyn SaveRestoreSimpleVmbusClientDevice<
98 SavedState = Self::SavedState,
99 Runner = Self::Runner,
100 >,
101 >;
102}
103
104pub trait SimpleVmbusClientDeviceAsync: SimpleVmbusClientDevice + 'static + Send + Sync {
105 fn run(
107 &mut self,
108 stop: &mut StopTask<'_>,
109 runner: &mut Self::Runner,
110 ) -> impl Send + Future<Output = Result<(), Cancelled>>;
111}
112
113pub trait SaveRestoreSimpleVmbusClientDevice: SimpleVmbusClientDevice {
118 fn save_open(&mut self, runner: &Self::Runner) -> Self::SavedState;
122
123 fn restore_open(
127 &mut self,
128 state: Self::SavedState,
129 channel: RawAsyncChannel<MemoryBlockRingBuffer>,
130 ) -> Result<Self::Runner>;
131}
132
133#[derive(InspectMut)]
134pub struct SimpleVmbusClientDeviceWrapper<T: SimpleVmbusClientDeviceAsync> {
135 instance_id: Guid,
136 #[inspect(skip)]
137 spawner: Arc<dyn SpawnDriver>,
138 #[inspect(mut)]
139 vmbus_listener: TaskControl<SimpleVmbusClientDeviceTask<T>, SimpleVmbusClientDeviceTaskState>,
140}
141
142impl<T: SimpleVmbusClientDeviceAsync> SimpleVmbusClientDeviceWrapper<T> {
143 pub fn new(
145 driver: impl SpawnDriver + Clone,
146 dma_alloc: Arc<dyn DmaClient>,
147 device: T,
148 ) -> Result<Self> {
149 let spawner = Arc::new(driver.clone());
150 Ok(Self {
151 instance_id: device.instance_id(),
152 vmbus_listener: TaskControl::new(SimpleVmbusClientDeviceTask::new(
153 device,
154 spawner.clone(),
155 dma_alloc,
156 )),
157 spawner,
158 })
159 }
160
161 pub fn instance_id(&self) -> Guid {
162 self.instance_id
163 }
164
165 pub fn detach(
166 mut self,
167 driver: impl SpawnDriver,
168 recv_relay: mesh::Receiver<InterceptChannelRequest>,
169 ) -> Result<mesh::OneshotSender<()>> {
170 self.vmbus_listener.insert(
171 &self.spawner,
172 format!("{}", self.instance_id),
173 SimpleVmbusClientDeviceTaskState {
174 offer: None,
175 recv_relay,
176 vtl_pages: None,
177 },
178 );
179 let (driver_send, driver_recv) = mesh::oneshot();
180 driver
181 .spawn(
182 format!("vmbus_relay_device {}", self.instance_id),
183 async move {
184 self.vmbus_listener.start();
185 let _ = driver_recv.await;
186 self.vmbus_listener.stop().await;
187 },
188 )
189 .detach();
190 Ok(driver_send)
191 }
192}
193
194struct RelayDeviceTask<T>(T);
195
196impl<T: SimpleVmbusClientDeviceAsync> AsyncRun<T::Runner> for RelayDeviceTask<T> {
197 async fn run(
198 &mut self,
199 stop: &mut StopTask<'_>,
200 runner: &mut T::Runner,
201 ) -> Result<(), Cancelled> {
202 self.0.run(stop, runner).await
203 }
204}
205
206impl<T: SimpleVmbusClientDeviceAsync> InspectTaskMut<T::Runner> for RelayDeviceTask<T> {
207 fn inspect_mut(&mut self, req: inspect::Request<'_>, runner: Option<&mut T::Runner>) {
208 self.0.inspect(req, runner)
209 }
210}
211
212#[derive(InspectMut)]
213struct SimpleVmbusClientDeviceTaskState {
214 offer: Option<OfferInfo>,
215 #[inspect(skip)]
216 recv_relay: mesh::Receiver<InterceptChannelRequest>,
217 #[inspect(hex, with = "|x| x.as_ref().map(|x| inspect::iter_by_index(x.pfns()))")]
218 vtl_pages: Option<MemoryBlock>,
219}
220
221struct SimpleVmbusClientDeviceTask<T: SimpleVmbusClientDeviceAsync> {
222 device: TaskControl<RelayDeviceTask<T>, T::Runner>,
223 saved_state: Option<T::SavedState>,
224 spawner: Arc<dyn SpawnDriver>,
225 dma_alloc: Arc<dyn DmaClient>,
226}
227
228impl<T: SimpleVmbusClientDeviceAsync> AsyncRun<SimpleVmbusClientDeviceTaskState>
229 for SimpleVmbusClientDeviceTask<T>
230{
231 async fn run(
232 &mut self,
233 stop: &mut StopTask<'_>,
234 state: &mut SimpleVmbusClientDeviceTaskState,
235 ) -> Result<(), Cancelled> {
236 stop.until_stopped(self.process_messages(state)).await
237 }
238}
239
240impl<T: SimpleVmbusClientDeviceAsync> InspectTaskMut<SimpleVmbusClientDeviceTaskState>
241 for SimpleVmbusClientDeviceTask<T>
242{
243 fn inspect_mut(
244 &mut self,
245 req: inspect::Request<'_>,
246 state: Option<&mut SimpleVmbusClientDeviceTaskState>,
247 ) {
248 req.respond()
249 .merge(state)
250 .field_mut("device", &mut self.device)
251 .field("dma_alloc", &self.dma_alloc);
252 }
253}
254
255impl<T: SimpleVmbusClientDeviceAsync> SimpleVmbusClientDeviceTask<T> {
256 pub fn new(device: T, spawner: Arc<dyn SpawnDriver>, dma_alloc: Arc<dyn DmaClient>) -> Self {
257 Self {
258 device: TaskControl::new(RelayDeviceTask(device)),
259 saved_state: None,
260 spawner,
261 dma_alloc,
262 }
263 }
264
265 fn insert_runner(&mut self, state: &SimpleVmbusClientDeviceTaskState, runner: T::Runner) {
266 let offer = state.offer.as_ref().unwrap().offer;
267 self.device.insert(
268 &self.spawner,
269 format!("{}-{}", offer.interface_id, offer.instance_id),
270 runner,
271 );
272 }
273
274 async fn handle_offer(
276 &mut self,
277 offer: OfferInfo,
278 state: &mut SimpleVmbusClientDeviceTaskState,
279 ) -> Result<()> {
280 tracing::info!(?offer, "matching channel offered");
281
282 if offer.offer.is_dedicated != 1 {
283 tracing::warn!(offer = ?offer.offer, "All offers should be dedicated with Win8+ host")
284 }
285
286 if matches!(
287 self.device.task_mut().0.offer(&offer.offer),
288 OfferResponse::Ignore
289 ) {
290 return Ok(());
291 }
292
293 let interrupt_event = pal_event::Event::new();
294 let (memory, ring_gpadl_id) = self
295 .reserve_memory(state, &offer.request_send, 4)
296 .await
297 .context("reserve memory")?;
298 let guest_to_host_interrupt = offer.guest_to_host_interrupt.clone();
299 state.offer = Some(offer);
300 let offer = state.offer.as_ref().unwrap();
301 self.open_channel(&offer.request_send, ring_gpadl_id, &interrupt_event)
302 .await
303 .context("open channel")?;
304 let channel = self
305 .create_vmbus_channel(&memory, &interrupt_event, guest_to_host_interrupt)
306 .context("create vmbus queue")?;
307
308 let save_restore = self.device.task_mut().0.supports_save_restore();
309 let saved_state = self.saved_state.take();
310 let device_runner = if save_restore.is_some() && saved_state.is_some() {
311 save_restore
312 .unwrap()
313 .restore_open(saved_state.unwrap(), channel)
314 .context("device restore_open callback")?
315 } else {
316 self.device
317 .task_mut()
318 .0
319 .open(offer.offer.subchannel_index, channel)
320 .context("device open callback")?
321 };
322 self.insert_runner(state, device_runner);
323 self.device.start();
324 Ok(())
325 }
326
327 async fn handle_start(&mut self, state: &mut SimpleVmbusClientDeviceTaskState) {
329 if self.device.is_running() {
330 return;
331 }
332
333 let offer = state.offer.take();
334 if offer.is_none() {
335 return;
336 }
337
338 if let Err(err) = self.handle_offer(offer.unwrap(), state).await {
340 tracing::error!(
341 err = err.as_ref() as &dyn std::error::Error,
342 "Failed to reconnect vmbus channel"
343 );
344 }
345 }
346
347 async fn cleanup_device_resources(&mut self, state: &mut SimpleVmbusClientDeviceTaskState) {
348 let Some(offer) = state.offer.as_mut() else {
349 return;
350 };
351
352 if state.vtl_pages.is_some() {
353 if let Err(err) = offer
354 .request_send
355 .call(
356 ChannelRequest::TeardownGpadl,
357 GpadlId(state.vtl_pages.as_ref().unwrap().pfns()[1] as u32),
358 )
359 .await
360 {
361 tracing::error!(
362 error = &err as &dyn std::error::Error,
363 "failed to teardown gpadl"
364 );
365 }
366
367 state.vtl_pages = None;
368 }
369 }
370
371 async fn handle_stop(&mut self, state: &mut SimpleVmbusClientDeviceTaskState) {
373 if !self.device.stop().await {
374 return;
375 }
376
377 {
387 let offer = state.offer.as_ref().expect("device opened");
388 offer
389 .request_send
390 .call(ChannelRequest::Close, ())
391 .await
392 .ok();
393 }
394 self.cleanup_device_resources(state).await;
398 let runner = self.device.remove();
399 let device = self.device.task_mut();
400 if let Some(save_restore) = device.0.supports_save_restore() {
401 self.saved_state = Some(save_restore.save_open(&runner));
402 }
403 drop(runner);
404 let offer = state.offer.as_ref().expect("device opened");
405 device.0.close(offer.offer.subchannel_index);
406 }
407
408 async fn reserve_memory(
411 &mut self,
412 state: &mut SimpleVmbusClientDeviceTaskState,
413 request_send: &mesh::Sender<ChannelRequest>,
414 page_count: usize,
415 ) -> Result<(MemoryBlock, GpadlId)> {
416 assert!(page_count >= 4);
419
420 let mem = self
421 .dma_alloc
422 .allocate_dma_buffer(page_count * PAGE_SIZE)
423 .context("allocating memory for vmbus rings")?;
424 state.vtl_pages = Some(mem.clone());
425 let buf: Vec<_> = [mem.len() as u64]
426 .iter()
427 .chain(mem.pfns())
428 .copied()
429 .collect();
430
431 let gpadl_id = GpadlId(state.vtl_pages.as_ref().unwrap().pfns()[1] as u32);
432 request_send
433 .call_failable(
434 ChannelRequest::Gpadl,
435 GpadlRequest {
436 id: gpadl_id,
437 count: 1,
438 buf,
439 },
440 )
441 .await
442 .context("registering gpadl")?;
443 Ok((mem, gpadl_id))
444 }
445
446 async fn open_channel(
448 &self,
449 request_send: &mesh::Sender<ChannelRequest>,
450 ring_gpadl_id: GpadlId,
451 event: &pal_event::Event,
452 ) -> Result<OpenOutput> {
453 let open_request = OpenRequest {
454 open_data: OpenData {
455 target_vp: 0,
456 ring_offset: 2,
457 ring_gpadl_id,
458 event_flag: !0,
459 connection_id: !0,
460 user_data: UserDefinedData::new_zeroed(),
461 },
462 incoming_event: Some(event.clone()),
463 use_vtl2_connection_id: true,
464 };
465
466 request_send
467 .call_failable(ChannelRequest::Open, open_request)
468 .instrument(tracing::info_span!(
469 "opening vmbus channel for intercepted device"
470 ))
471 .await
472 .context("open vmbus channel")
473 }
474
475 fn create_vmbus_channel(
477 &self,
478 mem: &MemoryBlock,
479 host_to_guest_event: &pal_event::Event,
480 guest_to_host_interrupt: Interrupt,
481 ) -> Result<RawAsyncChannel<MemoryBlockRingBuffer>> {
482 let (out_ring_mem, in_ring_mem) = (
483 mem.subblock(0, 2 * PAGE_SIZE),
484 mem.subblock(2 * PAGE_SIZE, 2 * PAGE_SIZE),
485 );
486 let (in_ring, out_ring) = (
487 IncomingRing::new(in_ring_mem.into()).unwrap(),
488 OutgoingRing::new(out_ring_mem.into()).unwrap(),
489 );
490
491 let signal = MemoryBlockChannelSignal {
492 event: Notify::from_event(host_to_guest_event.clone())
493 .pollable(self.spawner.as_ref())
494 .unwrap(),
495 interrupt: guest_to_host_interrupt,
496 };
497 Ok(RawAsyncChannel {
498 in_ring,
499 out_ring,
500 signal: Box::new(signal),
501 })
502 }
503
504 async fn handle_revoke(&mut self, state: &mut SimpleVmbusClientDeviceTaskState) {
506 let Some(offer) = state.offer.take() else {
507 return;
508 };
509 tracing::info!("device revoked");
510 if self.device.stop().await {
511 drop(self.device.remove());
512 self.device.task_mut().0.close(offer.offer.subchannel_index);
513 }
514 self.cleanup_device_resources(state).await;
515 }
516
517 fn handle_save(&mut self) -> SavedStateBlob {
518 let saved_state = self.saved_state.take();
519 if let Some(saved_state) = saved_state {
520 let blob = SavedStateBlob::new(saved_state);
521 self.handle_restore(&blob);
522 blob
523 } else {
524 SavedStateBlob::new(NoSavedState)
525 }
526 }
527
528 fn handle_restore(&mut self, saved_state_blob: &SavedStateBlob) {
529 self.saved_state = match saved_state_blob.parse() {
530 Ok(saved_state) => Some(saved_state),
531 Err(err) => {
532 tracing::error!(
533 err = &err as &dyn std::error::Error,
534 "Protobuf conversion error saving state"
535 );
536 None
537 }
538 };
539 }
540
541 pub async fn process_messages(&mut self, state: &mut SimpleVmbusClientDeviceTaskState) {
544 loop {
545 #[expect(clippy::large_enum_variant)]
546 enum Event {
547 Request(InterceptChannelRequest),
548 Revoke(()),
549 }
550 let revoke = pin!(async {
551 if let Some(offer) = &mut state.offer {
552 (&mut offer.revoke_recv).await.ok();
553 } else {
554 pending().await
555 }
556 });
557 let Some(r) = (
558 (&mut state.recv_relay).map(Event::Request),
559 futures::stream::once(revoke).map(Event::Revoke),
560 )
561 .merge()
562 .next()
563 .await
564 else {
565 break;
566 };
567 match r {
568 Event::Revoke(()) => {
569 self.handle_revoke(state).await;
570 }
571 Event::Request(InterceptChannelRequest::Offer(offer)) => {
572 if !self.device.is_running() {
575 if let Err(err) = self.handle_offer(offer, state).await {
576 tracing::error!(
577 error = err.as_ref() as &dyn std::error::Error,
578 "failed offer handling"
579 );
580 }
581 }
582 }
583 Event::Request(InterceptChannelRequest::Start) => {
584 self.handle_start(state).await;
585 }
586 Event::Request(InterceptChannelRequest::Stop(rpc)) => {
587 rpc.handle(async |()| self.handle_stop(state).await).await;
588 }
589 Event::Request(InterceptChannelRequest::Save(rpc)) => {
590 rpc.handle_sync(|()| self.handle_save());
591 }
592 Event::Request(InterceptChannelRequest::Restore(saved_state)) => {
593 self.handle_restore(&saved_state);
594 }
595 }
596 }
597 }
598}
599
600struct MemoryBlockChannelSignal {
601 event: PolledNotify,
602 interrupt: Interrupt,
603}
604
605impl SignalVmbusChannel for MemoryBlockChannelSignal {
606 fn signal_remote(&self) {
607 self.interrupt.deliver();
608 }
609
610 fn poll_for_signal(
611 &self,
612 cx: &mut std::task::Context<'_>,
613 ) -> std::task::Poll<Result<(), ChannelClosed>> {
614 self.event.poll_wait(cx).map(Ok)
615 }
616}