vmbus_client/
driver.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! Support for authoring vmbus device drivers on top of the vmbus client
5//! driver.
6
7use crate::ChannelRequest;
8use crate::OfferInfo;
9use crate::OpenRequest;
10use anyhow::Context as _;
11use futures::FutureExt;
12use futures_concurrency::future::Race;
13use inspect::InspectMut;
14use mesh::rpc::RpcSend;
15use pal_async::driver::SpawnDriver;
16use std::mem::ManuallyDrop;
17use std::sync::Arc;
18use std::sync::atomic::AtomicBool;
19use std::sync::atomic::AtomicU8;
20use std::sync::atomic::Ordering::Relaxed;
21use std::task::Poll;
22use user_driver::DmaClient;
23use user_driver::memory::MemoryBlock;
24use vmbus_channel::ChannelClosed;
25use vmbus_channel::RawAsyncChannel;
26use vmbus_channel::SignalVmbusChannel;
27use vmbus_channel::bus::GpadlRequest;
28use vmbus_channel::bus::OpenData;
29use vmbus_core::protocol::GpadlId;
30use vmbus_core::protocol::UserDefinedData;
31use vmbus_ring::IncomingRing;
32use vmbus_ring::OutgoingRing;
33use vmbus_ring::SingleMappedRingMem;
34use vmcore::interrupt::Interrupt;
35use vmcore::notify::Notify;
36use vmcore::notify::PolledNotify;
37
38/// Input parameters when opening a vmbus channel.
39pub struct OpenParams {
40    /// The number of pages to use for the ring buffer.
41    pub ring_pages: u16,
42    /// The offset in pages where the downstream ring starts.
43    pub ring_offset_in_pages: u16,
44}
45
46/// The memory type used for the vmbus channel ring buffer.
47pub type MemoryBlockRingMem = SingleMappedRingMem<MemoryBlockView>;
48
49/// Opens a vmbus channel, returning the ring buffer parameters.
50pub async fn open_channel(
51    driver: impl SpawnDriver + Clone + 'static,
52    offer_info: OfferInfo,
53    params: OpenParams,
54    dma_client: &dyn DmaClient,
55) -> anyhow::Result<RawAsyncChannel<MemoryBlockRingMem>> {
56    let gpadl =
57        dma_client.allocate_dma_buffer(vmbus_ring::PAGE_SIZE * params.ring_pages as usize)?;
58
59    let (resp_send, resp_recv) = mesh::oneshot();
60    // Detach the task so that it doesn't get dropped (and thereby leak the allocation).
61    driver
62        .clone()
63        .spawn("vmbus_client", async move {
64            ChannelWorker::run(driver, offer_info, params, gpadl, resp_send).await
65        })
66        .detach();
67
68    resp_recv.await.context("no response opening channel")?
69}
70
71#[derive(InspectMut)]
72struct ChannelWorker<D> {
73    #[inspect(skip)]
74    driver: D,
75    offer: vmbus_core::protocol::OfferChannel,
76    #[inspect(skip)]
77    request_send: mesh::Sender<ChannelRequest>,
78    #[inspect(skip)]
79    gpadl: ManuallyDrop<Arc<MemoryBlock>>,
80    #[inspect(debug)]
81    ring_gpadl_id: GpadlId,
82    is_gpadl_created: bool,
83    is_open: bool,
84}
85
86impl<D: SpawnDriver> ChannelWorker<D> {
87    async fn open(
88        &mut self,
89        input: OpenParams,
90        close_send: mesh::OneshotSender<()>,
91        guest_to_host: Interrupt,
92        host_to_guest: pal_event::Event,
93        revoked: Arc<AtomicBool>,
94    ) -> anyhow::Result<RawAsyncChannel<MemoryBlockRingMem>> {
95        let gpadl_buf = [self.gpadl.len() as u64]
96            .into_iter()
97            .chain(self.gpadl.pfns().iter().copied())
98            .collect::<Vec<_>>();
99
100        self.request_send
101            .call_failable(
102                ChannelRequest::Gpadl,
103                GpadlRequest {
104                    id: self.ring_gpadl_id,
105                    count: 1,
106                    buf: gpadl_buf,
107                },
108            )
109            .await?;
110
111        self.is_gpadl_created = true;
112
113        self.request_send
114            .call_failable(
115                ChannelRequest::Open,
116                OpenRequest {
117                    open_data: OpenData {
118                        target_vp: 0, // TODO: improve
119                        ring_offset: input.ring_offset_in_pages.into(),
120                        ring_gpadl_id: self.ring_gpadl_id,
121                        event_flag: !0,
122                        connection_id: !0,
123                        user_data: UserDefinedData::default(),
124                    },
125                    incoming_event: Some(host_to_guest.clone()),
126                    use_vtl2_connection_id: true,
127                },
128            )
129            .await?;
130
131        self.is_open = true;
132
133        let in_ring = MemoryBlockView {
134            mem: Arc::clone(&self.gpadl),
135            offset: input.ring_offset_in_pages as usize * vmbus_ring::PAGE_SIZE,
136            len: (input.ring_pages - input.ring_offset_in_pages) as usize * vmbus_ring::PAGE_SIZE,
137        };
138
139        let out_ring = MemoryBlockView {
140            mem: Arc::clone(&self.gpadl),
141            offset: 0,
142            len: input.ring_offset_in_pages as usize * vmbus_ring::PAGE_SIZE,
143        };
144
145        let signal = ClientSignaller {
146            guest_to_host,
147            host_to_guest: Notify::from_event(host_to_guest).pollable(&self.driver)?,
148            revoked: revoked.clone(),
149            _close: close_send,
150        };
151
152        Ok(RawAsyncChannel {
153            in_ring: IncomingRing::new(SingleMappedRingMem(in_ring))?,
154            out_ring: OutgoingRing::new(SingleMappedRingMem(out_ring))?,
155            signal: Box::new(signal),
156        })
157    }
158
159    async fn shutdown(self) {
160        if self.is_open {
161            self.request_send.call(ChannelRequest::Close, ()).await.ok();
162        }
163
164        if self.is_gpadl_created {
165            self.request_send
166                .call(ChannelRequest::TeardownGpadl, self.ring_gpadl_id)
167                .await
168                .ok();
169        }
170
171        // Now it is safe to deallocate the gpadl memory.
172        ManuallyDrop::into_inner(self.gpadl);
173    }
174
175    async fn run(
176        driver: D,
177        offer_info: OfferInfo,
178        input: OpenParams,
179        gpadl_mem: MemoryBlock,
180        resp: mesh::OneshotSender<anyhow::Result<RawAsyncChannel<MemoryBlockRingMem>>>,
181    ) {
182        let instance_id = offer_info.offer.instance_id;
183        let ring_gpadl_id = GpadlId((1 << 31) | offer_info.offer.channel_id.0);
184
185        let mut worker = ChannelWorker {
186            driver,
187            offer: offer_info.offer,
188            request_send: offer_info.request_send,
189            gpadl: ManuallyDrop::new(Arc::new(gpadl_mem)),
190            ring_gpadl_id,
191            is_gpadl_created: false,
192            is_open: false,
193        };
194
195        let revoked = Arc::new(AtomicBool::new(false));
196        let (close_send, close_recv) = mesh::oneshot();
197        let host_to_guest = pal_event::Event::new();
198        match worker
199            .open(
200                input,
201                close_send,
202                offer_info.guest_to_host_interrupt,
203                host_to_guest.clone(),
204                revoked.clone(),
205            )
206            .await
207        {
208            Ok(channel) => {
209                resp.send(Ok(channel));
210            }
211            Err(e) => {
212                resp.send(Err(e));
213                worker.shutdown().await;
214                return;
215            }
216        }
217
218        enum Event<T, U> {
219            Close(T),
220            Revoke(U),
221        }
222
223        let revoke = offer_info.revoke_recv.map(Event::Revoke);
224        let close = close_recv.map(Event::Close);
225        let event = (revoke, close).race().await;
226        match event {
227            Event::Close(_) => {
228                tracing::debug!(%instance_id, "channel close requested");
229            }
230            Event::Revoke(_) => {
231                tracing::debug!(%instance_id, "channel revoked");
232                revoked.store(true, Relaxed);
233                host_to_guest.signal();
234                worker.is_open = false;
235            }
236        }
237
238        worker.shutdown().await;
239    }
240}
241
242/// A view into a [`MemoryBlock`].
243pub struct MemoryBlockView {
244    mem: Arc<MemoryBlock>,
245    offset: usize,
246    len: usize,
247}
248
249impl AsRef<[AtomicU8]> for MemoryBlockView {
250    fn as_ref(&self) -> &[AtomicU8] {
251        &self.mem.as_slice()[self.offset..][..self.len]
252    }
253}
254
255struct ClientSignaller {
256    guest_to_host: Interrupt,
257    host_to_guest: PolledNotify,
258    revoked: Arc<AtomicBool>,
259    // This closes the channel on drop.
260    _close: mesh::OneshotSender<()>,
261}
262
263impl SignalVmbusChannel for ClientSignaller {
264    fn signal_remote(&self) {
265        self.guest_to_host.deliver();
266    }
267
268    fn poll_for_signal(&self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), ChannelClosed>> {
269        if self.revoked.load(Relaxed) {
270            return Poll::Ready(Err(ChannelClosed));
271        }
272        self.host_to_guest.poll_wait(cx).map(Ok)
273    }
274}