vmbus_client/
driver.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! Support for authoring vmbus device drivers on top of the vmbus client
5//! driver.
6
7use crate::ChannelRequest;
8use crate::OfferInfo;
9use crate::OpenRequest;
10use anyhow::Context as _;
11use futures::FutureExt;
12use futures_concurrency::future::Race;
13use inspect::InspectMut;
14use mesh::rpc::RpcSend;
15use pal_async::driver::SpawnDriver;
16use std::mem::ManuallyDrop;
17use std::sync::Arc;
18use std::sync::atomic::AtomicBool;
19use std::sync::atomic::AtomicU8;
20use std::sync::atomic::Ordering::Relaxed;
21use std::task::Poll;
22use user_driver::DmaClient;
23use user_driver::memory::MemoryBlock;
24use vmbus_channel::ChannelClosed;
25use vmbus_channel::RawAsyncChannel;
26use vmbus_channel::SignalVmbusChannel;
27use vmbus_channel::bus::GpadlRequest;
28use vmbus_channel::bus::OpenData;
29use vmbus_core::protocol::GpadlId;
30use vmbus_core::protocol::UserDefinedData;
31use vmbus_ring::IncomingRing;
32use vmbus_ring::OutgoingRing;
33use vmbus_ring::SingleMappedRingMem;
34use vmcore::interrupt::Interrupt;
35use vmcore::notify::Notify;
36use vmcore::notify::PolledNotify;
37
38/// Input parameters when opening a vmbus channel.
39pub struct OpenParams {
40    /// The number of pages to use for the ring buffer.
41    pub ring_pages: u16,
42    /// The offset in pages where the downstream ring starts.
43    pub ring_offset_in_pages: u16,
44}
45
46/// The memory type used for the vmbus channel ring buffer.
47pub type MemoryBlockRingMem = SingleMappedRingMem<MemoryBlockView>;
48
49/// Opens a vmbus channel, returning the ring buffer parameters.
50pub async fn open_channel(
51    driver: impl SpawnDriver + Clone + 'static,
52    offer_info: OfferInfo,
53    params: OpenParams,
54    dma_client: &dyn DmaClient,
55) -> anyhow::Result<RawAsyncChannel<MemoryBlockRingMem>> {
56    let gpadl =
57        dma_client.allocate_dma_buffer(vmbus_ring::PAGE_SIZE * params.ring_pages as usize)?;
58
59    let (resp_send, resp_recv) = mesh::oneshot();
60    // Detach the task so that it doesn't get dropped (and thereby leak the allocation).
61    driver
62        .clone()
63        .spawn("vmbus_client", async move {
64            ChannelWorker::run(driver, offer_info, params, gpadl, resp_send).await
65        })
66        .detach();
67
68    resp_recv.await.context("no response opening channel")?
69}
70
71#[derive(InspectMut)]
72struct ChannelWorker<D> {
73    #[inspect(skip)]
74    driver: D,
75    offer: vmbus_core::protocol::OfferChannel,
76    #[inspect(skip)]
77    request_send: mesh::Sender<ChannelRequest>,
78    #[inspect(skip)]
79    gpadl: ManuallyDrop<Arc<MemoryBlock>>,
80    #[inspect(debug)]
81    ring_gpadl_id: GpadlId,
82    is_gpadl_created: bool,
83    is_open: bool,
84}
85
86impl<D: SpawnDriver> ChannelWorker<D> {
87    async fn open(
88        &mut self,
89        input: OpenParams,
90        close_send: mesh::OneshotSender<()>,
91        host_to_guest: pal_event::Event,
92        revoked: Arc<AtomicBool>,
93    ) -> anyhow::Result<RawAsyncChannel<MemoryBlockRingMem>> {
94        let gpadl_buf = [self.gpadl.len() as u64]
95            .into_iter()
96            .chain(self.gpadl.pfns().iter().copied())
97            .collect::<Vec<_>>();
98
99        self.request_send
100            .call_failable(
101                ChannelRequest::Gpadl,
102                GpadlRequest {
103                    id: self.ring_gpadl_id,
104                    count: 1,
105                    buf: gpadl_buf,
106                },
107            )
108            .await?;
109
110        self.is_gpadl_created = true;
111
112        let open = self
113            .request_send
114            .call_failable(
115                ChannelRequest::Open,
116                OpenRequest {
117                    open_data: OpenData {
118                        target_vp: 0, // TODO: improve
119                        ring_offset: input.ring_offset_in_pages.into(),
120                        ring_gpadl_id: self.ring_gpadl_id,
121                        event_flag: !0,
122                        connection_id: !0,
123                        user_data: UserDefinedData::default(),
124                    },
125                    incoming_event: Some(host_to_guest.clone()),
126                    use_vtl2_connection_id: true,
127                },
128            )
129            .await?;
130
131        self.is_open = true;
132
133        let in_ring = MemoryBlockView {
134            mem: Arc::clone(&self.gpadl),
135            offset: input.ring_offset_in_pages as usize * vmbus_ring::PAGE_SIZE,
136            len: (input.ring_pages - input.ring_offset_in_pages) as usize * vmbus_ring::PAGE_SIZE,
137        };
138
139        let out_ring = MemoryBlockView {
140            mem: Arc::clone(&self.gpadl),
141            offset: 0,
142            len: input.ring_offset_in_pages as usize * vmbus_ring::PAGE_SIZE,
143        };
144
145        let signal = ClientSignaller {
146            guest_to_host: open.guest_to_host_signal,
147            host_to_guest: Notify::from_event(host_to_guest).pollable(&self.driver)?,
148            revoked: revoked.clone(),
149            _close: close_send,
150        };
151
152        Ok(RawAsyncChannel {
153            in_ring: IncomingRing::new(SingleMappedRingMem(in_ring))?,
154            out_ring: OutgoingRing::new(SingleMappedRingMem(out_ring))?,
155            signal: Box::new(signal),
156        })
157    }
158
159    async fn shutdown(self) {
160        if self.is_open {
161            self.request_send.call(ChannelRequest::Close, ()).await.ok();
162        }
163
164        if self.is_gpadl_created {
165            self.request_send
166                .call(ChannelRequest::TeardownGpadl, self.ring_gpadl_id)
167                .await
168                .ok();
169        }
170
171        // Now it is safe to deallocate the gpadl memory.
172        ManuallyDrop::into_inner(self.gpadl);
173    }
174
175    async fn run(
176        driver: D,
177        offer_info: OfferInfo,
178        input: OpenParams,
179        gpadl_mem: MemoryBlock,
180        resp: mesh::OneshotSender<anyhow::Result<RawAsyncChannel<MemoryBlockRingMem>>>,
181    ) {
182        let instance_id = offer_info.offer.instance_id;
183        let ring_gpadl_id = GpadlId((1 << 31) | offer_info.offer.channel_id.0);
184
185        let mut worker = ChannelWorker {
186            driver,
187            offer: offer_info.offer,
188            request_send: offer_info.request_send,
189            gpadl: ManuallyDrop::new(Arc::new(gpadl_mem)),
190            ring_gpadl_id,
191            is_gpadl_created: false,
192            is_open: false,
193        };
194
195        let revoked = Arc::new(AtomicBool::new(false));
196        let (close_send, close_recv) = mesh::oneshot();
197        let host_to_guest = pal_event::Event::new();
198        match worker
199            .open(input, close_send, host_to_guest.clone(), revoked.clone())
200            .await
201        {
202            Ok(channel) => {
203                resp.send(Ok(channel));
204            }
205            Err(e) => {
206                resp.send(Err(e));
207                worker.shutdown().await;
208                return;
209            }
210        }
211
212        enum Event<T, U> {
213            Close(T),
214            Revoke(U),
215        }
216
217        let revoke = offer_info.revoke_recv.map(Event::Revoke);
218        let close = close_recv.map(Event::Close);
219        let event = (revoke, close).race().await;
220        match event {
221            Event::Close(_) => {
222                tracing::debug!(%instance_id, "channel close requested");
223            }
224            Event::Revoke(_) => {
225                tracing::debug!(%instance_id, "channel revoked");
226                revoked.store(true, Relaxed);
227                host_to_guest.signal();
228                worker.is_open = false;
229            }
230        }
231
232        worker.shutdown().await;
233    }
234}
235
236/// A view into a [`MemoryBlock`].
237pub struct MemoryBlockView {
238    mem: Arc<MemoryBlock>,
239    offset: usize,
240    len: usize,
241}
242
243impl AsRef<[AtomicU8]> for MemoryBlockView {
244    fn as_ref(&self) -> &[AtomicU8] {
245        &self.mem.as_slice()[self.offset..][..self.len]
246    }
247}
248
249struct ClientSignaller {
250    guest_to_host: Interrupt,
251    host_to_guest: PolledNotify,
252    revoked: Arc<AtomicBool>,
253    // This closes the channel on drop.
254    _close: mesh::OneshotSender<()>,
255}
256
257impl SignalVmbusChannel for ClientSignaller {
258    fn signal_remote(&self) {
259        self.guest_to_host.deliver();
260    }
261
262    fn poll_for_signal(&self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), ChannelClosed>> {
263        if self.revoked.load(Relaxed) {
264            return Poll::Ready(Err(ChannelClosed));
265        }
266        self.host_to_guest.poll_wait(cx).map(Ok)
267    }
268}