vmbus_relay_intercept_device/
lib.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! This module contains logic used to intercept from VTL2 a vmbus device
5//! provided for a VTL0 guest. This requires the vmbus relay to be active,
6//! which will filter the device out from the list provided to the VTL0 guest
7//! and send any vmbus notifications for that device to the
8//! SimpleVmbusClientDeviceWrapper instance.
9
10#![expect(missing_docs)]
11#![forbid(unsafe_code)]
12
13pub mod ring_buffer;
14
15use crate::ring_buffer::MemoryBlockRingBuffer;
16use anyhow::Context;
17use anyhow::Result;
18use futures::StreamExt;
19use futures_concurrency::stream::Merge;
20use guid::Guid;
21use inspect::InspectMut;
22use mesh::rpc::RpcSend;
23use pal_async::driver::SpawnDriver;
24use std::future::Future;
25use std::future::pending;
26use std::pin::pin;
27use std::sync::Arc;
28use task_control::AsyncRun;
29use task_control::Cancelled;
30use task_control::InspectTaskMut;
31use task_control::StopTask;
32use task_control::TaskControl;
33use tracing::Instrument;
34use user_driver::DmaClient;
35use user_driver::memory::MemoryBlock;
36use vmbus_channel::ChannelClosed;
37use vmbus_channel::RawAsyncChannel;
38use vmbus_channel::SignalVmbusChannel;
39use vmbus_channel::bus::GpadlRequest;
40use vmbus_channel::bus::OpenData;
41use vmbus_client::ChannelRequest;
42use vmbus_client::OfferInfo;
43use vmbus_client::OpenOutput;
44use vmbus_client::OpenRequest;
45use vmbus_core::protocol::GpadlId;
46use vmbus_core::protocol::UserDefinedData;
47use vmbus_relay::InterceptChannelRequest;
48use vmbus_ring::IncomingRing;
49use vmbus_ring::OutgoingRing;
50use vmbus_ring::PAGE_SIZE;
51use vmcore::interrupt::Interrupt;
52use vmcore::notify::Notify;
53use vmcore::notify::PolledNotify;
54use vmcore::save_restore::NoSavedState;
55use vmcore::save_restore::SavedStateBlob;
56use vmcore::save_restore::SavedStateRoot;
57use zerocopy::FromZeros;
58
59pub enum OfferResponse {
60    Ignore,
61    Open,
62}
63
64pub trait SimpleVmbusClientDevice {
65    /// The saved state type.
66    type SavedState: SavedStateRoot + Send + Sync;
67
68    /// The type used to run an open channel.
69    type Runner: 'static + Send + Sync;
70
71    /// Inspects a channel.
72    fn inspect(&mut self, req: inspect::Request<'_>, runner: Option<&mut Self::Runner>);
73
74    /// Returns the instance ID of the matching device.
75    fn instance_id(&self) -> Guid;
76
77    /// Respond to a new channel offer for a device matching instance_id().
78    fn offer(&self, offer: &vmbus_core::protocol::OfferChannel) -> OfferResponse;
79
80    /// Open successful for the channel number `channel_idx`.
81    ///
82    /// When the channel is closed, the runner will be dropped.
83    fn open(
84        &mut self,
85        channel_idx: u16,
86        channel: RawAsyncChannel<MemoryBlockRingBuffer>,
87    ) -> Result<Self::Runner>;
88
89    /// Closes the channel number `channel_idx` after the runner has been
90    /// dropped.
91    fn close(&mut self, channel_idx: u16);
92
93    /// Returns a trait used to save/restore the device.
94    fn supports_save_restore(
95        &mut self,
96    ) -> Option<
97        &mut dyn SaveRestoreSimpleVmbusClientDevice<
98            SavedState = Self::SavedState,
99            Runner = Self::Runner,
100        >,
101    >;
102}
103
104pub trait SimpleVmbusClientDeviceAsync: SimpleVmbusClientDevice + 'static + Send + Sync {
105    /// Runs an open channel until `stop` is signaled.
106    fn run(
107        &mut self,
108        stop: &mut StopTask<'_>,
109        runner: &mut Self::Runner,
110    ) -> impl Send + Future<Output = Result<(), Cancelled>>;
111}
112
113/// Trait implemented by simple vmbus client devices that support save/restore.
114///
115/// If you implement this, make sure to return `Some(self)` from
116/// [`SimpleVmbusClientDevice::supports_save_restore`].
117pub trait SaveRestoreSimpleVmbusClientDevice: SimpleVmbusClientDevice {
118    /// Saves the channel.
119    ///
120    /// Will only be called if the channel is open.
121    fn save_open(&mut self, runner: &Self::Runner) -> Self::SavedState;
122
123    /// Restores the channel.
124    ///
125    /// Will only be called if the channel was saved open.
126    fn restore_open(
127        &mut self,
128        state: Self::SavedState,
129        channel: RawAsyncChannel<MemoryBlockRingBuffer>,
130    ) -> Result<Self::Runner>;
131}
132
133#[derive(InspectMut)]
134pub struct SimpleVmbusClientDeviceWrapper<T: SimpleVmbusClientDeviceAsync> {
135    instance_id: Guid,
136    #[inspect(skip)]
137    spawner: Arc<dyn SpawnDriver>,
138    #[inspect(mut)]
139    vmbus_listener: TaskControl<SimpleVmbusClientDeviceTask<T>, SimpleVmbusClientDeviceTaskState>,
140}
141
142impl<T: SimpleVmbusClientDeviceAsync> SimpleVmbusClientDeviceWrapper<T> {
143    /// Create a new instance.
144    pub fn new(
145        driver: impl SpawnDriver + Clone,
146        dma_alloc: Arc<dyn DmaClient>,
147        device: T,
148    ) -> Result<Self> {
149        let spawner = Arc::new(driver.clone());
150        Ok(Self {
151            instance_id: device.instance_id(),
152            vmbus_listener: TaskControl::new(SimpleVmbusClientDeviceTask::new(
153                device,
154                spawner.clone(),
155                dma_alloc,
156            )),
157            spawner,
158        })
159    }
160
161    pub fn instance_id(&self) -> Guid {
162        self.instance_id
163    }
164
165    pub fn detach(
166        mut self,
167        driver: impl SpawnDriver,
168        recv_relay: mesh::Receiver<InterceptChannelRequest>,
169    ) -> Result<()> {
170        let (send_disconnected, recv_disconnected) = mesh::oneshot();
171        self.vmbus_listener.insert(
172            &self.spawner,
173            format!("{}", self.instance_id),
174            SimpleVmbusClientDeviceTaskState {
175                offer: None,
176                recv_relay,
177                send_disconnected: Some(send_disconnected),
178                vtl_pages: None,
179            },
180        );
181        driver
182            .spawn(
183                format!("vmbus_relay_device {}", self.instance_id),
184                async move {
185                    self.vmbus_listener.start();
186                    let _ = recv_disconnected.await;
187                    assert!(!self.vmbus_listener.stop().await);
188                    if self.vmbus_listener.state().unwrap().vtl_pages.is_some() {
189                        // The VTL pages were not freed. This can occur if an
190                        // error is hit that drops the vmbus parent tasks. Just
191                        // pend here and let the outer error cause the VM to
192                        // exit.
193                        pending::<()>().await;
194                    }
195                },
196            )
197            .detach();
198        Ok(())
199    }
200}
201
202struct RelayDeviceTask<T>(T);
203
204impl<T: SimpleVmbusClientDeviceAsync> AsyncRun<T::Runner> for RelayDeviceTask<T> {
205    async fn run(
206        &mut self,
207        stop: &mut StopTask<'_>,
208        runner: &mut T::Runner,
209    ) -> Result<(), Cancelled> {
210        self.0.run(stop, runner).await
211    }
212}
213
214impl<T: SimpleVmbusClientDeviceAsync> InspectTaskMut<T::Runner> for RelayDeviceTask<T> {
215    fn inspect_mut(&mut self, req: inspect::Request<'_>, runner: Option<&mut T::Runner>) {
216        self.0.inspect(req, runner)
217    }
218}
219
220#[derive(InspectMut)]
221struct SimpleVmbusClientDeviceTaskState {
222    offer: Option<OfferInfo>,
223    #[inspect(skip)]
224    recv_relay: mesh::Receiver<InterceptChannelRequest>,
225    #[inspect(skip)]
226    send_disconnected: Option<mesh::OneshotSender<()>>,
227    #[inspect(hex, with = "|x| x.as_ref().map(|x| inspect::iter_by_index(x.pfns()))")]
228    vtl_pages: Option<MemoryBlock>,
229}
230
231struct SimpleVmbusClientDeviceTask<T: SimpleVmbusClientDeviceAsync> {
232    device: TaskControl<RelayDeviceTask<T>, T::Runner>,
233    saved_state: Option<T::SavedState>,
234    spawner: Arc<dyn SpawnDriver>,
235    dma_alloc: Arc<dyn DmaClient>,
236}
237
238impl<T: SimpleVmbusClientDeviceAsync> AsyncRun<SimpleVmbusClientDeviceTaskState>
239    for SimpleVmbusClientDeviceTask<T>
240{
241    async fn run(
242        &mut self,
243        stop: &mut StopTask<'_>,
244        state: &mut SimpleVmbusClientDeviceTaskState,
245    ) -> Result<(), Cancelled> {
246        stop.until_stopped(self.process_messages(state)).await?;
247        state
248            .send_disconnected
249            .take()
250            .expect("task should not be restarted")
251            .send(());
252        Ok(())
253    }
254}
255
256impl<T: SimpleVmbusClientDeviceAsync> InspectTaskMut<SimpleVmbusClientDeviceTaskState>
257    for SimpleVmbusClientDeviceTask<T>
258{
259    fn inspect_mut(
260        &mut self,
261        req: inspect::Request<'_>,
262        state: Option<&mut SimpleVmbusClientDeviceTaskState>,
263    ) {
264        req.respond()
265            .merge(state)
266            .field_mut("device", &mut self.device)
267            .field("dma_alloc", &self.dma_alloc);
268    }
269}
270
271impl<T: SimpleVmbusClientDeviceAsync> SimpleVmbusClientDeviceTask<T> {
272    pub fn new(device: T, spawner: Arc<dyn SpawnDriver>, dma_alloc: Arc<dyn DmaClient>) -> Self {
273        Self {
274            device: TaskControl::new(RelayDeviceTask(device)),
275            saved_state: None,
276            spawner,
277            dma_alloc,
278        }
279    }
280
281    fn insert_runner(&mut self, state: &SimpleVmbusClientDeviceTaskState, runner: T::Runner) {
282        let offer = state.offer.as_ref().unwrap().offer;
283        self.device.insert(
284            &self.spawner,
285            format!("{}-{}", offer.interface_id, offer.instance_id),
286            runner,
287        );
288    }
289
290    /// Configures channel.
291    async fn handle_offer(
292        &mut self,
293        offer: OfferInfo,
294        state: &mut SimpleVmbusClientDeviceTaskState,
295    ) -> Result<()> {
296        tracing::info!(?offer, "matching channel offered");
297
298        if offer.offer.is_dedicated != 1 {
299            tracing::warn!(offer = ?offer.offer, "All offers should be dedicated with Win8+ host")
300        }
301
302        if matches!(
303            self.device.task_mut().0.offer(&offer.offer),
304            OfferResponse::Ignore
305        ) {
306            return Ok(());
307        }
308
309        let interrupt_event = pal_event::Event::new();
310        let (memory, ring_gpadl_id) = self
311            .reserve_memory(state, &offer.request_send, 4)
312            .await
313            .context("reserve memory")?;
314        let guest_to_host_interrupt = offer.guest_to_host_interrupt.clone();
315        state.offer = Some(offer);
316        let offer = state.offer.as_ref().unwrap();
317        self.open_channel(&offer.request_send, ring_gpadl_id, &interrupt_event)
318            .await
319            .context("open channel")?;
320        let channel = self
321            .create_vmbus_channel(&memory, &interrupt_event, guest_to_host_interrupt)
322            .context("create vmbus queue")?;
323
324        let save_restore = self.device.task_mut().0.supports_save_restore();
325        let saved_state = self.saved_state.take();
326        let device_runner = if save_restore.is_some() && saved_state.is_some() {
327            save_restore
328                .unwrap()
329                .restore_open(saved_state.unwrap(), channel)
330                .context("device restore_open callback")?
331        } else {
332            self.device
333                .task_mut()
334                .0
335                .open(offer.offer.subchannel_index, channel)
336                .context("device open callback")?
337        };
338        self.insert_runner(state, device_runner);
339        self.device.start();
340        Ok(())
341    }
342
343    /// Start channel after it has been stopped.
344    async fn handle_start(&mut self, state: &mut SimpleVmbusClientDeviceTaskState) {
345        if self.device.is_running() {
346            return;
347        }
348
349        let offer = state.offer.take();
350        if offer.is_none() {
351            return;
352        }
353
354        // If there is a previous valid offer, open the channel again.
355        if let Err(err) = self.handle_offer(offer.unwrap(), state).await {
356            tracing::error!(
357                err = err.as_ref() as &dyn std::error::Error,
358                "Failed to reconnect vmbus channel"
359            );
360        }
361    }
362
363    async fn cleanup_device_resources(&mut self, state: &mut SimpleVmbusClientDeviceTaskState) {
364        let Some(offer) = state.offer.as_mut() else {
365            return;
366        };
367
368        if state.vtl_pages.is_some() {
369            match offer
370                .request_send
371                .call(
372                    ChannelRequest::TeardownGpadl,
373                    GpadlId(state.vtl_pages.as_ref().unwrap().pfns()[1] as u32),
374                )
375                .await
376            {
377                Ok(()) => {
378                    state.vtl_pages = None;
379                }
380                Err(err) => {
381                    // If the ring buffer pages are still in use by the host, which
382                    // has to be assumed, the memory pages cannot be used again as
383                    // they have been marked as visible to VTL0.
384                    tracing::error!(
385                        error = &err as &dyn std::error::Error,
386                        "Failed to teardown gpadl -- leaking memory."
387                    );
388                }
389            }
390        }
391    }
392
393    /// Stop channel
394    async fn handle_stop(&mut self, state: &mut SimpleVmbusClientDeviceTaskState) {
395        if !self.device.stop().await {
396            return;
397        }
398
399        // Close the channel on every stop. Overlay devices cannot be saved /
400        // restored because the physical pages used for the ring buffer, et al.
401        // would need to be reserved at boot, otherwise the host may end up
402        // scribbling on random memory as it continues updating a ring buffer it
403        // assumes it has ownership of.
404        //
405        // TODO: We could support save restore, if we had a pool of memory that
406        // supports that. This should be possible once the page_pool_alloc is
407        // available everywhere.
408        {
409            let offer = state.offer.as_ref().expect("device opened");
410            offer
411                .request_send
412                .call(ChannelRequest::Close, ())
413                .await
414                .ok();
415        }
416        // N.B. This will wait for a TeardownGpadl response which can be used
417        // as a signal that the channel is closed and the ring buffers are no
418        // longer in use.
419        self.cleanup_device_resources(state).await;
420        let runner = self.device.remove();
421        let device = self.device.task_mut();
422        if let Some(save_restore) = device.0.supports_save_restore() {
423            self.saved_state = Some(save_restore.save_open(&runner));
424        }
425        drop(runner);
426        let offer = state.offer.as_ref().expect("device opened");
427        device.0.close(offer.offer.subchannel_index);
428    }
429
430    /// Allocates memory to be shared with the host and registers it with a
431    /// GPADL ID.
432    async fn reserve_memory(
433        &mut self,
434        state: &mut SimpleVmbusClientDeviceTaskState,
435        request_send: &mesh::Sender<ChannelRequest>,
436        page_count: usize,
437    ) -> Result<(MemoryBlock, GpadlId)> {
438        // Incoming and outgoing rings require a minimum of two pages apiece:
439        // one for the control bytes and at least one for the ring.
440        assert!(page_count >= 4);
441
442        let mem = self
443            .dma_alloc
444            .allocate_dma_buffer(page_count * PAGE_SIZE)
445            .context("allocating memory for vmbus rings")?;
446        state.vtl_pages = Some(mem.clone());
447        let buf: Vec<_> = [mem.len() as u64]
448            .iter()
449            .chain(mem.pfns())
450            .copied()
451            .collect();
452
453        let gpadl_id = GpadlId(state.vtl_pages.as_ref().unwrap().pfns()[1] as u32);
454        request_send
455            .call_failable(
456                ChannelRequest::Gpadl,
457                GpadlRequest {
458                    id: gpadl_id,
459                    count: 1,
460                    buf,
461                },
462            )
463            .await
464            .context("registering gpadl")?;
465        Ok((mem, gpadl_id))
466    }
467
468    /// Open the channel offered by the host.
469    async fn open_channel(
470        &self,
471        request_send: &mesh::Sender<ChannelRequest>,
472        ring_gpadl_id: GpadlId,
473        event: &pal_event::Event,
474    ) -> Result<OpenOutput> {
475        let open_request = OpenRequest {
476            open_data: OpenData {
477                target_vp: 0,
478                ring_offset: 2,
479                ring_gpadl_id,
480                event_flag: !0,
481                connection_id: !0,
482                user_data: UserDefinedData::new_zeroed(),
483            },
484            incoming_event: Some(event.clone()),
485            use_vtl2_connection_id: true,
486        };
487
488        request_send
489            .call_failable(ChannelRequest::Open, open_request)
490            .instrument(tracing::info_span!(
491                "opening vmbus channel for intercepted device"
492            ))
493            .await
494            .context("open vmbus channel")
495    }
496
497    /// Create a raw vmbus channel.
498    fn create_vmbus_channel(
499        &self,
500        mem: &MemoryBlock,
501        host_to_guest_event: &pal_event::Event,
502        guest_to_host_interrupt: Interrupt,
503    ) -> Result<RawAsyncChannel<MemoryBlockRingBuffer>> {
504        let (out_ring_mem, in_ring_mem) = (
505            mem.subblock(0, 2 * PAGE_SIZE),
506            mem.subblock(2 * PAGE_SIZE, 2 * PAGE_SIZE),
507        );
508        let (in_ring, out_ring) = (
509            IncomingRing::new(in_ring_mem.into()).unwrap(),
510            OutgoingRing::new(out_ring_mem.into()).unwrap(),
511        );
512
513        let signal = MemoryBlockChannelSignal {
514            event: Notify::from_event(host_to_guest_event.clone())
515                .pollable(self.spawner.as_ref())
516                .unwrap(),
517            interrupt: guest_to_host_interrupt,
518        };
519        Ok(RawAsyncChannel {
520            in_ring,
521            out_ring,
522            signal: Box::new(signal),
523        })
524    }
525
526    /// Responds to the channel being revoked by the host.
527    async fn handle_revoke(&mut self, state: &mut SimpleVmbusClientDeviceTaskState) {
528        let Some(offer) = state.offer.as_ref() else {
529            return;
530        };
531        tracing::info!("device revoked");
532        if self.device.stop().await {
533            drop(self.device.remove());
534            self.device.task_mut().0.close(offer.offer.subchannel_index);
535        }
536        self.cleanup_device_resources(state).await;
537        drop(state.offer.take());
538    }
539
540    fn handle_save(&mut self) -> SavedStateBlob {
541        let saved_state = self.saved_state.take();
542        if let Some(saved_state) = saved_state {
543            let blob = SavedStateBlob::new(saved_state);
544            self.handle_restore(&blob);
545            blob
546        } else {
547            SavedStateBlob::new(NoSavedState)
548        }
549    }
550
551    fn handle_restore(&mut self, saved_state_blob: &SavedStateBlob) {
552        self.saved_state = match saved_state_blob.parse() {
553            Ok(saved_state) => Some(saved_state),
554            Err(err) => {
555                tracing::error!(
556                    err = &err as &dyn std::error::Error,
557                    "Protobuf conversion error saving state"
558                );
559                None
560            }
561        };
562    }
563
564    /// Handle vmbus messages from the host and control messages from the
565    /// device wrapper.
566    pub async fn process_messages(&mut self, state: &mut SimpleVmbusClientDeviceTaskState) {
567        loop {
568            #[expect(clippy::large_enum_variant)]
569            enum Event {
570                Request(InterceptChannelRequest),
571                Revoke,
572            }
573            let r = if let Some(offer) = &mut state.offer {
574                (
575                    (&mut state.recv_relay).map(Event::Request),
576                    futures::stream::once(&mut offer.revoke_recv).map(|_| Event::Revoke),
577                )
578                    .merge()
579                    .next()
580                    .await
581            } else {
582                let mut recv_relay = pin!(&mut state.recv_relay);
583                recv_relay.next().await.map(Event::Request)
584            };
585            let Some(r) = r else {
586                break;
587            };
588            match r {
589                Event::Revoke => {
590                    self.handle_revoke(state).await;
591                }
592                Event::Request(InterceptChannelRequest::Offer(offer)) => {
593                    // Any extraneous offer notifications (e.g. from a request offers
594                    // query) are ignored.
595                    if !self.device.is_running() {
596                        if let Err(err) = self.handle_offer(offer, state).await {
597                            tracing::error!(
598                                error = err.as_ref() as &dyn std::error::Error,
599                                "failed offer handling"
600                            );
601                        }
602                    }
603                }
604                Event::Request(InterceptChannelRequest::Start) => {
605                    self.handle_start(state).await;
606                }
607                Event::Request(InterceptChannelRequest::Stop(rpc)) => {
608                    rpc.handle(async |()| self.handle_stop(state).await).await;
609                }
610                Event::Request(InterceptChannelRequest::Save(rpc)) => {
611                    rpc.handle_sync(|()| self.handle_save());
612                }
613                Event::Request(InterceptChannelRequest::Restore(saved_state)) => {
614                    self.handle_restore(&saved_state);
615                }
616            }
617        }
618    }
619}
620
621struct MemoryBlockChannelSignal {
622    event: PolledNotify,
623    interrupt: Interrupt,
624}
625
626impl SignalVmbusChannel for MemoryBlockChannelSignal {
627    fn signal_remote(&self) {
628        self.interrupt.deliver();
629    }
630
631    fn poll_for_signal(
632        &self,
633        cx: &mut std::task::Context<'_>,
634    ) -> std::task::Poll<Result<(), ChannelClosed>> {
635        self.event.poll_wait(cx).map(Ok)
636    }
637}