vmbus_relay_intercept_device/
lib.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! This module contains logic used to intercept from VTL2 a vmbus device
5//! provided for a VTL0 guest. This requires the vmbus relay to be active,
6//! which will filter the device out from the list provided to the VTL0 guest
7//! and send any vmbus notifications for that device to the
8//! SimpleVmbusClientDeviceWrapper instance.
9
10#![expect(missing_docs)]
11#![forbid(unsafe_code)]
12
13pub mod ring_buffer;
14
15use crate::ring_buffer::MemoryBlockRingBuffer;
16use anyhow::Context;
17use anyhow::Result;
18use futures::StreamExt;
19use futures_concurrency::stream::Merge;
20use guid::Guid;
21use inspect::InspectMut;
22use mesh::rpc::RpcSend;
23use pal_async::driver::SpawnDriver;
24use std::future::Future;
25use std::future::pending;
26use std::pin::pin;
27use std::sync::Arc;
28use task_control::AsyncRun;
29use task_control::Cancelled;
30use task_control::InspectTaskMut;
31use task_control::StopTask;
32use task_control::TaskControl;
33use tracing::Instrument;
34use user_driver::DmaClient;
35use user_driver::memory::MemoryBlock;
36use vmbus_channel::ChannelClosed;
37use vmbus_channel::RawAsyncChannel;
38use vmbus_channel::SignalVmbusChannel;
39use vmbus_channel::bus::GpadlRequest;
40use vmbus_channel::bus::OpenData;
41use vmbus_client::ChannelRequest;
42use vmbus_client::OfferInfo;
43use vmbus_client::OpenOutput;
44use vmbus_client::OpenRequest;
45use vmbus_core::protocol::GpadlId;
46use vmbus_core::protocol::UserDefinedData;
47use vmbus_relay::InterceptChannelRequest;
48use vmbus_ring::IncomingRing;
49use vmbus_ring::OutgoingRing;
50use vmbus_ring::PAGE_SIZE;
51use vmcore::interrupt::Interrupt;
52use vmcore::notify::Notify;
53use vmcore::notify::PolledNotify;
54use vmcore::save_restore::NoSavedState;
55use vmcore::save_restore::SavedStateBlob;
56use vmcore::save_restore::SavedStateRoot;
57use zerocopy::FromZeros;
58
59pub enum OfferResponse {
60    Ignore,
61    Open,
62}
63
64pub trait SimpleVmbusClientDevice {
65    /// The saved state type.
66    type SavedState: SavedStateRoot + Send + Sync;
67
68    /// The type used to run an open channel.
69    type Runner: 'static + Send + Sync;
70
71    /// Inspects a channel.
72    fn inspect(&mut self, req: inspect::Request<'_>, runner: Option<&mut Self::Runner>);
73
74    /// Returns the instance ID of the matching device.
75    fn instance_id(&self) -> Guid;
76
77    /// Respond to a new channel offer for a device matching instance_id().
78    fn offer(&self, offer: &vmbus_core::protocol::OfferChannel) -> OfferResponse;
79
80    /// Open successful for the channel number `channel_idx`.
81    ///
82    /// When the channel is closed, the runner will be dropped.
83    fn open(
84        &mut self,
85        channel_idx: u16,
86        channel: RawAsyncChannel<MemoryBlockRingBuffer>,
87    ) -> Result<Self::Runner>;
88
89    /// Closes the channel number `channel_idx` after the runner has been
90    /// dropped.
91    fn close(&mut self, channel_idx: u16);
92
93    /// Returns a trait used to save/restore the device.
94    fn supports_save_restore(
95        &mut self,
96    ) -> Option<
97        &mut dyn SaveRestoreSimpleVmbusClientDevice<
98            SavedState = Self::SavedState,
99            Runner = Self::Runner,
100        >,
101    >;
102}
103
104pub trait SimpleVmbusClientDeviceAsync: SimpleVmbusClientDevice + 'static + Send + Sync {
105    /// Runs an open channel until `stop` is signaled.
106    fn run(
107        &mut self,
108        stop: &mut StopTask<'_>,
109        runner: &mut Self::Runner,
110    ) -> impl Send + Future<Output = Result<(), Cancelled>>;
111}
112
113/// Trait implemented by simple vmbus client devices that support save/restore.
114///
115/// If you implement this, make sure to return `Some(self)` from
116/// [`SimpleVmbusClientDevice::supports_save_restore`].
117pub trait SaveRestoreSimpleVmbusClientDevice: SimpleVmbusClientDevice {
118    /// Saves the channel.
119    ///
120    /// Will only be called if the channel is open.
121    fn save_open(&mut self, runner: &Self::Runner) -> Self::SavedState;
122
123    /// Restores the channel.
124    ///
125    /// Will only be called if the channel was saved open.
126    fn restore_open(
127        &mut self,
128        state: Self::SavedState,
129        channel: RawAsyncChannel<MemoryBlockRingBuffer>,
130    ) -> Result<Self::Runner>;
131}
132
133#[derive(InspectMut)]
134pub struct SimpleVmbusClientDeviceWrapper<T: SimpleVmbusClientDeviceAsync> {
135    instance_id: Guid,
136    #[inspect(skip)]
137    spawner: Arc<dyn SpawnDriver>,
138    #[inspect(mut)]
139    vmbus_listener: TaskControl<SimpleVmbusClientDeviceTask<T>, SimpleVmbusClientDeviceTaskState>,
140}
141
142impl<T: SimpleVmbusClientDeviceAsync> SimpleVmbusClientDeviceWrapper<T> {
143    /// Create a new instance.
144    pub fn new(
145        driver: impl SpawnDriver + Clone,
146        dma_alloc: Arc<dyn DmaClient>,
147        device: T,
148    ) -> Result<Self> {
149        let spawner = Arc::new(driver.clone());
150        Ok(Self {
151            instance_id: device.instance_id(),
152            vmbus_listener: TaskControl::new(SimpleVmbusClientDeviceTask::new(
153                device,
154                spawner.clone(),
155                dma_alloc,
156            )),
157            spawner,
158        })
159    }
160
161    pub fn instance_id(&self) -> Guid {
162        self.instance_id
163    }
164
165    pub fn detach(
166        mut self,
167        driver: impl SpawnDriver,
168        recv_relay: mesh::Receiver<InterceptChannelRequest>,
169    ) -> Result<mesh::OneshotSender<()>> {
170        self.vmbus_listener.insert(
171            &self.spawner,
172            format!("{}", self.instance_id),
173            SimpleVmbusClientDeviceTaskState {
174                offer: None,
175                recv_relay,
176                vtl_pages: None,
177            },
178        );
179        let (driver_send, driver_recv) = mesh::oneshot();
180        driver
181            .spawn(
182                format!("vmbus_relay_device {}", self.instance_id),
183                async move {
184                    self.vmbus_listener.start();
185                    let _ = driver_recv.await;
186                    self.vmbus_listener.stop().await;
187                },
188            )
189            .detach();
190        Ok(driver_send)
191    }
192}
193
194struct RelayDeviceTask<T>(T);
195
196impl<T: SimpleVmbusClientDeviceAsync> AsyncRun<T::Runner> for RelayDeviceTask<T> {
197    async fn run(
198        &mut self,
199        stop: &mut StopTask<'_>,
200        runner: &mut T::Runner,
201    ) -> Result<(), Cancelled> {
202        self.0.run(stop, runner).await
203    }
204}
205
206impl<T: SimpleVmbusClientDeviceAsync> InspectTaskMut<T::Runner> for RelayDeviceTask<T> {
207    fn inspect_mut(&mut self, req: inspect::Request<'_>, runner: Option<&mut T::Runner>) {
208        self.0.inspect(req, runner)
209    }
210}
211
212#[derive(InspectMut)]
213struct SimpleVmbusClientDeviceTaskState {
214    offer: Option<OfferInfo>,
215    #[inspect(skip)]
216    recv_relay: mesh::Receiver<InterceptChannelRequest>,
217    #[inspect(hex, with = "|x| x.as_ref().map(|x| inspect::iter_by_index(x.pfns()))")]
218    vtl_pages: Option<MemoryBlock>,
219}
220
221struct SimpleVmbusClientDeviceTask<T: SimpleVmbusClientDeviceAsync> {
222    device: TaskControl<RelayDeviceTask<T>, T::Runner>,
223    saved_state: Option<T::SavedState>,
224    spawner: Arc<dyn SpawnDriver>,
225    dma_alloc: Arc<dyn DmaClient>,
226}
227
228impl<T: SimpleVmbusClientDeviceAsync> AsyncRun<SimpleVmbusClientDeviceTaskState>
229    for SimpleVmbusClientDeviceTask<T>
230{
231    async fn run(
232        &mut self,
233        stop: &mut StopTask<'_>,
234        state: &mut SimpleVmbusClientDeviceTaskState,
235    ) -> Result<(), Cancelled> {
236        stop.until_stopped(self.process_messages(state)).await
237    }
238}
239
240impl<T: SimpleVmbusClientDeviceAsync> InspectTaskMut<SimpleVmbusClientDeviceTaskState>
241    for SimpleVmbusClientDeviceTask<T>
242{
243    fn inspect_mut(
244        &mut self,
245        req: inspect::Request<'_>,
246        state: Option<&mut SimpleVmbusClientDeviceTaskState>,
247    ) {
248        req.respond()
249            .merge(state)
250            .field_mut("device", &mut self.device)
251            .field("dma_alloc", &self.dma_alloc);
252    }
253}
254
255impl<T: SimpleVmbusClientDeviceAsync> SimpleVmbusClientDeviceTask<T> {
256    pub fn new(device: T, spawner: Arc<dyn SpawnDriver>, dma_alloc: Arc<dyn DmaClient>) -> Self {
257        Self {
258            device: TaskControl::new(RelayDeviceTask(device)),
259            saved_state: None,
260            spawner,
261            dma_alloc,
262        }
263    }
264
265    fn insert_runner(&mut self, state: &SimpleVmbusClientDeviceTaskState, runner: T::Runner) {
266        let offer = state.offer.as_ref().unwrap().offer;
267        self.device.insert(
268            &self.spawner,
269            format!("{}-{}", offer.interface_id, offer.instance_id),
270            runner,
271        );
272    }
273
274    /// Configures channel.
275    async fn handle_offer(
276        &mut self,
277        offer: OfferInfo,
278        state: &mut SimpleVmbusClientDeviceTaskState,
279    ) -> Result<()> {
280        tracing::info!(?offer, "matching channel offered");
281
282        if offer.offer.is_dedicated != 1 {
283            tracing::warn!(offer = ?offer.offer, "All offers should be dedicated with Win8+ host")
284        }
285
286        if matches!(
287            self.device.task_mut().0.offer(&offer.offer),
288            OfferResponse::Ignore
289        ) {
290            return Ok(());
291        }
292
293        let interrupt_event = pal_event::Event::new();
294        let (memory, ring_gpadl_id) = self
295            .reserve_memory(state, &offer.request_send, 4)
296            .await
297            .context("reserve memory")?;
298        let guest_to_host_interrupt = offer.guest_to_host_interrupt.clone();
299        state.offer = Some(offer);
300        let offer = state.offer.as_ref().unwrap();
301        self.open_channel(&offer.request_send, ring_gpadl_id, &interrupt_event)
302            .await
303            .context("open channel")?;
304        let channel = self
305            .create_vmbus_channel(&memory, &interrupt_event, guest_to_host_interrupt)
306            .context("create vmbus queue")?;
307
308        let save_restore = self.device.task_mut().0.supports_save_restore();
309        let saved_state = self.saved_state.take();
310        let device_runner = if save_restore.is_some() && saved_state.is_some() {
311            save_restore
312                .unwrap()
313                .restore_open(saved_state.unwrap(), channel)
314                .context("device restore_open callback")?
315        } else {
316            self.device
317                .task_mut()
318                .0
319                .open(offer.offer.subchannel_index, channel)
320                .context("device open callback")?
321        };
322        self.insert_runner(state, device_runner);
323        self.device.start();
324        Ok(())
325    }
326
327    /// Start channel after it has been stopped.
328    async fn handle_start(&mut self, state: &mut SimpleVmbusClientDeviceTaskState) {
329        if self.device.is_running() {
330            return;
331        }
332
333        let offer = state.offer.take();
334        if offer.is_none() {
335            return;
336        }
337
338        // If there is a previous valid offer, open the channel again.
339        if let Err(err) = self.handle_offer(offer.unwrap(), state).await {
340            tracing::error!(
341                err = err.as_ref() as &dyn std::error::Error,
342                "Failed to reconnect vmbus channel"
343            );
344        }
345    }
346
347    async fn cleanup_device_resources(&mut self, state: &mut SimpleVmbusClientDeviceTaskState) {
348        let Some(offer) = state.offer.as_mut() else {
349            return;
350        };
351
352        if state.vtl_pages.is_some() {
353            if let Err(err) = offer
354                .request_send
355                .call(
356                    ChannelRequest::TeardownGpadl,
357                    GpadlId(state.vtl_pages.as_ref().unwrap().pfns()[1] as u32),
358                )
359                .await
360            {
361                tracing::error!(
362                    error = &err as &dyn std::error::Error,
363                    "failed to teardown gpadl"
364                );
365            }
366
367            state.vtl_pages = None;
368        }
369    }
370
371    /// Stop channel
372    async fn handle_stop(&mut self, state: &mut SimpleVmbusClientDeviceTaskState) {
373        if !self.device.stop().await {
374            return;
375        }
376
377        // Close the channel on every stop. Overlay devices cannot be saved /
378        // restored because the physical pages used for the ring buffer, et al.
379        // would need to be reserved at boot, otherwise the host may end up
380        // scribbling on random memory as it continues updating a ring buffer it
381        // assumes it has ownership of.
382        //
383        // TODO: We could support save restore, if we had a pool of memory that
384        // supports that. This should be possible once the page_pool_alloc is
385        // available everywhere.
386        {
387            let offer = state.offer.as_ref().expect("device opened");
388            offer
389                .request_send
390                .call(ChannelRequest::Close, ())
391                .await
392                .ok();
393        }
394        // N.B. This will wait for a TeardownGpadl response which can be used
395        // as a signal that the channel is closed and the ring buffers are no
396        // longer in use.
397        self.cleanup_device_resources(state).await;
398        let runner = self.device.remove();
399        let device = self.device.task_mut();
400        if let Some(save_restore) = device.0.supports_save_restore() {
401            self.saved_state = Some(save_restore.save_open(&runner));
402        }
403        drop(runner);
404        let offer = state.offer.as_ref().expect("device opened");
405        device.0.close(offer.offer.subchannel_index);
406    }
407
408    /// Allocates memory to be shared with the host and registers it with a
409    /// GPADL ID.
410    async fn reserve_memory(
411        &mut self,
412        state: &mut SimpleVmbusClientDeviceTaskState,
413        request_send: &mesh::Sender<ChannelRequest>,
414        page_count: usize,
415    ) -> Result<(MemoryBlock, GpadlId)> {
416        // Incoming and outgoing rings require a minimum of two pages apiece:
417        // one for the control bytes and at least one for the ring.
418        assert!(page_count >= 4);
419
420        let mem = self
421            .dma_alloc
422            .allocate_dma_buffer(page_count * PAGE_SIZE)
423            .context("allocating memory for vmbus rings")?;
424        state.vtl_pages = Some(mem.clone());
425        let buf: Vec<_> = [mem.len() as u64]
426            .iter()
427            .chain(mem.pfns())
428            .copied()
429            .collect();
430
431        let gpadl_id = GpadlId(state.vtl_pages.as_ref().unwrap().pfns()[1] as u32);
432        request_send
433            .call_failable(
434                ChannelRequest::Gpadl,
435                GpadlRequest {
436                    id: gpadl_id,
437                    count: 1,
438                    buf,
439                },
440            )
441            .await
442            .context("registering gpadl")?;
443        Ok((mem, gpadl_id))
444    }
445
446    /// Open the channel offered by the host.
447    async fn open_channel(
448        &self,
449        request_send: &mesh::Sender<ChannelRequest>,
450        ring_gpadl_id: GpadlId,
451        event: &pal_event::Event,
452    ) -> Result<OpenOutput> {
453        let open_request = OpenRequest {
454            open_data: OpenData {
455                target_vp: 0,
456                ring_offset: 2,
457                ring_gpadl_id,
458                event_flag: !0,
459                connection_id: !0,
460                user_data: UserDefinedData::new_zeroed(),
461            },
462            incoming_event: Some(event.clone()),
463            use_vtl2_connection_id: true,
464        };
465
466        request_send
467            .call_failable(ChannelRequest::Open, open_request)
468            .instrument(tracing::info_span!(
469                "opening vmbus channel for intercepted device"
470            ))
471            .await
472            .context("open vmbus channel")
473    }
474
475    /// Create a raw vmbus channel.
476    fn create_vmbus_channel(
477        &self,
478        mem: &MemoryBlock,
479        host_to_guest_event: &pal_event::Event,
480        guest_to_host_interrupt: Interrupt,
481    ) -> Result<RawAsyncChannel<MemoryBlockRingBuffer>> {
482        let (out_ring_mem, in_ring_mem) = (
483            mem.subblock(0, 2 * PAGE_SIZE),
484            mem.subblock(2 * PAGE_SIZE, 2 * PAGE_SIZE),
485        );
486        let (in_ring, out_ring) = (
487            IncomingRing::new(in_ring_mem.into()).unwrap(),
488            OutgoingRing::new(out_ring_mem.into()).unwrap(),
489        );
490
491        let signal = MemoryBlockChannelSignal {
492            event: Notify::from_event(host_to_guest_event.clone())
493                .pollable(self.spawner.as_ref())
494                .unwrap(),
495            interrupt: guest_to_host_interrupt,
496        };
497        Ok(RawAsyncChannel {
498            in_ring,
499            out_ring,
500            signal: Box::new(signal),
501        })
502    }
503
504    /// Responds to the channel being revoked by the host.
505    async fn handle_revoke(&mut self, state: &mut SimpleVmbusClientDeviceTaskState) {
506        let Some(offer) = state.offer.take() else {
507            return;
508        };
509        tracing::info!("device revoked");
510        if self.device.stop().await {
511            drop(self.device.remove());
512            self.device.task_mut().0.close(offer.offer.subchannel_index);
513        }
514        self.cleanup_device_resources(state).await;
515    }
516
517    fn handle_save(&mut self) -> SavedStateBlob {
518        let saved_state = self.saved_state.take();
519        if let Some(saved_state) = saved_state {
520            let blob = SavedStateBlob::new(saved_state);
521            self.handle_restore(&blob);
522            blob
523        } else {
524            SavedStateBlob::new(NoSavedState)
525        }
526    }
527
528    fn handle_restore(&mut self, saved_state_blob: &SavedStateBlob) {
529        self.saved_state = match saved_state_blob.parse() {
530            Ok(saved_state) => Some(saved_state),
531            Err(err) => {
532                tracing::error!(
533                    err = &err as &dyn std::error::Error,
534                    "Protobuf conversion error saving state"
535                );
536                None
537            }
538        };
539    }
540
541    /// Handle vmbus messages from the host and control messages from the
542    /// device wrapper.
543    pub async fn process_messages(&mut self, state: &mut SimpleVmbusClientDeviceTaskState) {
544        loop {
545            #[expect(clippy::large_enum_variant)]
546            enum Event {
547                Request(InterceptChannelRequest),
548                Revoke(()),
549            }
550            let revoke = pin!(async {
551                if let Some(offer) = &mut state.offer {
552                    (&mut offer.revoke_recv).await.ok();
553                } else {
554                    pending().await
555                }
556            });
557            let Some(r) = (
558                (&mut state.recv_relay).map(Event::Request),
559                futures::stream::once(revoke).map(Event::Revoke),
560            )
561                .merge()
562                .next()
563                .await
564            else {
565                break;
566            };
567            match r {
568                Event::Revoke(()) => {
569                    self.handle_revoke(state).await;
570                }
571                Event::Request(InterceptChannelRequest::Offer(offer)) => {
572                    // Any extraneous offer notifications (e.g. from a request offers
573                    // query) are ignored.
574                    if !self.device.is_running() {
575                        if let Err(err) = self.handle_offer(offer, state).await {
576                            tracing::error!(
577                                error = err.as_ref() as &dyn std::error::Error,
578                                "failed offer handling"
579                            );
580                        }
581                    }
582                }
583                Event::Request(InterceptChannelRequest::Start) => {
584                    self.handle_start(state).await;
585                }
586                Event::Request(InterceptChannelRequest::Stop(rpc)) => {
587                    rpc.handle(async |()| self.handle_stop(state).await).await;
588                }
589                Event::Request(InterceptChannelRequest::Save(rpc)) => {
590                    rpc.handle_sync(|()| self.handle_save());
591                }
592                Event::Request(InterceptChannelRequest::Restore(saved_state)) => {
593                    self.handle_restore(&saved_state);
594                }
595            }
596        }
597    }
598}
599
600struct MemoryBlockChannelSignal {
601    event: PolledNotify,
602    interrupt: Interrupt,
603}
604
605impl SignalVmbusChannel for MemoryBlockChannelSignal {
606    fn signal_remote(&self) {
607        self.interrupt.deliver();
608    }
609
610    fn poll_for_signal(
611        &self,
612        cx: &mut std::task::Context<'_>,
613    ) -> std::task::Poll<Result<(), ChannelClosed>> {
614        self.event.poll_wait(cx).map(Ok)
615    }
616}