1#![expect(missing_docs)]
11#![forbid(unsafe_code)]
12
13pub mod ring_buffer;
14
15use crate::ring_buffer::MemoryBlockRingBuffer;
16use anyhow::Context;
17use anyhow::Result;
18use futures::StreamExt;
19use futures_concurrency::stream::Merge;
20use guid::Guid;
21use inspect::InspectMut;
22use mesh::rpc::RpcSend;
23use pal_async::driver::SpawnDriver;
24use std::future::Future;
25use std::future::pending;
26use std::pin::pin;
27use std::sync::Arc;
28use task_control::AsyncRun;
29use task_control::Cancelled;
30use task_control::InspectTaskMut;
31use task_control::StopTask;
32use task_control::TaskControl;
33use tracing::Instrument;
34use user_driver::DmaClient;
35use user_driver::memory::MemoryBlock;
36use vmbus_channel::ChannelClosed;
37use vmbus_channel::RawAsyncChannel;
38use vmbus_channel::SignalVmbusChannel;
39use vmbus_channel::bus::GpadlRequest;
40use vmbus_channel::bus::OpenData;
41use vmbus_client::ChannelRequest;
42use vmbus_client::OfferInfo;
43use vmbus_client::OpenOutput;
44use vmbus_client::OpenRequest;
45use vmbus_core::protocol::GpadlId;
46use vmbus_core::protocol::UserDefinedData;
47use vmbus_relay::InterceptChannelRequest;
48use vmbus_ring::IncomingRing;
49use vmbus_ring::OutgoingRing;
50use vmbus_ring::PAGE_SIZE;
51use vmcore::interrupt::Interrupt;
52use vmcore::notify::Notify;
53use vmcore::notify::PolledNotify;
54use vmcore::save_restore::NoSavedState;
55use vmcore::save_restore::SavedStateBlob;
56use vmcore::save_restore::SavedStateRoot;
57use zerocopy::FromZeros;
58
59pub enum OfferResponse {
60 Ignore,
61 Open,
62}
63
64pub trait SimpleVmbusClientDevice {
65 type SavedState: SavedStateRoot + Send + Sync;
67
68 type Runner: 'static + Send + Sync;
70
71 fn inspect(&mut self, req: inspect::Request<'_>, runner: Option<&mut Self::Runner>);
73
74 fn instance_id(&self) -> Guid;
76
77 fn offer(&self, offer: &vmbus_core::protocol::OfferChannel) -> OfferResponse;
79
80 fn open(
84 &mut self,
85 channel_idx: u16,
86 channel: RawAsyncChannel<MemoryBlockRingBuffer>,
87 ) -> Result<Self::Runner>;
88
89 fn close(&mut self, channel_idx: u16);
92
93 fn supports_save_restore(
95 &mut self,
96 ) -> Option<
97 &mut dyn SaveRestoreSimpleVmbusClientDevice<
98 SavedState = Self::SavedState,
99 Runner = Self::Runner,
100 >,
101 >;
102}
103
104pub trait SimpleVmbusClientDeviceAsync: SimpleVmbusClientDevice + 'static + Send + Sync {
105 fn run(
107 &mut self,
108 stop: &mut StopTask<'_>,
109 runner: &mut Self::Runner,
110 ) -> impl Send + Future<Output = Result<(), Cancelled>>;
111}
112
113pub trait SaveRestoreSimpleVmbusClientDevice: SimpleVmbusClientDevice {
118 fn save_open(&mut self, runner: &Self::Runner) -> Self::SavedState;
122
123 fn restore_open(
127 &mut self,
128 state: Self::SavedState,
129 channel: RawAsyncChannel<MemoryBlockRingBuffer>,
130 ) -> Result<Self::Runner>;
131}
132
133#[derive(InspectMut)]
134pub struct SimpleVmbusClientDeviceWrapper<T: SimpleVmbusClientDeviceAsync> {
135 instance_id: Guid,
136 #[inspect(skip)]
137 spawner: Arc<dyn SpawnDriver>,
138 #[inspect(mut)]
139 vmbus_listener: TaskControl<SimpleVmbusClientDeviceTask<T>, SimpleVmbusClientDeviceTaskState>,
140}
141
142impl<T: SimpleVmbusClientDeviceAsync> SimpleVmbusClientDeviceWrapper<T> {
143 pub fn new(
145 driver: impl SpawnDriver + Clone,
146 dma_alloc: Arc<dyn DmaClient>,
147 device: T,
148 ) -> Result<Self> {
149 let spawner = Arc::new(driver.clone());
150 Ok(Self {
151 instance_id: device.instance_id(),
152 vmbus_listener: TaskControl::new(SimpleVmbusClientDeviceTask::new(
153 device,
154 spawner.clone(),
155 dma_alloc,
156 )),
157 spawner,
158 })
159 }
160
161 pub fn instance_id(&self) -> Guid {
162 self.instance_id
163 }
164
165 pub fn detach(
166 mut self,
167 driver: impl SpawnDriver,
168 recv_relay: mesh::Receiver<InterceptChannelRequest>,
169 ) -> Result<()> {
170 let (send_disconnected, recv_disconnected) = mesh::oneshot();
171 self.vmbus_listener.insert(
172 &self.spawner,
173 format!("{}", self.instance_id),
174 SimpleVmbusClientDeviceTaskState {
175 offer: None,
176 recv_relay,
177 send_disconnected: Some(send_disconnected),
178 vtl_pages: None,
179 },
180 );
181 driver
182 .spawn(
183 format!("vmbus_relay_device {}", self.instance_id),
184 async move {
185 self.vmbus_listener.start();
186 let _ = recv_disconnected.await;
187 assert!(!self.vmbus_listener.stop().await);
188 if self.vmbus_listener.state().unwrap().vtl_pages.is_some() {
189 pending::<()>().await;
194 }
195 },
196 )
197 .detach();
198 Ok(())
199 }
200}
201
202struct RelayDeviceTask<T>(T);
203
204impl<T: SimpleVmbusClientDeviceAsync> AsyncRun<T::Runner> for RelayDeviceTask<T> {
205 async fn run(
206 &mut self,
207 stop: &mut StopTask<'_>,
208 runner: &mut T::Runner,
209 ) -> Result<(), Cancelled> {
210 self.0.run(stop, runner).await
211 }
212}
213
214impl<T: SimpleVmbusClientDeviceAsync> InspectTaskMut<T::Runner> for RelayDeviceTask<T> {
215 fn inspect_mut(&mut self, req: inspect::Request<'_>, runner: Option<&mut T::Runner>) {
216 self.0.inspect(req, runner)
217 }
218}
219
220#[derive(InspectMut)]
221struct SimpleVmbusClientDeviceTaskState {
222 offer: Option<OfferInfo>,
223 #[inspect(skip)]
224 recv_relay: mesh::Receiver<InterceptChannelRequest>,
225 #[inspect(skip)]
226 send_disconnected: Option<mesh::OneshotSender<()>>,
227 #[inspect(hex, with = "|x| x.as_ref().map(|x| inspect::iter_by_index(x.pfns()))")]
228 vtl_pages: Option<MemoryBlock>,
229}
230
231struct SimpleVmbusClientDeviceTask<T: SimpleVmbusClientDeviceAsync> {
232 device: TaskControl<RelayDeviceTask<T>, T::Runner>,
233 saved_state: Option<T::SavedState>,
234 spawner: Arc<dyn SpawnDriver>,
235 dma_alloc: Arc<dyn DmaClient>,
236}
237
238impl<T: SimpleVmbusClientDeviceAsync> AsyncRun<SimpleVmbusClientDeviceTaskState>
239 for SimpleVmbusClientDeviceTask<T>
240{
241 async fn run(
242 &mut self,
243 stop: &mut StopTask<'_>,
244 state: &mut SimpleVmbusClientDeviceTaskState,
245 ) -> Result<(), Cancelled> {
246 stop.until_stopped(self.process_messages(state)).await?;
247 state
248 .send_disconnected
249 .take()
250 .expect("task should not be restarted")
251 .send(());
252 Ok(())
253 }
254}
255
256impl<T: SimpleVmbusClientDeviceAsync> InspectTaskMut<SimpleVmbusClientDeviceTaskState>
257 for SimpleVmbusClientDeviceTask<T>
258{
259 fn inspect_mut(
260 &mut self,
261 req: inspect::Request<'_>,
262 state: Option<&mut SimpleVmbusClientDeviceTaskState>,
263 ) {
264 req.respond()
265 .merge(state)
266 .field_mut("device", &mut self.device)
267 .field("dma_alloc", &self.dma_alloc);
268 }
269}
270
271impl<T: SimpleVmbusClientDeviceAsync> SimpleVmbusClientDeviceTask<T> {
272 pub fn new(device: T, spawner: Arc<dyn SpawnDriver>, dma_alloc: Arc<dyn DmaClient>) -> Self {
273 Self {
274 device: TaskControl::new(RelayDeviceTask(device)),
275 saved_state: None,
276 spawner,
277 dma_alloc,
278 }
279 }
280
281 fn insert_runner(&mut self, state: &SimpleVmbusClientDeviceTaskState, runner: T::Runner) {
282 let offer = state.offer.as_ref().unwrap().offer;
283 self.device.insert(
284 &self.spawner,
285 format!("{}-{}", offer.interface_id, offer.instance_id),
286 runner,
287 );
288 }
289
290 async fn handle_offer(
292 &mut self,
293 offer: OfferInfo,
294 state: &mut SimpleVmbusClientDeviceTaskState,
295 ) -> Result<()> {
296 tracing::info!(?offer, "matching channel offered");
297
298 if offer.offer.is_dedicated != 1 {
299 tracing::warn!(offer = ?offer.offer, "All offers should be dedicated with Win8+ host")
300 }
301
302 if matches!(
303 self.device.task_mut().0.offer(&offer.offer),
304 OfferResponse::Ignore
305 ) {
306 return Ok(());
307 }
308
309 let interrupt_event = pal_event::Event::new();
310 let (memory, ring_gpadl_id) = self
311 .reserve_memory(state, &offer.request_send, 4)
312 .await
313 .context("reserve memory")?;
314 let guest_to_host_interrupt = offer.guest_to_host_interrupt.clone();
315 state.offer = Some(offer);
316 let offer = state.offer.as_ref().unwrap();
317 self.open_channel(&offer.request_send, ring_gpadl_id, &interrupt_event)
318 .await
319 .context("open channel")?;
320 let channel = self
321 .create_vmbus_channel(&memory, &interrupt_event, guest_to_host_interrupt)
322 .context("create vmbus queue")?;
323
324 let save_restore = self.device.task_mut().0.supports_save_restore();
325 let saved_state = self.saved_state.take();
326 let device_runner = if save_restore.is_some() && saved_state.is_some() {
327 save_restore
328 .unwrap()
329 .restore_open(saved_state.unwrap(), channel)
330 .context("device restore_open callback")?
331 } else {
332 self.device
333 .task_mut()
334 .0
335 .open(offer.offer.subchannel_index, channel)
336 .context("device open callback")?
337 };
338 self.insert_runner(state, device_runner);
339 self.device.start();
340 Ok(())
341 }
342
343 async fn handle_start(&mut self, state: &mut SimpleVmbusClientDeviceTaskState) {
345 if self.device.is_running() {
346 return;
347 }
348
349 let offer = state.offer.take();
350 if offer.is_none() {
351 return;
352 }
353
354 if let Err(err) = self.handle_offer(offer.unwrap(), state).await {
356 tracing::error!(
357 err = err.as_ref() as &dyn std::error::Error,
358 "Failed to reconnect vmbus channel"
359 );
360 }
361 }
362
363 async fn cleanup_device_resources(&mut self, state: &mut SimpleVmbusClientDeviceTaskState) {
364 let Some(offer) = state.offer.as_mut() else {
365 return;
366 };
367
368 if state.vtl_pages.is_some() {
369 match offer
370 .request_send
371 .call(
372 ChannelRequest::TeardownGpadl,
373 GpadlId(state.vtl_pages.as_ref().unwrap().pfns()[1] as u32),
374 )
375 .await
376 {
377 Ok(()) => {
378 state.vtl_pages = None;
379 }
380 Err(err) => {
381 tracing::error!(
385 error = &err as &dyn std::error::Error,
386 "Failed to teardown gpadl -- leaking memory."
387 );
388 }
389 }
390 }
391 }
392
393 async fn handle_stop(&mut self, state: &mut SimpleVmbusClientDeviceTaskState) {
395 if !self.device.stop().await {
396 return;
397 }
398
399 {
409 let offer = state.offer.as_ref().expect("device opened");
410 offer
411 .request_send
412 .call(ChannelRequest::Close, ())
413 .await
414 .ok();
415 }
416 self.cleanup_device_resources(state).await;
420 let runner = self.device.remove();
421 let device = self.device.task_mut();
422 if let Some(save_restore) = device.0.supports_save_restore() {
423 self.saved_state = Some(save_restore.save_open(&runner));
424 }
425 drop(runner);
426 let offer = state.offer.as_ref().expect("device opened");
427 device.0.close(offer.offer.subchannel_index);
428 }
429
430 async fn reserve_memory(
433 &mut self,
434 state: &mut SimpleVmbusClientDeviceTaskState,
435 request_send: &mesh::Sender<ChannelRequest>,
436 page_count: usize,
437 ) -> Result<(MemoryBlock, GpadlId)> {
438 assert!(page_count >= 4);
441
442 let mem = self
443 .dma_alloc
444 .allocate_dma_buffer(page_count * PAGE_SIZE)
445 .context("allocating memory for vmbus rings")?;
446 state.vtl_pages = Some(mem.clone());
447 let buf: Vec<_> = [mem.len() as u64]
448 .iter()
449 .chain(mem.pfns())
450 .copied()
451 .collect();
452
453 let gpadl_id = GpadlId(state.vtl_pages.as_ref().unwrap().pfns()[1] as u32);
454 request_send
455 .call_failable(
456 ChannelRequest::Gpadl,
457 GpadlRequest {
458 id: gpadl_id,
459 count: 1,
460 buf,
461 },
462 )
463 .await
464 .context("registering gpadl")?;
465 Ok((mem, gpadl_id))
466 }
467
468 async fn open_channel(
470 &self,
471 request_send: &mesh::Sender<ChannelRequest>,
472 ring_gpadl_id: GpadlId,
473 event: &pal_event::Event,
474 ) -> Result<OpenOutput> {
475 let open_request = OpenRequest {
476 open_data: OpenData {
477 target_vp: 0,
478 ring_offset: 2,
479 ring_gpadl_id,
480 event_flag: !0,
481 connection_id: !0,
482 user_data: UserDefinedData::new_zeroed(),
483 },
484 incoming_event: Some(event.clone()),
485 use_vtl2_connection_id: true,
486 };
487
488 request_send
489 .call_failable(ChannelRequest::Open, open_request)
490 .instrument(tracing::info_span!(
491 "opening vmbus channel for intercepted device"
492 ))
493 .await
494 .context("open vmbus channel")
495 }
496
497 fn create_vmbus_channel(
499 &self,
500 mem: &MemoryBlock,
501 host_to_guest_event: &pal_event::Event,
502 guest_to_host_interrupt: Interrupt,
503 ) -> Result<RawAsyncChannel<MemoryBlockRingBuffer>> {
504 let (out_ring_mem, in_ring_mem) = (
505 mem.subblock(0, 2 * PAGE_SIZE),
506 mem.subblock(2 * PAGE_SIZE, 2 * PAGE_SIZE),
507 );
508 let (in_ring, out_ring) = (
509 IncomingRing::new(in_ring_mem.into()).unwrap(),
510 OutgoingRing::new(out_ring_mem.into()).unwrap(),
511 );
512
513 let signal = MemoryBlockChannelSignal {
514 event: Notify::from_event(host_to_guest_event.clone())
515 .pollable(self.spawner.as_ref())
516 .unwrap(),
517 interrupt: guest_to_host_interrupt,
518 };
519 Ok(RawAsyncChannel {
520 in_ring,
521 out_ring,
522 signal: Box::new(signal),
523 })
524 }
525
526 async fn handle_revoke(&mut self, state: &mut SimpleVmbusClientDeviceTaskState) {
528 let Some(offer) = state.offer.as_ref() else {
529 return;
530 };
531 tracing::info!("device revoked");
532 if self.device.stop().await {
533 drop(self.device.remove());
534 self.device.task_mut().0.close(offer.offer.subchannel_index);
535 }
536 self.cleanup_device_resources(state).await;
537 drop(state.offer.take());
538 }
539
540 fn handle_save(&mut self) -> SavedStateBlob {
541 let saved_state = self.saved_state.take();
542 if let Some(saved_state) = saved_state {
543 let blob = SavedStateBlob::new(saved_state);
544 self.handle_restore(&blob);
545 blob
546 } else {
547 SavedStateBlob::new(NoSavedState)
548 }
549 }
550
551 fn handle_restore(&mut self, saved_state_blob: &SavedStateBlob) {
552 self.saved_state = match saved_state_blob.parse() {
553 Ok(saved_state) => Some(saved_state),
554 Err(err) => {
555 tracing::error!(
556 err = &err as &dyn std::error::Error,
557 "Protobuf conversion error saving state"
558 );
559 None
560 }
561 };
562 }
563
564 pub async fn process_messages(&mut self, state: &mut SimpleVmbusClientDeviceTaskState) {
567 loop {
568 #[expect(clippy::large_enum_variant)]
569 enum Event {
570 Request(InterceptChannelRequest),
571 Revoke,
572 }
573 let r = if let Some(offer) = &mut state.offer {
574 (
575 (&mut state.recv_relay).map(Event::Request),
576 futures::stream::once(&mut offer.revoke_recv).map(|_| Event::Revoke),
577 )
578 .merge()
579 .next()
580 .await
581 } else {
582 let mut recv_relay = pin!(&mut state.recv_relay);
583 recv_relay.next().await.map(Event::Request)
584 };
585 let Some(r) = r else {
586 break;
587 };
588 match r {
589 Event::Revoke => {
590 self.handle_revoke(state).await;
591 }
592 Event::Request(InterceptChannelRequest::Offer(offer)) => {
593 if !self.device.is_running() {
596 if let Err(err) = self.handle_offer(offer, state).await {
597 tracing::error!(
598 error = err.as_ref() as &dyn std::error::Error,
599 "failed offer handling"
600 );
601 }
602 }
603 }
604 Event::Request(InterceptChannelRequest::Start) => {
605 self.handle_start(state).await;
606 }
607 Event::Request(InterceptChannelRequest::Stop(rpc)) => {
608 rpc.handle(async |()| self.handle_stop(state).await).await;
609 }
610 Event::Request(InterceptChannelRequest::Save(rpc)) => {
611 rpc.handle_sync(|()| self.handle_save());
612 }
613 Event::Request(InterceptChannelRequest::Restore(saved_state)) => {
614 self.handle_restore(&saved_state);
615 }
616 }
617 }
618 }
619}
620
621struct MemoryBlockChannelSignal {
622 event: PolledNotify,
623 interrupt: Interrupt,
624}
625
626impl SignalVmbusChannel for MemoryBlockChannelSignal {
627 fn signal_remote(&self) {
628 self.interrupt.deliver();
629 }
630
631 fn poll_for_signal(
632 &self,
633 cx: &mut std::task::Context<'_>,
634 ) -> std::task::Poll<Result<(), ChannelClosed>> {
635 self.event.poll_wait(cx).map(Ok)
636 }
637}