1#![cfg(target_os = "linux")]
11#![expect(missing_docs)]
12#![forbid(unsafe_code)]
13
14pub mod ring_buffer;
15
16use crate::ring_buffer::MemoryBlockRingBuffer;
17use anyhow::Context;
18use anyhow::Result;
19use futures::StreamExt;
20use futures_concurrency::stream::Merge;
21use guid::Guid;
22use inspect::InspectMut;
23use mesh::rpc::RpcSend;
24use pal_async::driver::SpawnDriver;
25use std::future::Future;
26use std::future::pending;
27use std::pin::pin;
28use std::sync::Arc;
29use task_control::AsyncRun;
30use task_control::Cancelled;
31use task_control::InspectTaskMut;
32use task_control::StopTask;
33use task_control::TaskControl;
34use tracing::Instrument;
35use user_driver::DmaClient;
36use user_driver::memory::MemoryBlock;
37use vmbus_channel::ChannelClosed;
38use vmbus_channel::RawAsyncChannel;
39use vmbus_channel::SignalVmbusChannel;
40use vmbus_channel::bus::GpadlRequest;
41use vmbus_channel::bus::OpenData;
42use vmbus_client::ChannelRequest;
43use vmbus_client::OfferInfo;
44use vmbus_client::OpenOutput;
45use vmbus_client::OpenRequest;
46use vmbus_core::protocol::GpadlId;
47use vmbus_core::protocol::UserDefinedData;
48use vmbus_relay::InterceptChannelRequest;
49use vmbus_ring::IncomingRing;
50use vmbus_ring::OutgoingRing;
51use vmbus_ring::PAGE_SIZE;
52use vmcore::interrupt::Interrupt;
53use vmcore::notify::Notify;
54use vmcore::notify::PolledNotify;
55use vmcore::save_restore::NoSavedState;
56use vmcore::save_restore::SavedStateBlob;
57use vmcore::save_restore::SavedStateRoot;
58use zerocopy::FromZeros;
59
60pub enum OfferResponse {
61 Ignore,
62 Open,
63}
64
65pub trait SimpleVmbusClientDevice {
66 type SavedState: SavedStateRoot + Send + Sync;
68
69 type Runner: 'static + Send + Sync;
71
72 fn inspect(&mut self, req: inspect::Request<'_>, runner: Option<&mut Self::Runner>);
74
75 fn instance_id(&self) -> Guid;
77
78 fn offer(&self, offer: &vmbus_core::protocol::OfferChannel) -> OfferResponse;
80
81 fn open(
85 &mut self,
86 channel_idx: u16,
87 channel: RawAsyncChannel<MemoryBlockRingBuffer>,
88 ) -> Result<Self::Runner>;
89
90 fn close(&mut self, channel_idx: u16);
93
94 fn supports_save_restore(
96 &mut self,
97 ) -> Option<
98 &mut dyn SaveRestoreSimpleVmbusClientDevice<
99 SavedState = Self::SavedState,
100 Runner = Self::Runner,
101 >,
102 >;
103}
104
105pub trait SimpleVmbusClientDeviceAsync: SimpleVmbusClientDevice + 'static + Send + Sync {
106 fn run(
108 &mut self,
109 stop: &mut StopTask<'_>,
110 runner: &mut Self::Runner,
111 ) -> impl Send + Future<Output = Result<(), Cancelled>>;
112}
113
114pub trait SaveRestoreSimpleVmbusClientDevice: SimpleVmbusClientDevice {
119 fn save_open(&mut self, runner: &Self::Runner) -> Self::SavedState;
123
124 fn restore_open(
128 &mut self,
129 state: Self::SavedState,
130 channel: RawAsyncChannel<MemoryBlockRingBuffer>,
131 ) -> Result<Self::Runner>;
132}
133
134#[derive(InspectMut)]
135pub struct SimpleVmbusClientDeviceWrapper<T: SimpleVmbusClientDeviceAsync> {
136 instance_id: Guid,
137 #[inspect(skip)]
138 spawner: Arc<dyn SpawnDriver>,
139 #[inspect(mut)]
140 vmbus_listener: TaskControl<SimpleVmbusClientDeviceTask<T>, SimpleVmbusClientDeviceTaskState>,
141}
142
143impl<T: SimpleVmbusClientDeviceAsync> SimpleVmbusClientDeviceWrapper<T> {
144 pub fn new(
146 driver: impl SpawnDriver + Clone,
147 dma_alloc: Arc<dyn DmaClient>,
148 device: T,
149 ) -> Result<Self> {
150 let spawner = Arc::new(driver.clone());
151 Ok(Self {
152 instance_id: device.instance_id(),
153 vmbus_listener: TaskControl::new(SimpleVmbusClientDeviceTask::new(
154 device,
155 spawner.clone(),
156 dma_alloc,
157 )),
158 spawner,
159 })
160 }
161
162 pub fn instance_id(&self) -> Guid {
163 self.instance_id
164 }
165
166 pub fn detach(
167 mut self,
168 driver: impl SpawnDriver,
169 recv_relay: mesh::Receiver<InterceptChannelRequest>,
170 ) -> Result<mesh::OneshotSender<()>> {
171 self.vmbus_listener.insert(
172 &self.spawner,
173 format!("{}", self.instance_id),
174 SimpleVmbusClientDeviceTaskState {
175 offer: None,
176 recv_relay,
177 vtl_pages: None,
178 },
179 );
180 let (driver_send, driver_recv) = mesh::oneshot();
181 driver
182 .spawn(
183 format!("vmbus_relay_device {}", self.instance_id),
184 async move {
185 self.vmbus_listener.start();
186 let _ = driver_recv.await;
187 self.vmbus_listener.stop().await;
188 },
189 )
190 .detach();
191 Ok(driver_send)
192 }
193}
194
195struct RelayDeviceTask<T>(T);
196
197impl<T: SimpleVmbusClientDeviceAsync> AsyncRun<T::Runner> for RelayDeviceTask<T> {
198 async fn run(
199 &mut self,
200 stop: &mut StopTask<'_>,
201 runner: &mut T::Runner,
202 ) -> Result<(), Cancelled> {
203 self.0.run(stop, runner).await
204 }
205}
206
207impl<T: SimpleVmbusClientDeviceAsync> InspectTaskMut<T::Runner> for RelayDeviceTask<T> {
208 fn inspect_mut(&mut self, req: inspect::Request<'_>, runner: Option<&mut T::Runner>) {
209 self.0.inspect(req, runner)
210 }
211}
212
213#[derive(InspectMut)]
214struct SimpleVmbusClientDeviceTaskState {
215 offer: Option<OfferInfo>,
216 #[inspect(skip)]
217 recv_relay: mesh::Receiver<InterceptChannelRequest>,
218 #[inspect(hex, with = "|x| x.as_ref().map(|x| inspect::iter_by_index(x.pfns()))")]
219 vtl_pages: Option<MemoryBlock>,
220}
221
222struct SimpleVmbusClientDeviceTask<T: SimpleVmbusClientDeviceAsync> {
223 device: TaskControl<RelayDeviceTask<T>, T::Runner>,
224 saved_state: Option<T::SavedState>,
225 spawner: Arc<dyn SpawnDriver>,
226 dma_alloc: Arc<dyn DmaClient>,
227}
228
229impl<T: SimpleVmbusClientDeviceAsync> AsyncRun<SimpleVmbusClientDeviceTaskState>
230 for SimpleVmbusClientDeviceTask<T>
231{
232 async fn run(
233 &mut self,
234 stop: &mut StopTask<'_>,
235 state: &mut SimpleVmbusClientDeviceTaskState,
236 ) -> Result<(), Cancelled> {
237 stop.until_stopped(self.process_messages(state)).await
238 }
239}
240
241impl<T: SimpleVmbusClientDeviceAsync> InspectTaskMut<SimpleVmbusClientDeviceTaskState>
242 for SimpleVmbusClientDeviceTask<T>
243{
244 fn inspect_mut(
245 &mut self,
246 req: inspect::Request<'_>,
247 state: Option<&mut SimpleVmbusClientDeviceTaskState>,
248 ) {
249 req.respond()
250 .merge(state)
251 .field_mut("device", &mut self.device)
252 .field("dma_alloc", &self.dma_alloc);
253 }
254}
255
256impl<T: SimpleVmbusClientDeviceAsync> SimpleVmbusClientDeviceTask<T> {
257 pub fn new(device: T, spawner: Arc<dyn SpawnDriver>, dma_alloc: Arc<dyn DmaClient>) -> Self {
258 Self {
259 device: TaskControl::new(RelayDeviceTask(device)),
260 saved_state: None,
261 spawner,
262 dma_alloc,
263 }
264 }
265
266 fn insert_runner(&mut self, state: &SimpleVmbusClientDeviceTaskState, runner: T::Runner) {
267 let offer = state.offer.as_ref().unwrap().offer;
268 self.device.insert(
269 &self.spawner,
270 format!("{}-{}", offer.interface_id, offer.instance_id),
271 runner,
272 );
273 }
274
275 async fn handle_offer(
277 &mut self,
278 offer: OfferInfo,
279 state: &mut SimpleVmbusClientDeviceTaskState,
280 ) -> Result<()> {
281 tracing::info!(?offer, "matching channel offered");
282
283 if offer.offer.is_dedicated != 1 {
284 tracing::warn!(offer = ?offer.offer, "All offers should be dedicated with Win8+ host")
285 }
286
287 if matches!(
288 self.device.task_mut().0.offer(&offer.offer),
289 OfferResponse::Ignore
290 ) {
291 return Ok(());
292 }
293
294 let interrupt_event = pal_event::Event::new();
295 let (memory, ring_gpadl_id) = self
296 .reserve_memory(state, &offer.request_send, 4)
297 .await
298 .context("reserve memory")?;
299 state.offer = Some(offer);
300 let offer = state.offer.as_ref().unwrap();
301 let opened = self
302 .open_channel(&offer.request_send, ring_gpadl_id, &interrupt_event)
303 .await
304 .context("open channel")?;
305 let channel = self
306 .create_vmbus_channel(&memory, &interrupt_event, opened.guest_to_host_signal)
307 .context("create vmbus queue")?;
308
309 let save_restore = self.device.task_mut().0.supports_save_restore();
310 let saved_state = self.saved_state.take();
311 let device_runner = if save_restore.is_some() && saved_state.is_some() {
312 save_restore
313 .unwrap()
314 .restore_open(saved_state.unwrap(), channel)
315 .context("device restore_open callback")?
316 } else {
317 self.device
318 .task_mut()
319 .0
320 .open(offer.offer.subchannel_index, channel)
321 .context("device open callback")?
322 };
323 self.insert_runner(state, device_runner);
324 self.device.start();
325 Ok(())
326 }
327
328 async fn handle_start(&mut self, state: &mut SimpleVmbusClientDeviceTaskState) {
330 if self.device.is_running() {
331 return;
332 }
333
334 let offer = state.offer.take();
335 if offer.is_none() {
336 return;
337 }
338
339 if let Err(err) = self.handle_offer(offer.unwrap(), state).await {
341 tracing::error!(
342 err = err.as_ref() as &dyn std::error::Error,
343 "Failed to reconnect vmbus channel"
344 );
345 }
346 }
347
348 async fn cleanup_device_resources(&mut self, state: &mut SimpleVmbusClientDeviceTaskState) {
349 let Some(offer) = state.offer.as_mut() else {
350 return;
351 };
352
353 if state.vtl_pages.is_some() {
354 if let Err(err) = offer
355 .request_send
356 .call(
357 ChannelRequest::TeardownGpadl,
358 GpadlId(state.vtl_pages.as_ref().unwrap().pfns()[1] as u32),
359 )
360 .await
361 {
362 tracing::error!(
363 error = &err as &dyn std::error::Error,
364 "failed to teardown gpadl"
365 );
366 }
367
368 state.vtl_pages = None;
369 }
370 }
371
372 async fn handle_stop(&mut self, state: &mut SimpleVmbusClientDeviceTaskState) {
374 if !self.device.stop().await {
375 return;
376 }
377
378 {
388 let offer = state.offer.as_ref().expect("device opened");
389 offer
390 .request_send
391 .call(ChannelRequest::Close, ())
392 .await
393 .ok();
394 }
395 self.cleanup_device_resources(state).await;
399 let runner = self.device.remove();
400 let device = self.device.task_mut();
401 if let Some(save_restore) = device.0.supports_save_restore() {
402 self.saved_state = Some(save_restore.save_open(&runner));
403 }
404 drop(runner);
405 let offer = state.offer.as_ref().expect("device opened");
406 device.0.close(offer.offer.subchannel_index);
407 }
408
409 async fn reserve_memory(
412 &mut self,
413 state: &mut SimpleVmbusClientDeviceTaskState,
414 request_send: &mesh::Sender<ChannelRequest>,
415 page_count: usize,
416 ) -> Result<(MemoryBlock, GpadlId)> {
417 assert!(page_count >= 4);
420
421 let mem = self
422 .dma_alloc
423 .allocate_dma_buffer(page_count * PAGE_SIZE)
424 .context("allocating memory for vmbus rings")?;
425 state.vtl_pages = Some(mem.clone());
426 let buf: Vec<_> = [mem.len() as u64]
427 .iter()
428 .chain(mem.pfns())
429 .copied()
430 .collect();
431
432 let gpadl_id = GpadlId(state.vtl_pages.as_ref().unwrap().pfns()[1] as u32);
433 request_send
434 .call_failable(
435 ChannelRequest::Gpadl,
436 GpadlRequest {
437 id: gpadl_id,
438 count: 1,
439 buf,
440 },
441 )
442 .await
443 .context("registering gpadl")?;
444 Ok((mem, gpadl_id))
445 }
446
447 async fn open_channel(
449 &self,
450 request_send: &mesh::Sender<ChannelRequest>,
451 ring_gpadl_id: GpadlId,
452 event: &pal_event::Event,
453 ) -> Result<OpenOutput> {
454 let open_request = OpenRequest {
455 open_data: OpenData {
456 target_vp: 0,
457 ring_offset: 2,
458 ring_gpadl_id,
459 event_flag: !0,
460 connection_id: !0,
461 user_data: UserDefinedData::new_zeroed(),
462 },
463 incoming_event: Some(event.clone()),
464 use_vtl2_connection_id: true,
465 };
466
467 request_send
468 .call_failable(ChannelRequest::Open, open_request)
469 .instrument(tracing::info_span!(
470 "opening vmbus channel for intercepted device"
471 ))
472 .await
473 .context("open vmbus channel")
474 }
475
476 fn create_vmbus_channel(
478 &self,
479 mem: &MemoryBlock,
480 host_to_guest_event: &pal_event::Event,
481 guest_to_host_interrupt: Interrupt,
482 ) -> Result<RawAsyncChannel<MemoryBlockRingBuffer>> {
483 let (out_ring_mem, in_ring_mem) = (
484 mem.subblock(0, 2 * PAGE_SIZE),
485 mem.subblock(2 * PAGE_SIZE, 2 * PAGE_SIZE),
486 );
487 let (in_ring, out_ring) = (
488 IncomingRing::new(in_ring_mem.into()).unwrap(),
489 OutgoingRing::new(out_ring_mem.into()).unwrap(),
490 );
491
492 let signal = MemoryBlockChannelSignal {
493 event: Notify::from_event(host_to_guest_event.clone())
494 .pollable(self.spawner.as_ref())
495 .unwrap(),
496 interrupt: guest_to_host_interrupt,
497 };
498 Ok(RawAsyncChannel {
499 in_ring,
500 out_ring,
501 signal: Box::new(signal),
502 })
503 }
504
505 async fn handle_revoke(&mut self, state: &mut SimpleVmbusClientDeviceTaskState) {
507 let Some(offer) = state.offer.take() else {
508 return;
509 };
510 tracing::info!("device revoked");
511 if self.device.stop().await {
512 drop(self.device.remove());
513 self.device.task_mut().0.close(offer.offer.subchannel_index);
514 }
515 self.cleanup_device_resources(state).await;
516 }
517
518 fn handle_save(&mut self) -> SavedStateBlob {
519 let saved_state = self.saved_state.take();
520 if saved_state.is_some() {
521 let blob = SavedStateBlob::new(saved_state.unwrap());
522 self.handle_restore(&blob);
523 blob
524 } else {
525 SavedStateBlob::new(NoSavedState)
526 }
527 }
528
529 fn handle_restore(&mut self, saved_state_blob: &SavedStateBlob) {
530 self.saved_state = match saved_state_blob.parse() {
531 Ok(saved_state) => Some(saved_state),
532 Err(err) => {
533 tracing::error!(
534 err = &err as &dyn std::error::Error,
535 "Protobuf conversion error saving state"
536 );
537 None
538 }
539 };
540 }
541
542 pub async fn process_messages(&mut self, state: &mut SimpleVmbusClientDeviceTaskState) {
545 loop {
546 #[expect(clippy::large_enum_variant)]
547 enum Event {
548 Request(InterceptChannelRequest),
549 Revoke(()),
550 }
551 let revoke = pin!(async {
552 if let Some(offer) = &mut state.offer {
553 (&mut offer.revoke_recv).await.ok();
554 } else {
555 pending().await
556 }
557 });
558 let Some(r) = (
559 (&mut state.recv_relay).map(Event::Request),
560 futures::stream::once(revoke).map(Event::Revoke),
561 )
562 .merge()
563 .next()
564 .await
565 else {
566 break;
567 };
568 match r {
569 Event::Revoke(()) => {
570 self.handle_revoke(state).await;
571 }
572 Event::Request(InterceptChannelRequest::Offer(offer)) => {
573 if !self.device.is_running() {
576 if let Err(err) = self.handle_offer(offer, state).await {
577 tracing::error!(
578 error = err.as_ref() as &dyn std::error::Error,
579 "failed offer handling"
580 );
581 }
582 }
583 }
584 Event::Request(InterceptChannelRequest::Start) => {
585 self.handle_start(state).await;
586 }
587 Event::Request(InterceptChannelRequest::Stop(rpc)) => {
588 rpc.handle(async |()| self.handle_stop(state).await).await;
589 }
590 Event::Request(InterceptChannelRequest::Save(rpc)) => {
591 rpc.handle_sync(|()| self.handle_save());
592 }
593 Event::Request(InterceptChannelRequest::Restore(saved_state)) => {
594 self.handle_restore(&saved_state);
595 }
596 }
597 }
598 }
599}
600
601struct MemoryBlockChannelSignal {
602 event: PolledNotify,
603 interrupt: Interrupt,
604}
605
606impl SignalVmbusChannel for MemoryBlockChannelSignal {
607 fn signal_remote(&self) {
608 self.interrupt.deliver();
609 }
610
611 fn poll_for_signal(
612 &self,
613 cx: &mut std::task::Context<'_>,
614 ) -> std::task::Poll<Result<(), ChannelClosed>> {
615 self.event.poll_wait(cx).map(Ok)
616 }
617}