vmbus_relay_intercept_device/
lib.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! This module contains logic used to intercept from VTL2 a vmbus device
5//! provided for a VTL0 guest. This requires the vmbus relay to be active,
6//! which will filter the device out from the list provided to the VTL0 guest
7//! and send any vmbus notifications for that device to the
8//! SimpleVmbusClientDeviceWrapper instance.
9
10#![expect(missing_docs)]
11#![forbid(unsafe_code)]
12
13pub mod ring_buffer;
14
15use crate::ring_buffer::MemoryBlockRingBuffer;
16use anyhow::Context;
17use anyhow::Result;
18use futures::StreamExt;
19use futures_concurrency::stream::Merge;
20use guid::Guid;
21use inspect::InspectMut;
22use mesh::rpc::RpcSend;
23use pal_async::driver::SpawnDriver;
24use std::future::Future;
25use std::future::pending;
26use std::pin::pin;
27use std::sync::Arc;
28use task_control::AsyncRun;
29use task_control::Cancelled;
30use task_control::InspectTaskMut;
31use task_control::StopTask;
32use task_control::TaskControl;
33use tracing::Instrument;
34use user_driver::DmaClient;
35use user_driver::memory::MemoryBlock;
36use vmbus_channel::ChannelClosed;
37use vmbus_channel::RawAsyncChannel;
38use vmbus_channel::SignalVmbusChannel;
39use vmbus_channel::bus::GpadlRequest;
40use vmbus_channel::bus::OpenData;
41use vmbus_client::ChannelRequest;
42use vmbus_client::OfferInfo;
43use vmbus_client::OpenOutput;
44use vmbus_client::OpenRequest;
45use vmbus_core::protocol::GpadlId;
46use vmbus_core::protocol::UserDefinedData;
47use vmbus_relay::InterceptChannelRequest;
48use vmbus_ring::IncomingRing;
49use vmbus_ring::OutgoingRing;
50use vmbus_ring::PAGE_SIZE;
51use vmcore::interrupt::Interrupt;
52use vmcore::notify::Notify;
53use vmcore::notify::PolledNotify;
54use vmcore::save_restore::NoSavedState;
55use vmcore::save_restore::SavedStateBlob;
56use vmcore::save_restore::SavedStateRoot;
57use zerocopy::FromZeros;
58
59pub enum OfferResponse {
60    Ignore,
61    Open,
62}
63
64pub trait SimpleVmbusClientDevice {
65    /// The saved state type.
66    type SavedState: SavedStateRoot + Send + Sync;
67
68    /// The type used to run an open channel.
69    type Runner: 'static + Send + Sync;
70
71    /// Inspects a channel.
72    fn inspect(&mut self, req: inspect::Request<'_>, runner: Option<&mut Self::Runner>);
73
74    /// Returns the instance ID of the matching device.
75    fn instance_id(&self) -> Guid;
76
77    /// Respond to a new channel offer for a device matching instance_id().
78    fn offer(&self, offer: &vmbus_core::protocol::OfferChannel) -> OfferResponse;
79
80    /// Open successful for the channel number `channel_idx`.
81    ///
82    /// When the channel is closed, the runner will be dropped.
83    fn open(
84        &mut self,
85        channel_idx: u16,
86        channel: RawAsyncChannel<MemoryBlockRingBuffer>,
87    ) -> Result<Self::Runner>;
88
89    /// Closes the channel number `channel_idx` after the runner has been
90    /// dropped.
91    fn close(&mut self, channel_idx: u16);
92
93    /// Returns a trait used to save/restore the device.
94    fn supports_save_restore(
95        &mut self,
96    ) -> Option<
97        &mut dyn SaveRestoreSimpleVmbusClientDevice<
98            SavedState = Self::SavedState,
99            Runner = Self::Runner,
100        >,
101    >;
102}
103
104pub trait SimpleVmbusClientDeviceAsync: SimpleVmbusClientDevice + 'static + Send + Sync {
105    /// Runs an open channel until `stop` is signaled.
106    fn run(
107        &mut self,
108        stop: &mut StopTask<'_>,
109        runner: &mut Self::Runner,
110    ) -> impl Send + Future<Output = Result<(), Cancelled>>;
111}
112
113/// Trait implemented by simple vmbus client devices that support save/restore.
114///
115/// If you implement this, make sure to return `Some(self)` from
116/// [`SimpleVmbusClientDevice::supports_save_restore`].
117pub trait SaveRestoreSimpleVmbusClientDevice: SimpleVmbusClientDevice {
118    /// Saves the channel.
119    ///
120    /// Will only be called if the channel is open.
121    fn save_open(&mut self, runner: &Self::Runner) -> Self::SavedState;
122
123    /// Restores the channel.
124    ///
125    /// Will only be called if the channel was saved open.
126    fn restore_open(
127        &mut self,
128        state: Self::SavedState,
129        channel: RawAsyncChannel<MemoryBlockRingBuffer>,
130    ) -> Result<Self::Runner>;
131}
132
133#[derive(InspectMut)]
134pub struct SimpleVmbusClientDeviceWrapper<T: SimpleVmbusClientDeviceAsync> {
135    instance_id: Guid,
136    #[inspect(skip)]
137    spawner: Arc<dyn SpawnDriver>,
138    #[inspect(mut)]
139    vmbus_listener: TaskControl<SimpleVmbusClientDeviceTask<T>, SimpleVmbusClientDeviceTaskState>,
140}
141
142impl<T: SimpleVmbusClientDeviceAsync> SimpleVmbusClientDeviceWrapper<T> {
143    /// Create a new instance.
144    pub fn new(
145        driver: impl SpawnDriver + Clone,
146        dma_alloc: Arc<dyn DmaClient>,
147        device: T,
148    ) -> Result<Self> {
149        let spawner = Arc::new(driver.clone());
150        Ok(Self {
151            instance_id: device.instance_id(),
152            vmbus_listener: TaskControl::new(SimpleVmbusClientDeviceTask::new(
153                device,
154                spawner.clone(),
155                dma_alloc,
156            )),
157            spawner,
158        })
159    }
160
161    pub fn instance_id(&self) -> Guid {
162        self.instance_id
163    }
164
165    pub fn detach(
166        mut self,
167        driver: impl SpawnDriver,
168        recv_relay: mesh::Receiver<InterceptChannelRequest>,
169    ) -> Result<()> {
170        let (send_disconnected, recv_disconnected) = mesh::oneshot();
171        self.vmbus_listener.insert(
172            &self.spawner,
173            format!("{}", self.instance_id),
174            SimpleVmbusClientDeviceTaskState {
175                offer: None,
176                recv_relay,
177                send_disconnected: Some(send_disconnected),
178                vtl_pages: None,
179            },
180        );
181        driver
182            .spawn(
183                format!("vmbus_relay_device {}", self.instance_id),
184                async move {
185                    self.vmbus_listener.start();
186                    let _ = recv_disconnected.await;
187                    assert!(!self.vmbus_listener.stop().await);
188                    if self.vmbus_listener.state().unwrap().vtl_pages.is_some() {
189                        // The VTL pages were not freed. This can occur if an
190                        // error is hit that drops the vmbus parent tasks. Just
191                        // pend here and let the outer error cause the VM to
192                        // exit.
193                        pending::<()>().await;
194                    }
195                },
196            )
197            .detach();
198        Ok(())
199    }
200}
201
202struct RelayDeviceTask<T>(T);
203
204impl<T: SimpleVmbusClientDeviceAsync> AsyncRun<T::Runner> for RelayDeviceTask<T> {
205    async fn run(
206        &mut self,
207        stop: &mut StopTask<'_>,
208        runner: &mut T::Runner,
209    ) -> Result<(), Cancelled> {
210        self.0.run(stop, runner).await
211    }
212}
213
214impl<T: SimpleVmbusClientDeviceAsync> InspectTaskMut<T::Runner> for RelayDeviceTask<T> {
215    fn inspect_mut(&mut self, req: inspect::Request<'_>, runner: Option<&mut T::Runner>) {
216        self.0.inspect(req, runner)
217    }
218}
219
220#[derive(InspectMut)]
221struct SimpleVmbusClientDeviceTaskState {
222    offer: Option<OfferInfo>,
223    #[inspect(skip)]
224    recv_relay: mesh::Receiver<InterceptChannelRequest>,
225    #[inspect(skip)]
226    send_disconnected: Option<mesh::OneshotSender<()>>,
227    #[inspect(hex, with = "|x| x.as_ref().map(|x| inspect::iter_by_index(x.pfns()))")]
228    vtl_pages: Option<MemoryBlock>,
229}
230
231struct SimpleVmbusClientDeviceTask<T: SimpleVmbusClientDeviceAsync> {
232    device: TaskControl<RelayDeviceTask<T>, T::Runner>,
233    saved_state: Option<T::SavedState>,
234    spawner: Arc<dyn SpawnDriver>,
235    dma_alloc: Arc<dyn DmaClient>,
236}
237
238impl<T: SimpleVmbusClientDeviceAsync> AsyncRun<SimpleVmbusClientDeviceTaskState>
239    for SimpleVmbusClientDeviceTask<T>
240{
241    async fn run(
242        &mut self,
243        stop: &mut StopTask<'_>,
244        state: &mut SimpleVmbusClientDeviceTaskState,
245    ) -> Result<(), Cancelled> {
246        stop.until_stopped(self.process_messages(state)).await?;
247        state
248            .send_disconnected
249            .take()
250            .expect("task should not be restarted")
251            .send(());
252        Ok(())
253    }
254}
255
256impl<T: SimpleVmbusClientDeviceAsync> InspectTaskMut<SimpleVmbusClientDeviceTaskState>
257    for SimpleVmbusClientDeviceTask<T>
258{
259    fn inspect_mut(
260        &mut self,
261        req: inspect::Request<'_>,
262        state: Option<&mut SimpleVmbusClientDeviceTaskState>,
263    ) {
264        req.respond()
265            .merge(state)
266            .field_mut("device", &mut self.device)
267            .field("dma_alloc", &self.dma_alloc);
268    }
269}
270
271impl<T: SimpleVmbusClientDeviceAsync> SimpleVmbusClientDeviceTask<T> {
272    pub fn new(device: T, spawner: Arc<dyn SpawnDriver>, dma_alloc: Arc<dyn DmaClient>) -> Self {
273        Self {
274            device: TaskControl::new(RelayDeviceTask(device)),
275            saved_state: None,
276            spawner,
277            dma_alloc,
278        }
279    }
280
281    fn insert_runner(&mut self, state: &SimpleVmbusClientDeviceTaskState, runner: T::Runner) {
282        let offer = state.offer.as_ref().unwrap().offer;
283        self.device.insert(
284            &self.spawner,
285            format!("{}-{}", offer.interface_id, offer.instance_id),
286            runner,
287        );
288    }
289
290    /// Configures channel.
291    async fn handle_offer(
292        &mut self,
293        offer: OfferInfo,
294        state: &mut SimpleVmbusClientDeviceTaskState,
295    ) -> Result<()> {
296        tracing::info!(?offer, "matching channel offered");
297
298        if offer.offer.is_dedicated != 1 {
299            tracing::warn!(offer = ?offer.offer, "All offers should be dedicated with Win8+ host")
300        }
301
302        if matches!(
303            self.device.task_mut().0.offer(&offer.offer),
304            OfferResponse::Ignore
305        ) {
306            return Ok(());
307        }
308
309        let interrupt_event = pal_event::Event::new();
310        let (memory, ring_gpadl_id) = self
311            .reserve_memory(state, &offer.request_send, 4)
312            .await
313            .context("reserve memory")?;
314        let guest_to_host_interrupt = offer.guest_to_host_interrupt.clone();
315        state.offer = Some(offer);
316        let offer = state.offer.as_ref().unwrap();
317        self.open_channel(&offer.request_send, ring_gpadl_id, &interrupt_event)
318            .await
319            .context("open channel")?;
320        let channel = self
321            .create_vmbus_channel(&memory, &interrupt_event, guest_to_host_interrupt)
322            .context("create vmbus queue")?;
323
324        let save_restore = self.device.task_mut().0.supports_save_restore();
325        let saved_state = self.saved_state.take();
326        let device_runner = if let Some(save_restore) = save_restore
327            && let Some(saved_state) = saved_state
328        {
329            save_restore
330                .restore_open(saved_state, channel)
331                .context("device restore_open callback")?
332        } else {
333            self.device
334                .task_mut()
335                .0
336                .open(offer.offer.subchannel_index, channel)
337                .context("device open callback")?
338        };
339        self.insert_runner(state, device_runner);
340        self.device.start();
341        Ok(())
342    }
343
344    /// Start channel after it has been stopped.
345    async fn handle_start(&mut self, state: &mut SimpleVmbusClientDeviceTaskState) {
346        if self.device.is_running() {
347            return;
348        }
349
350        let offer = state.offer.take();
351        if offer.is_none() {
352            return;
353        }
354
355        // If there is a previous valid offer, open the channel again.
356        if let Err(err) = self.handle_offer(offer.unwrap(), state).await {
357            tracing::error!(
358                err = err.as_ref() as &dyn std::error::Error,
359                "Failed to reconnect vmbus channel"
360            );
361        }
362    }
363
364    async fn cleanup_device_resources(&mut self, state: &mut SimpleVmbusClientDeviceTaskState) {
365        let Some(offer) = state.offer.as_mut() else {
366            return;
367        };
368
369        if let Some(vtl_pages) = &state.vtl_pages {
370            match offer
371                .request_send
372                .call(
373                    ChannelRequest::TeardownGpadl,
374                    GpadlId(vtl_pages.pfns()[1] as u32),
375                )
376                .await
377            {
378                Ok(()) => {
379                    state.vtl_pages = None;
380                }
381                Err(err) => {
382                    // If the ring buffer pages are still in use by the host, which
383                    // has to be assumed, the memory pages cannot be used again as
384                    // they have been marked as visible to VTL0.
385                    tracing::error!(
386                        error = &err as &dyn std::error::Error,
387                        "Failed to teardown gpadl -- leaking memory."
388                    );
389                }
390            }
391        }
392    }
393
394    /// Stop channel
395    async fn handle_stop(&mut self, state: &mut SimpleVmbusClientDeviceTaskState) {
396        if !self.device.stop().await {
397            return;
398        }
399
400        // Close the channel on every stop. Overlay devices cannot be saved /
401        // restored because the physical pages used for the ring buffer, et al.
402        // would need to be reserved at boot, otherwise the host may end up
403        // scribbling on random memory as it continues updating a ring buffer it
404        // assumes it has ownership of.
405        //
406        // TODO: We could support save restore, if we had a pool of memory that
407        // supports that. This should be possible once the page_pool_alloc is
408        // available everywhere.
409        {
410            let offer = state.offer.as_ref().expect("device opened");
411            offer
412                .request_send
413                .call(ChannelRequest::Close, ())
414                .await
415                .ok();
416        }
417        // N.B. This will wait for a TeardownGpadl response which can be used
418        // as a signal that the channel is closed and the ring buffers are no
419        // longer in use.
420        self.cleanup_device_resources(state).await;
421        let runner = self.device.remove();
422        let device = self.device.task_mut();
423        if let Some(save_restore) = device.0.supports_save_restore() {
424            self.saved_state = Some(save_restore.save_open(&runner));
425        }
426        drop(runner);
427        let offer = state.offer.as_ref().expect("device opened");
428        device.0.close(offer.offer.subchannel_index);
429    }
430
431    /// Allocates memory to be shared with the host and registers it with a
432    /// GPADL ID.
433    async fn reserve_memory(
434        &mut self,
435        state: &mut SimpleVmbusClientDeviceTaskState,
436        request_send: &mesh::Sender<ChannelRequest>,
437        page_count: usize,
438    ) -> Result<(MemoryBlock, GpadlId)> {
439        // Incoming and outgoing rings require a minimum of two pages apiece:
440        // one for the control bytes and at least one for the ring.
441        assert!(page_count >= 4);
442
443        let mem = self
444            .dma_alloc
445            .allocate_dma_buffer(page_count * PAGE_SIZE)
446            .context("allocating memory for vmbus rings")?;
447        state.vtl_pages = Some(mem.clone());
448        let buf: Vec<_> = [mem.len() as u64]
449            .iter()
450            .chain(mem.pfns())
451            .copied()
452            .collect();
453
454        let gpadl_id = GpadlId(state.vtl_pages.as_ref().unwrap().pfns()[1] as u32);
455        request_send
456            .call_failable(
457                ChannelRequest::Gpadl,
458                GpadlRequest {
459                    id: gpadl_id,
460                    count: 1,
461                    buf,
462                },
463            )
464            .await
465            .context("registering gpadl")?;
466        Ok((mem, gpadl_id))
467    }
468
469    /// Open the channel offered by the host.
470    async fn open_channel(
471        &self,
472        request_send: &mesh::Sender<ChannelRequest>,
473        ring_gpadl_id: GpadlId,
474        event: &pal_event::Event,
475    ) -> Result<OpenOutput> {
476        let open_request = OpenRequest {
477            open_data: OpenData {
478                target_vp: Some(0),
479                ring_offset: 2,
480                ring_gpadl_id,
481                event_flag: !0,
482                connection_id: !0,
483                user_data: UserDefinedData::new_zeroed(),
484            },
485            incoming_event: Some(event.clone()),
486            use_vtl2_connection_id: true,
487        };
488
489        request_send
490            .call_failable(ChannelRequest::Open, open_request)
491            .instrument(tracing::info_span!(
492                "opening vmbus channel for intercepted device"
493            ))
494            .await
495            .context("open vmbus channel")
496    }
497
498    /// Create a raw vmbus channel.
499    fn create_vmbus_channel(
500        &self,
501        mem: &MemoryBlock,
502        host_to_guest_event: &pal_event::Event,
503        guest_to_host_interrupt: Interrupt,
504    ) -> Result<RawAsyncChannel<MemoryBlockRingBuffer>> {
505        let (out_ring_mem, in_ring_mem) = (
506            mem.subblock(0, 2 * PAGE_SIZE),
507            mem.subblock(2 * PAGE_SIZE, 2 * PAGE_SIZE),
508        );
509        let (in_ring, out_ring) = (
510            IncomingRing::new(in_ring_mem.into()).unwrap(),
511            OutgoingRing::new(out_ring_mem.into()).unwrap(),
512        );
513
514        let signal = MemoryBlockChannelSignal {
515            event: Notify::from_event(host_to_guest_event.clone())
516                .pollable(self.spawner.as_ref())
517                .unwrap(),
518            interrupt: guest_to_host_interrupt,
519        };
520        Ok(RawAsyncChannel {
521            in_ring,
522            out_ring,
523            signal: Box::new(signal),
524        })
525    }
526
527    /// Responds to the channel being revoked by the host.
528    async fn handle_revoke(&mut self, state: &mut SimpleVmbusClientDeviceTaskState) {
529        let Some(offer) = state.offer.as_ref() else {
530            return;
531        };
532        tracing::info!("device revoked");
533        if self.device.stop().await {
534            drop(self.device.remove());
535            self.device.task_mut().0.close(offer.offer.subchannel_index);
536        }
537        self.cleanup_device_resources(state).await;
538        drop(state.offer.take());
539    }
540
541    fn handle_save(&mut self) -> SavedStateBlob {
542        let saved_state = self.saved_state.take();
543        if let Some(saved_state) = saved_state {
544            let blob = SavedStateBlob::new(saved_state);
545            self.handle_restore(&blob);
546            blob
547        } else {
548            SavedStateBlob::new(NoSavedState)
549        }
550    }
551
552    fn handle_restore(&mut self, saved_state_blob: &SavedStateBlob) {
553        self.saved_state = match saved_state_blob.parse() {
554            Ok(saved_state) => Some(saved_state),
555            Err(err) => {
556                tracing::error!(
557                    err = &err as &dyn std::error::Error,
558                    "Protobuf conversion error saving state"
559                );
560                None
561            }
562        };
563    }
564
565    /// Handle vmbus messages from the host and control messages from the
566    /// device wrapper.
567    pub async fn process_messages(&mut self, state: &mut SimpleVmbusClientDeviceTaskState) {
568        loop {
569            #[expect(clippy::large_enum_variant)]
570            enum Event {
571                Request(InterceptChannelRequest),
572                Revoke,
573            }
574            let r = if let Some(offer) = &mut state.offer {
575                (
576                    (&mut state.recv_relay).map(Event::Request),
577                    futures::stream::once(&mut offer.revoke_recv).map(|_| Event::Revoke),
578                )
579                    .merge()
580                    .next()
581                    .await
582            } else {
583                let mut recv_relay = pin!(&mut state.recv_relay);
584                recv_relay.next().await.map(Event::Request)
585            };
586            let Some(r) = r else {
587                break;
588            };
589            match r {
590                Event::Revoke => {
591                    self.handle_revoke(state).await;
592                }
593                Event::Request(InterceptChannelRequest::Offer(offer)) => {
594                    // Any extraneous offer notifications (e.g. from a request offers
595                    // query) are ignored.
596                    if !self.device.is_running() {
597                        if let Err(err) = self.handle_offer(offer, state).await {
598                            tracing::error!(
599                                error = err.as_ref() as &dyn std::error::Error,
600                                "failed offer handling"
601                            );
602                        }
603                    }
604                }
605                Event::Request(InterceptChannelRequest::Start) => {
606                    self.handle_start(state).await;
607                }
608                Event::Request(InterceptChannelRequest::Stop(rpc)) => {
609                    rpc.handle(async |()| self.handle_stop(state).await).await;
610                }
611                Event::Request(InterceptChannelRequest::Save(rpc)) => {
612                    rpc.handle_sync(|()| self.handle_save());
613                }
614                Event::Request(InterceptChannelRequest::Restore(saved_state)) => {
615                    self.handle_restore(&saved_state);
616                }
617            }
618        }
619    }
620}
621
622struct MemoryBlockChannelSignal {
623    event: PolledNotify,
624    interrupt: Interrupt,
625}
626
627impl SignalVmbusChannel for MemoryBlockChannelSignal {
628    fn signal_remote(&self) {
629        self.interrupt.deliver();
630    }
631
632    fn poll_for_signal(
633        &self,
634        cx: &mut std::task::Context<'_>,
635    ) -> std::task::Poll<Result<(), ChannelClosed>> {
636        self.event.poll_wait(cx).map(Ok)
637    }
638}