vmbus_relay_intercept_device/
lib.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! This module contains logic used to intercept from VTL2 a vmbus device
5//! provided for a VTL0 guest. This requires the vmbus relay to be active,
6//! which will filter the device out from the list provided to the VTL0 guest
7//! and send any vmbus notifications for that device to the
8//! SimpleVmbusClientDeviceWrapper instance.
9
10#![cfg(target_os = "linux")]
11#![expect(missing_docs)]
12#![forbid(unsafe_code)]
13
14pub mod ring_buffer;
15
16use crate::ring_buffer::MemoryBlockRingBuffer;
17use anyhow::Context;
18use anyhow::Result;
19use futures::StreamExt;
20use futures_concurrency::stream::Merge;
21use guid::Guid;
22use inspect::InspectMut;
23use mesh::rpc::RpcSend;
24use pal_async::driver::SpawnDriver;
25use std::future::Future;
26use std::future::pending;
27use std::pin::pin;
28use std::sync::Arc;
29use task_control::AsyncRun;
30use task_control::Cancelled;
31use task_control::InspectTaskMut;
32use task_control::StopTask;
33use task_control::TaskControl;
34use tracing::Instrument;
35use user_driver::DmaClient;
36use user_driver::memory::MemoryBlock;
37use vmbus_channel::ChannelClosed;
38use vmbus_channel::RawAsyncChannel;
39use vmbus_channel::SignalVmbusChannel;
40use vmbus_channel::bus::GpadlRequest;
41use vmbus_channel::bus::OpenData;
42use vmbus_client::ChannelRequest;
43use vmbus_client::OfferInfo;
44use vmbus_client::OpenOutput;
45use vmbus_client::OpenRequest;
46use vmbus_core::protocol::GpadlId;
47use vmbus_core::protocol::UserDefinedData;
48use vmbus_relay::InterceptChannelRequest;
49use vmbus_ring::IncomingRing;
50use vmbus_ring::OutgoingRing;
51use vmbus_ring::PAGE_SIZE;
52use vmcore::interrupt::Interrupt;
53use vmcore::notify::Notify;
54use vmcore::notify::PolledNotify;
55use vmcore::save_restore::NoSavedState;
56use vmcore::save_restore::SavedStateBlob;
57use vmcore::save_restore::SavedStateRoot;
58use zerocopy::FromZeros;
59
60pub enum OfferResponse {
61    Ignore,
62    Open,
63}
64
65pub trait SimpleVmbusClientDevice {
66    /// The saved state type.
67    type SavedState: SavedStateRoot + Send + Sync;
68
69    /// The type used to run an open channel.
70    type Runner: 'static + Send + Sync;
71
72    /// Inspects a channel.
73    fn inspect(&mut self, req: inspect::Request<'_>, runner: Option<&mut Self::Runner>);
74
75    /// Returns the instance ID of the matching device.
76    fn instance_id(&self) -> Guid;
77
78    /// Respond to a new channel offer for a device matching instance_id().
79    fn offer(&self, offer: &vmbus_core::protocol::OfferChannel) -> OfferResponse;
80
81    /// Open successful for the channel number `channel_idx`.
82    ///
83    /// When the channel is closed, the runner will be dropped.
84    fn open(
85        &mut self,
86        channel_idx: u16,
87        channel: RawAsyncChannel<MemoryBlockRingBuffer>,
88    ) -> Result<Self::Runner>;
89
90    /// Closes the channel number `channel_idx` after the runner has been
91    /// dropped.
92    fn close(&mut self, channel_idx: u16);
93
94    /// Returns a trait used to save/restore the device.
95    fn supports_save_restore(
96        &mut self,
97    ) -> Option<
98        &mut dyn SaveRestoreSimpleVmbusClientDevice<
99            SavedState = Self::SavedState,
100            Runner = Self::Runner,
101        >,
102    >;
103}
104
105pub trait SimpleVmbusClientDeviceAsync: SimpleVmbusClientDevice + 'static + Send + Sync {
106    /// Runs an open channel until `stop` is signaled.
107    fn run(
108        &mut self,
109        stop: &mut StopTask<'_>,
110        runner: &mut Self::Runner,
111    ) -> impl Send + Future<Output = Result<(), Cancelled>>;
112}
113
114/// Trait implemented by simple vmbus client devices that support save/restore.
115///
116/// If you implement this, make sure to return `Some(self)` from
117/// [`SimpleVmbusClientDevice::supports_save_restore`].
118pub trait SaveRestoreSimpleVmbusClientDevice: SimpleVmbusClientDevice {
119    /// Saves the channel.
120    ///
121    /// Will only be called if the channel is open.
122    fn save_open(&mut self, runner: &Self::Runner) -> Self::SavedState;
123
124    /// Restores the channel.
125    ///
126    /// Will only be called if the channel was saved open.
127    fn restore_open(
128        &mut self,
129        state: Self::SavedState,
130        channel: RawAsyncChannel<MemoryBlockRingBuffer>,
131    ) -> Result<Self::Runner>;
132}
133
134#[derive(InspectMut)]
135pub struct SimpleVmbusClientDeviceWrapper<T: SimpleVmbusClientDeviceAsync> {
136    instance_id: Guid,
137    #[inspect(skip)]
138    spawner: Arc<dyn SpawnDriver>,
139    #[inspect(mut)]
140    vmbus_listener: TaskControl<SimpleVmbusClientDeviceTask<T>, SimpleVmbusClientDeviceTaskState>,
141}
142
143impl<T: SimpleVmbusClientDeviceAsync> SimpleVmbusClientDeviceWrapper<T> {
144    /// Create a new instance.
145    pub fn new(
146        driver: impl SpawnDriver + Clone,
147        dma_alloc: Arc<dyn DmaClient>,
148        device: T,
149    ) -> Result<Self> {
150        let spawner = Arc::new(driver.clone());
151        Ok(Self {
152            instance_id: device.instance_id(),
153            vmbus_listener: TaskControl::new(SimpleVmbusClientDeviceTask::new(
154                device,
155                spawner.clone(),
156                dma_alloc,
157            )),
158            spawner,
159        })
160    }
161
162    pub fn instance_id(&self) -> Guid {
163        self.instance_id
164    }
165
166    pub fn detach(
167        mut self,
168        driver: impl SpawnDriver,
169        recv_relay: mesh::Receiver<InterceptChannelRequest>,
170    ) -> Result<mesh::OneshotSender<()>> {
171        self.vmbus_listener.insert(
172            &self.spawner,
173            format!("{}", self.instance_id),
174            SimpleVmbusClientDeviceTaskState {
175                offer: None,
176                recv_relay,
177                vtl_pages: None,
178            },
179        );
180        let (driver_send, driver_recv) = mesh::oneshot();
181        driver
182            .spawn(
183                format!("vmbus_relay_device {}", self.instance_id),
184                async move {
185                    self.vmbus_listener.start();
186                    let _ = driver_recv.await;
187                    self.vmbus_listener.stop().await;
188                },
189            )
190            .detach();
191        Ok(driver_send)
192    }
193}
194
195struct RelayDeviceTask<T>(T);
196
197impl<T: SimpleVmbusClientDeviceAsync> AsyncRun<T::Runner> for RelayDeviceTask<T> {
198    async fn run(
199        &mut self,
200        stop: &mut StopTask<'_>,
201        runner: &mut T::Runner,
202    ) -> Result<(), Cancelled> {
203        self.0.run(stop, runner).await
204    }
205}
206
207impl<T: SimpleVmbusClientDeviceAsync> InspectTaskMut<T::Runner> for RelayDeviceTask<T> {
208    fn inspect_mut(&mut self, req: inspect::Request<'_>, runner: Option<&mut T::Runner>) {
209        self.0.inspect(req, runner)
210    }
211}
212
213#[derive(InspectMut)]
214struct SimpleVmbusClientDeviceTaskState {
215    offer: Option<OfferInfo>,
216    #[inspect(skip)]
217    recv_relay: mesh::Receiver<InterceptChannelRequest>,
218    #[inspect(hex, with = "|x| x.as_ref().map(|x| inspect::iter_by_index(x.pfns()))")]
219    vtl_pages: Option<MemoryBlock>,
220}
221
222struct SimpleVmbusClientDeviceTask<T: SimpleVmbusClientDeviceAsync> {
223    device: TaskControl<RelayDeviceTask<T>, T::Runner>,
224    saved_state: Option<T::SavedState>,
225    spawner: Arc<dyn SpawnDriver>,
226    dma_alloc: Arc<dyn DmaClient>,
227}
228
229impl<T: SimpleVmbusClientDeviceAsync> AsyncRun<SimpleVmbusClientDeviceTaskState>
230    for SimpleVmbusClientDeviceTask<T>
231{
232    async fn run(
233        &mut self,
234        stop: &mut StopTask<'_>,
235        state: &mut SimpleVmbusClientDeviceTaskState,
236    ) -> Result<(), Cancelled> {
237        stop.until_stopped(self.process_messages(state)).await
238    }
239}
240
241impl<T: SimpleVmbusClientDeviceAsync> InspectTaskMut<SimpleVmbusClientDeviceTaskState>
242    for SimpleVmbusClientDeviceTask<T>
243{
244    fn inspect_mut(
245        &mut self,
246        req: inspect::Request<'_>,
247        state: Option<&mut SimpleVmbusClientDeviceTaskState>,
248    ) {
249        req.respond()
250            .merge(state)
251            .field_mut("device", &mut self.device)
252            .field("dma_alloc", &self.dma_alloc);
253    }
254}
255
256impl<T: SimpleVmbusClientDeviceAsync> SimpleVmbusClientDeviceTask<T> {
257    pub fn new(device: T, spawner: Arc<dyn SpawnDriver>, dma_alloc: Arc<dyn DmaClient>) -> Self {
258        Self {
259            device: TaskControl::new(RelayDeviceTask(device)),
260            saved_state: None,
261            spawner,
262            dma_alloc,
263        }
264    }
265
266    fn insert_runner(&mut self, state: &SimpleVmbusClientDeviceTaskState, runner: T::Runner) {
267        let offer = state.offer.as_ref().unwrap().offer;
268        self.device.insert(
269            &self.spawner,
270            format!("{}-{}", offer.interface_id, offer.instance_id),
271            runner,
272        );
273    }
274
275    /// Configures channel.
276    async fn handle_offer(
277        &mut self,
278        offer: OfferInfo,
279        state: &mut SimpleVmbusClientDeviceTaskState,
280    ) -> Result<()> {
281        tracing::info!(?offer, "matching channel offered");
282
283        if offer.offer.is_dedicated != 1 {
284            tracing::warn!(offer = ?offer.offer, "All offers should be dedicated with Win8+ host")
285        }
286
287        if matches!(
288            self.device.task_mut().0.offer(&offer.offer),
289            OfferResponse::Ignore
290        ) {
291            return Ok(());
292        }
293
294        let interrupt_event = pal_event::Event::new();
295        let (memory, ring_gpadl_id) = self
296            .reserve_memory(state, &offer.request_send, 4)
297            .await
298            .context("reserve memory")?;
299        state.offer = Some(offer);
300        let offer = state.offer.as_ref().unwrap();
301        let opened = self
302            .open_channel(&offer.request_send, ring_gpadl_id, &interrupt_event)
303            .await
304            .context("open channel")?;
305        let channel = self
306            .create_vmbus_channel(&memory, &interrupt_event, opened.guest_to_host_signal)
307            .context("create vmbus queue")?;
308
309        let save_restore = self.device.task_mut().0.supports_save_restore();
310        let saved_state = self.saved_state.take();
311        let device_runner = if save_restore.is_some() && saved_state.is_some() {
312            save_restore
313                .unwrap()
314                .restore_open(saved_state.unwrap(), channel)
315                .context("device restore_open callback")?
316        } else {
317            self.device
318                .task_mut()
319                .0
320                .open(offer.offer.subchannel_index, channel)
321                .context("device open callback")?
322        };
323        self.insert_runner(state, device_runner);
324        self.device.start();
325        Ok(())
326    }
327
328    /// Start channel after it has been stopped.
329    async fn handle_start(&mut self, state: &mut SimpleVmbusClientDeviceTaskState) {
330        if self.device.is_running() {
331            return;
332        }
333
334        let offer = state.offer.take();
335        if offer.is_none() {
336            return;
337        }
338
339        // If there is a previous valid offer, open the channel again.
340        if let Err(err) = self.handle_offer(offer.unwrap(), state).await {
341            tracing::error!(
342                err = err.as_ref() as &dyn std::error::Error,
343                "Failed to reconnect vmbus channel"
344            );
345        }
346    }
347
348    async fn cleanup_device_resources(&mut self, state: &mut SimpleVmbusClientDeviceTaskState) {
349        let Some(offer) = state.offer.as_mut() else {
350            return;
351        };
352
353        if state.vtl_pages.is_some() {
354            if let Err(err) = offer
355                .request_send
356                .call(
357                    ChannelRequest::TeardownGpadl,
358                    GpadlId(state.vtl_pages.as_ref().unwrap().pfns()[1] as u32),
359                )
360                .await
361            {
362                tracing::error!(
363                    error = &err as &dyn std::error::Error,
364                    "failed to teardown gpadl"
365                );
366            }
367
368            state.vtl_pages = None;
369        }
370    }
371
372    /// Stop channel
373    async fn handle_stop(&mut self, state: &mut SimpleVmbusClientDeviceTaskState) {
374        if !self.device.stop().await {
375            return;
376        }
377
378        // Close the channel on every stop. Overlay devices cannot be saved /
379        // restored because the physical pages used for the ring buffer, et al.
380        // would need to be reserved at boot, otherwise the host may end up
381        // scribbling on random memory as it continues updating a ring buffer it
382        // assumes it has ownership of.
383        //
384        // TODO: We could support save restore, if we had a pool of memory that
385        // supports that. This should be possible once the page_pool_alloc is
386        // available everywhere.
387        {
388            let offer = state.offer.as_ref().expect("device opened");
389            offer
390                .request_send
391                .call(ChannelRequest::Close, ())
392                .await
393                .ok();
394        }
395        // N.B. This will wait for a TeardownGpadl response which can be used
396        // as a signal that the channel is closed and the ring buffers are no
397        // longer in use.
398        self.cleanup_device_resources(state).await;
399        let runner = self.device.remove();
400        let device = self.device.task_mut();
401        if let Some(save_restore) = device.0.supports_save_restore() {
402            self.saved_state = Some(save_restore.save_open(&runner));
403        }
404        drop(runner);
405        let offer = state.offer.as_ref().expect("device opened");
406        device.0.close(offer.offer.subchannel_index);
407    }
408
409    /// Allocates memory to be shared with the host and registers it with a
410    /// GPADL ID.
411    async fn reserve_memory(
412        &mut self,
413        state: &mut SimpleVmbusClientDeviceTaskState,
414        request_send: &mesh::Sender<ChannelRequest>,
415        page_count: usize,
416    ) -> Result<(MemoryBlock, GpadlId)> {
417        // Incoming and outgoing rings require a minimum of two pages apiece:
418        // one for the control bytes and at least one for the ring.
419        assert!(page_count >= 4);
420
421        let mem = self
422            .dma_alloc
423            .allocate_dma_buffer(page_count * PAGE_SIZE)
424            .context("allocating memory for vmbus rings")?;
425        state.vtl_pages = Some(mem.clone());
426        let buf: Vec<_> = [mem.len() as u64]
427            .iter()
428            .chain(mem.pfns())
429            .copied()
430            .collect();
431
432        let gpadl_id = GpadlId(state.vtl_pages.as_ref().unwrap().pfns()[1] as u32);
433        request_send
434            .call_failable(
435                ChannelRequest::Gpadl,
436                GpadlRequest {
437                    id: gpadl_id,
438                    count: 1,
439                    buf,
440                },
441            )
442            .await
443            .context("registering gpadl")?;
444        Ok((mem, gpadl_id))
445    }
446
447    /// Open the channel offered by the host.
448    async fn open_channel(
449        &self,
450        request_send: &mesh::Sender<ChannelRequest>,
451        ring_gpadl_id: GpadlId,
452        event: &pal_event::Event,
453    ) -> Result<OpenOutput> {
454        let open_request = OpenRequest {
455            open_data: OpenData {
456                target_vp: 0,
457                ring_offset: 2,
458                ring_gpadl_id,
459                event_flag: !0,
460                connection_id: !0,
461                user_data: UserDefinedData::new_zeroed(),
462            },
463            incoming_event: Some(event.clone()),
464            use_vtl2_connection_id: true,
465        };
466
467        request_send
468            .call_failable(ChannelRequest::Open, open_request)
469            .instrument(tracing::info_span!(
470                "opening vmbus channel for intercepted device"
471            ))
472            .await
473            .context("open vmbus channel")
474    }
475
476    /// Create a raw vmbus channel.
477    fn create_vmbus_channel(
478        &self,
479        mem: &MemoryBlock,
480        host_to_guest_event: &pal_event::Event,
481        guest_to_host_interrupt: Interrupt,
482    ) -> Result<RawAsyncChannel<MemoryBlockRingBuffer>> {
483        let (out_ring_mem, in_ring_mem) = (
484            mem.subblock(0, 2 * PAGE_SIZE),
485            mem.subblock(2 * PAGE_SIZE, 2 * PAGE_SIZE),
486        );
487        let (in_ring, out_ring) = (
488            IncomingRing::new(in_ring_mem.into()).unwrap(),
489            OutgoingRing::new(out_ring_mem.into()).unwrap(),
490        );
491
492        let signal = MemoryBlockChannelSignal {
493            event: Notify::from_event(host_to_guest_event.clone())
494                .pollable(self.spawner.as_ref())
495                .unwrap(),
496            interrupt: guest_to_host_interrupt,
497        };
498        Ok(RawAsyncChannel {
499            in_ring,
500            out_ring,
501            signal: Box::new(signal),
502        })
503    }
504
505    /// Responds to the channel being revoked by the host.
506    async fn handle_revoke(&mut self, state: &mut SimpleVmbusClientDeviceTaskState) {
507        let Some(offer) = state.offer.take() else {
508            return;
509        };
510        tracing::info!("device revoked");
511        if self.device.stop().await {
512            drop(self.device.remove());
513            self.device.task_mut().0.close(offer.offer.subchannel_index);
514        }
515        self.cleanup_device_resources(state).await;
516    }
517
518    fn handle_save(&mut self) -> SavedStateBlob {
519        let saved_state = self.saved_state.take();
520        if saved_state.is_some() {
521            let blob = SavedStateBlob::new(saved_state.unwrap());
522            self.handle_restore(&blob);
523            blob
524        } else {
525            SavedStateBlob::new(NoSavedState)
526        }
527    }
528
529    fn handle_restore(&mut self, saved_state_blob: &SavedStateBlob) {
530        self.saved_state = match saved_state_blob.parse() {
531            Ok(saved_state) => Some(saved_state),
532            Err(err) => {
533                tracing::error!(
534                    err = &err as &dyn std::error::Error,
535                    "Protobuf conversion error saving state"
536                );
537                None
538            }
539        };
540    }
541
542    /// Handle vmbus messages from the host and control messages from the
543    /// device wrapper.
544    pub async fn process_messages(&mut self, state: &mut SimpleVmbusClientDeviceTaskState) {
545        loop {
546            #[expect(clippy::large_enum_variant)]
547            enum Event {
548                Request(InterceptChannelRequest),
549                Revoke(()),
550            }
551            let revoke = pin!(async {
552                if let Some(offer) = &mut state.offer {
553                    (&mut offer.revoke_recv).await.ok();
554                } else {
555                    pending().await
556                }
557            });
558            let Some(r) = (
559                (&mut state.recv_relay).map(Event::Request),
560                futures::stream::once(revoke).map(Event::Revoke),
561            )
562                .merge()
563                .next()
564                .await
565            else {
566                break;
567            };
568            match r {
569                Event::Revoke(()) => {
570                    self.handle_revoke(state).await;
571                }
572                Event::Request(InterceptChannelRequest::Offer(offer)) => {
573                    // Any extraneous offer notifications (e.g. from a request offers
574                    // query) are ignored.
575                    if !self.device.is_running() {
576                        if let Err(err) = self.handle_offer(offer, state).await {
577                            tracing::error!(
578                                error = err.as_ref() as &dyn std::error::Error,
579                                "failed offer handling"
580                            );
581                        }
582                    }
583                }
584                Event::Request(InterceptChannelRequest::Start) => {
585                    self.handle_start(state).await;
586                }
587                Event::Request(InterceptChannelRequest::Stop(rpc)) => {
588                    rpc.handle(async |()| self.handle_stop(state).await).await;
589                }
590                Event::Request(InterceptChannelRequest::Save(rpc)) => {
591                    rpc.handle_sync(|()| self.handle_save());
592                }
593                Event::Request(InterceptChannelRequest::Restore(saved_state)) => {
594                    self.handle_restore(&saved_state);
595                }
596            }
597        }
598    }
599}
600
601struct MemoryBlockChannelSignal {
602    event: PolledNotify,
603    interrupt: Interrupt,
604}
605
606impl SignalVmbusChannel for MemoryBlockChannelSignal {
607    fn signal_remote(&self) {
608        self.interrupt.deliver();
609    }
610
611    fn poll_for_signal(
612        &self,
613        cx: &mut std::task::Context<'_>,
614    ) -> std::task::Poll<Result<(), ChannelClosed>> {
615        self.event.poll_wait(cx).map(Ok)
616    }
617}