1#![expect(missing_docs)]
11#![forbid(unsafe_code)]
12
13pub mod ring_buffer;
14
15use crate::ring_buffer::MemoryBlockRingBuffer;
16use anyhow::Context;
17use anyhow::Result;
18use futures::StreamExt;
19use futures_concurrency::stream::Merge;
20use guid::Guid;
21use inspect::InspectMut;
22use mesh::rpc::RpcSend;
23use pal_async::driver::SpawnDriver;
24use std::future::Future;
25use std::future::pending;
26use std::pin::pin;
27use std::sync::Arc;
28use task_control::AsyncRun;
29use task_control::Cancelled;
30use task_control::InspectTaskMut;
31use task_control::StopTask;
32use task_control::TaskControl;
33use tracing::Instrument;
34use user_driver::DmaClient;
35use user_driver::memory::MemoryBlock;
36use vmbus_channel::ChannelClosed;
37use vmbus_channel::RawAsyncChannel;
38use vmbus_channel::SignalVmbusChannel;
39use vmbus_channel::bus::GpadlRequest;
40use vmbus_channel::bus::OpenData;
41use vmbus_client::ChannelRequest;
42use vmbus_client::OfferInfo;
43use vmbus_client::OpenOutput;
44use vmbus_client::OpenRequest;
45use vmbus_core::protocol::GpadlId;
46use vmbus_core::protocol::UserDefinedData;
47use vmbus_relay::InterceptChannelRequest;
48use vmbus_ring::IncomingRing;
49use vmbus_ring::OutgoingRing;
50use vmbus_ring::PAGE_SIZE;
51use vmcore::interrupt::Interrupt;
52use vmcore::notify::Notify;
53use vmcore::notify::PolledNotify;
54use vmcore::save_restore::NoSavedState;
55use vmcore::save_restore::SavedStateBlob;
56use vmcore::save_restore::SavedStateRoot;
57use zerocopy::FromZeros;
58
59pub enum OfferResponse {
60 Ignore,
61 Open,
62}
63
64pub trait SimpleVmbusClientDevice {
65 type SavedState: SavedStateRoot + Send + Sync;
67
68 type Runner: 'static + Send + Sync;
70
71 fn inspect(&mut self, req: inspect::Request<'_>, runner: Option<&mut Self::Runner>);
73
74 fn instance_id(&self) -> Guid;
76
77 fn offer(&self, offer: &vmbus_core::protocol::OfferChannel) -> OfferResponse;
79
80 fn open(
84 &mut self,
85 channel_idx: u16,
86 channel: RawAsyncChannel<MemoryBlockRingBuffer>,
87 ) -> Result<Self::Runner>;
88
89 fn close(&mut self, channel_idx: u16);
92
93 fn supports_save_restore(
95 &mut self,
96 ) -> Option<
97 &mut dyn SaveRestoreSimpleVmbusClientDevice<
98 SavedState = Self::SavedState,
99 Runner = Self::Runner,
100 >,
101 >;
102}
103
104pub trait SimpleVmbusClientDeviceAsync: SimpleVmbusClientDevice + 'static + Send + Sync {
105 fn run(
107 &mut self,
108 stop: &mut StopTask<'_>,
109 runner: &mut Self::Runner,
110 ) -> impl Send + Future<Output = Result<(), Cancelled>>;
111}
112
113pub trait SaveRestoreSimpleVmbusClientDevice: SimpleVmbusClientDevice {
118 fn save_open(&mut self, runner: &Self::Runner) -> Self::SavedState;
122
123 fn restore_open(
127 &mut self,
128 state: Self::SavedState,
129 channel: RawAsyncChannel<MemoryBlockRingBuffer>,
130 ) -> Result<Self::Runner>;
131}
132
133#[derive(InspectMut)]
134pub struct SimpleVmbusClientDeviceWrapper<T: SimpleVmbusClientDeviceAsync> {
135 instance_id: Guid,
136 #[inspect(skip)]
137 spawner: Arc<dyn SpawnDriver>,
138 #[inspect(mut)]
139 vmbus_listener: TaskControl<SimpleVmbusClientDeviceTask<T>, SimpleVmbusClientDeviceTaskState>,
140}
141
142impl<T: SimpleVmbusClientDeviceAsync> SimpleVmbusClientDeviceWrapper<T> {
143 pub fn new(
145 driver: impl SpawnDriver + Clone,
146 dma_alloc: Arc<dyn DmaClient>,
147 device: T,
148 ) -> Result<Self> {
149 let spawner = Arc::new(driver.clone());
150 Ok(Self {
151 instance_id: device.instance_id(),
152 vmbus_listener: TaskControl::new(SimpleVmbusClientDeviceTask::new(
153 device,
154 spawner.clone(),
155 dma_alloc,
156 )),
157 spawner,
158 })
159 }
160
161 pub fn instance_id(&self) -> Guid {
162 self.instance_id
163 }
164
165 pub fn detach(
166 mut self,
167 driver: impl SpawnDriver,
168 recv_relay: mesh::Receiver<InterceptChannelRequest>,
169 ) -> Result<()> {
170 let (send_disconnected, recv_disconnected) = mesh::oneshot();
171 self.vmbus_listener.insert(
172 &self.spawner,
173 format!("{}", self.instance_id),
174 SimpleVmbusClientDeviceTaskState {
175 offer: None,
176 recv_relay,
177 send_disconnected: Some(send_disconnected),
178 vtl_pages: None,
179 },
180 );
181 driver
182 .spawn(
183 format!("vmbus_relay_device {}", self.instance_id),
184 async move {
185 self.vmbus_listener.start();
186 let _ = recv_disconnected.await;
187 assert!(!self.vmbus_listener.stop().await);
188 if self.vmbus_listener.state().unwrap().vtl_pages.is_some() {
189 pending::<()>().await;
194 }
195 },
196 )
197 .detach();
198 Ok(())
199 }
200}
201
202struct RelayDeviceTask<T>(T);
203
204impl<T: SimpleVmbusClientDeviceAsync> AsyncRun<T::Runner> for RelayDeviceTask<T> {
205 async fn run(
206 &mut self,
207 stop: &mut StopTask<'_>,
208 runner: &mut T::Runner,
209 ) -> Result<(), Cancelled> {
210 self.0.run(stop, runner).await
211 }
212}
213
214impl<T: SimpleVmbusClientDeviceAsync> InspectTaskMut<T::Runner> for RelayDeviceTask<T> {
215 fn inspect_mut(&mut self, req: inspect::Request<'_>, runner: Option<&mut T::Runner>) {
216 self.0.inspect(req, runner)
217 }
218}
219
220#[derive(InspectMut)]
221struct SimpleVmbusClientDeviceTaskState {
222 offer: Option<OfferInfo>,
223 #[inspect(skip)]
224 recv_relay: mesh::Receiver<InterceptChannelRequest>,
225 #[inspect(skip)]
226 send_disconnected: Option<mesh::OneshotSender<()>>,
227 #[inspect(hex, with = "|x| x.as_ref().map(|x| inspect::iter_by_index(x.pfns()))")]
228 vtl_pages: Option<MemoryBlock>,
229}
230
231struct SimpleVmbusClientDeviceTask<T: SimpleVmbusClientDeviceAsync> {
232 device: TaskControl<RelayDeviceTask<T>, T::Runner>,
233 saved_state: Option<T::SavedState>,
234 spawner: Arc<dyn SpawnDriver>,
235 dma_alloc: Arc<dyn DmaClient>,
236}
237
238impl<T: SimpleVmbusClientDeviceAsync> AsyncRun<SimpleVmbusClientDeviceTaskState>
239 for SimpleVmbusClientDeviceTask<T>
240{
241 async fn run(
242 &mut self,
243 stop: &mut StopTask<'_>,
244 state: &mut SimpleVmbusClientDeviceTaskState,
245 ) -> Result<(), Cancelled> {
246 stop.until_stopped(self.process_messages(state)).await?;
247 state
248 .send_disconnected
249 .take()
250 .expect("task should not be restarted")
251 .send(());
252 Ok(())
253 }
254}
255
256impl<T: SimpleVmbusClientDeviceAsync> InspectTaskMut<SimpleVmbusClientDeviceTaskState>
257 for SimpleVmbusClientDeviceTask<T>
258{
259 fn inspect_mut(
260 &mut self,
261 req: inspect::Request<'_>,
262 state: Option<&mut SimpleVmbusClientDeviceTaskState>,
263 ) {
264 req.respond()
265 .merge(state)
266 .field_mut("device", &mut self.device)
267 .field("dma_alloc", &self.dma_alloc);
268 }
269}
270
271impl<T: SimpleVmbusClientDeviceAsync> SimpleVmbusClientDeviceTask<T> {
272 pub fn new(device: T, spawner: Arc<dyn SpawnDriver>, dma_alloc: Arc<dyn DmaClient>) -> Self {
273 Self {
274 device: TaskControl::new(RelayDeviceTask(device)),
275 saved_state: None,
276 spawner,
277 dma_alloc,
278 }
279 }
280
281 fn insert_runner(&mut self, state: &SimpleVmbusClientDeviceTaskState, runner: T::Runner) {
282 let offer = state.offer.as_ref().unwrap().offer;
283 self.device.insert(
284 &self.spawner,
285 format!("{}-{}", offer.interface_id, offer.instance_id),
286 runner,
287 );
288 }
289
290 async fn handle_offer(
292 &mut self,
293 offer: OfferInfo,
294 state: &mut SimpleVmbusClientDeviceTaskState,
295 ) -> Result<()> {
296 tracing::info!(?offer, "matching channel offered");
297
298 if offer.offer.is_dedicated != 1 {
299 tracing::warn!(offer = ?offer.offer, "All offers should be dedicated with Win8+ host")
300 }
301
302 if matches!(
303 self.device.task_mut().0.offer(&offer.offer),
304 OfferResponse::Ignore
305 ) {
306 return Ok(());
307 }
308
309 let interrupt_event = pal_event::Event::new();
310 let (memory, ring_gpadl_id) = self
311 .reserve_memory(state, &offer.request_send, 4)
312 .await
313 .context("reserve memory")?;
314 let guest_to_host_interrupt = offer.guest_to_host_interrupt.clone();
315 state.offer = Some(offer);
316 let offer = state.offer.as_ref().unwrap();
317 self.open_channel(&offer.request_send, ring_gpadl_id, &interrupt_event)
318 .await
319 .context("open channel")?;
320 let channel = self
321 .create_vmbus_channel(&memory, &interrupt_event, guest_to_host_interrupt)
322 .context("create vmbus queue")?;
323
324 let save_restore = self.device.task_mut().0.supports_save_restore();
325 let saved_state = self.saved_state.take();
326 let device_runner = if let Some(save_restore) = save_restore
327 && let Some(saved_state) = saved_state
328 {
329 save_restore
330 .restore_open(saved_state, channel)
331 .context("device restore_open callback")?
332 } else {
333 self.device
334 .task_mut()
335 .0
336 .open(offer.offer.subchannel_index, channel)
337 .context("device open callback")?
338 };
339 self.insert_runner(state, device_runner);
340 self.device.start();
341 Ok(())
342 }
343
344 async fn handle_start(&mut self, state: &mut SimpleVmbusClientDeviceTaskState) {
346 if self.device.is_running() {
347 return;
348 }
349
350 let offer = state.offer.take();
351 if offer.is_none() {
352 return;
353 }
354
355 if let Err(err) = self.handle_offer(offer.unwrap(), state).await {
357 tracing::error!(
358 err = err.as_ref() as &dyn std::error::Error,
359 "Failed to reconnect vmbus channel"
360 );
361 }
362 }
363
364 async fn cleanup_device_resources(&mut self, state: &mut SimpleVmbusClientDeviceTaskState) {
365 let Some(offer) = state.offer.as_mut() else {
366 return;
367 };
368
369 if let Some(vtl_pages) = &state.vtl_pages {
370 match offer
371 .request_send
372 .call(
373 ChannelRequest::TeardownGpadl,
374 GpadlId(vtl_pages.pfns()[1] as u32),
375 )
376 .await
377 {
378 Ok(()) => {
379 state.vtl_pages = None;
380 }
381 Err(err) => {
382 tracing::error!(
386 error = &err as &dyn std::error::Error,
387 "Failed to teardown gpadl -- leaking memory."
388 );
389 }
390 }
391 }
392 }
393
394 async fn handle_stop(&mut self, state: &mut SimpleVmbusClientDeviceTaskState) {
396 if !self.device.stop().await {
397 return;
398 }
399
400 {
410 let offer = state.offer.as_ref().expect("device opened");
411 offer
412 .request_send
413 .call(ChannelRequest::Close, ())
414 .await
415 .ok();
416 }
417 self.cleanup_device_resources(state).await;
421 let runner = self.device.remove();
422 let device = self.device.task_mut();
423 if let Some(save_restore) = device.0.supports_save_restore() {
424 self.saved_state = Some(save_restore.save_open(&runner));
425 }
426 drop(runner);
427 let offer = state.offer.as_ref().expect("device opened");
428 device.0.close(offer.offer.subchannel_index);
429 }
430
431 async fn reserve_memory(
434 &mut self,
435 state: &mut SimpleVmbusClientDeviceTaskState,
436 request_send: &mesh::Sender<ChannelRequest>,
437 page_count: usize,
438 ) -> Result<(MemoryBlock, GpadlId)> {
439 assert!(page_count >= 4);
442
443 let mem = self
444 .dma_alloc
445 .allocate_dma_buffer(page_count * PAGE_SIZE)
446 .context("allocating memory for vmbus rings")?;
447 state.vtl_pages = Some(mem.clone());
448 let buf: Vec<_> = [mem.len() as u64]
449 .iter()
450 .chain(mem.pfns())
451 .copied()
452 .collect();
453
454 let gpadl_id = GpadlId(state.vtl_pages.as_ref().unwrap().pfns()[1] as u32);
455 request_send
456 .call_failable(
457 ChannelRequest::Gpadl,
458 GpadlRequest {
459 id: gpadl_id,
460 count: 1,
461 buf,
462 },
463 )
464 .await
465 .context("registering gpadl")?;
466 Ok((mem, gpadl_id))
467 }
468
469 async fn open_channel(
471 &self,
472 request_send: &mesh::Sender<ChannelRequest>,
473 ring_gpadl_id: GpadlId,
474 event: &pal_event::Event,
475 ) -> Result<OpenOutput> {
476 let open_request = OpenRequest {
477 open_data: OpenData {
478 target_vp: Some(0),
479 ring_offset: 2,
480 ring_gpadl_id,
481 event_flag: !0,
482 connection_id: !0,
483 user_data: UserDefinedData::new_zeroed(),
484 },
485 incoming_event: Some(event.clone()),
486 use_vtl2_connection_id: true,
487 };
488
489 request_send
490 .call_failable(ChannelRequest::Open, open_request)
491 .instrument(tracing::info_span!(
492 "opening vmbus channel for intercepted device"
493 ))
494 .await
495 .context("open vmbus channel")
496 }
497
498 fn create_vmbus_channel(
500 &self,
501 mem: &MemoryBlock,
502 host_to_guest_event: &pal_event::Event,
503 guest_to_host_interrupt: Interrupt,
504 ) -> Result<RawAsyncChannel<MemoryBlockRingBuffer>> {
505 let (out_ring_mem, in_ring_mem) = (
506 mem.subblock(0, 2 * PAGE_SIZE),
507 mem.subblock(2 * PAGE_SIZE, 2 * PAGE_SIZE),
508 );
509 let (in_ring, out_ring) = (
510 IncomingRing::new(in_ring_mem.into()).unwrap(),
511 OutgoingRing::new(out_ring_mem.into()).unwrap(),
512 );
513
514 let signal = MemoryBlockChannelSignal {
515 event: Notify::from_event(host_to_guest_event.clone())
516 .pollable(self.spawner.as_ref())
517 .unwrap(),
518 interrupt: guest_to_host_interrupt,
519 };
520 Ok(RawAsyncChannel {
521 in_ring,
522 out_ring,
523 signal: Box::new(signal),
524 })
525 }
526
527 async fn handle_revoke(&mut self, state: &mut SimpleVmbusClientDeviceTaskState) {
529 let Some(offer) = state.offer.as_ref() else {
530 return;
531 };
532 tracing::info!("device revoked");
533 if self.device.stop().await {
534 drop(self.device.remove());
535 self.device.task_mut().0.close(offer.offer.subchannel_index);
536 }
537 self.cleanup_device_resources(state).await;
538 drop(state.offer.take());
539 }
540
541 fn handle_save(&mut self) -> SavedStateBlob {
542 let saved_state = self.saved_state.take();
543 if let Some(saved_state) = saved_state {
544 let blob = SavedStateBlob::new(saved_state);
545 self.handle_restore(&blob);
546 blob
547 } else {
548 SavedStateBlob::new(NoSavedState)
549 }
550 }
551
552 fn handle_restore(&mut self, saved_state_blob: &SavedStateBlob) {
553 self.saved_state = match saved_state_blob.parse() {
554 Ok(saved_state) => Some(saved_state),
555 Err(err) => {
556 tracing::error!(
557 err = &err as &dyn std::error::Error,
558 "Protobuf conversion error saving state"
559 );
560 None
561 }
562 };
563 }
564
565 pub async fn process_messages(&mut self, state: &mut SimpleVmbusClientDeviceTaskState) {
568 loop {
569 #[expect(clippy::large_enum_variant)]
570 enum Event {
571 Request(InterceptChannelRequest),
572 Revoke,
573 }
574 let r = if let Some(offer) = &mut state.offer {
575 (
576 (&mut state.recv_relay).map(Event::Request),
577 futures::stream::once(&mut offer.revoke_recv).map(|_| Event::Revoke),
578 )
579 .merge()
580 .next()
581 .await
582 } else {
583 let mut recv_relay = pin!(&mut state.recv_relay);
584 recv_relay.next().await.map(Event::Request)
585 };
586 let Some(r) = r else {
587 break;
588 };
589 match r {
590 Event::Revoke => {
591 self.handle_revoke(state).await;
592 }
593 Event::Request(InterceptChannelRequest::Offer(offer)) => {
594 if !self.device.is_running() {
597 if let Err(err) = self.handle_offer(offer, state).await {
598 tracing::error!(
599 error = err.as_ref() as &dyn std::error::Error,
600 "failed offer handling"
601 );
602 }
603 }
604 }
605 Event::Request(InterceptChannelRequest::Start) => {
606 self.handle_start(state).await;
607 }
608 Event::Request(InterceptChannelRequest::Stop(rpc)) => {
609 rpc.handle(async |()| self.handle_stop(state).await).await;
610 }
611 Event::Request(InterceptChannelRequest::Save(rpc)) => {
612 rpc.handle_sync(|()| self.handle_save());
613 }
614 Event::Request(InterceptChannelRequest::Restore(saved_state)) => {
615 self.handle_restore(&saved_state);
616 }
617 }
618 }
619 }
620}
621
622struct MemoryBlockChannelSignal {
623 event: PolledNotify,
624 interrupt: Interrupt,
625}
626
627impl SignalVmbusChannel for MemoryBlockChannelSignal {
628 fn signal_remote(&self) {
629 self.interrupt.deliver();
630 }
631
632 fn poll_for_signal(
633 &self,
634 cx: &mut std::task::Context<'_>,
635 ) -> std::task::Poll<Result<(), ChannelClosed>> {
636 self.event.poll_wait(cx).map(Ok)
637 }
638}