vmbus_channel/
offer.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! Vmbus channel offer support.
5
6use crate::ChannelClosed;
7use crate::RawAsyncChannel;
8use crate::SignalVmbusChannel;
9use crate::bus::ChannelRequest;
10use crate::bus::ChannelServerRequest;
11use crate::bus::OfferInput;
12use crate::bus::OfferParams;
13use crate::bus::OfferResources;
14use crate::bus::OpenRequest;
15use crate::bus::ParentBus;
16use crate::gpadl::GpadlMap;
17use crate::gpadl::GpadlMapView;
18use crate::gpadl_ring;
19use crate::gpadl_ring::GpadlRingMem;
20use crate::gpadl_ring::make_rings;
21use futures::StreamExt;
22use mesh::rpc::Rpc;
23use pal_async::driver::Driver;
24use pal_async::task::Spawn;
25use pal_async::task::Task;
26use pal_event::Event;
27use std::sync::Arc;
28use std::sync::atomic::AtomicBool;
29use std::sync::atomic::Ordering;
30use vmbus_ring::gparange::MultiPagedRangeBuf;
31use vmcore::interrupt::Interrupt;
32use vmcore::notify::Notify;
33use vmcore::notify::PolledNotify;
34
35/// A channel accept error.
36#[derive(Debug, thiserror::Error)]
37pub enum Error {
38    /// channel revoked
39    #[error("the channel has been revoked")]
40    Revoked,
41    /// GPADL ring buffer error
42    #[error(transparent)]
43    GpadlRing(#[from] gpadl_ring::Error),
44    /// Driver error
45    #[error("io driver error")]
46    Driver(#[source] std::io::Error),
47}
48
49/// A channel offer.
50pub struct Offer {
51    task: Task<()>,
52    open_recv: mesh::Receiver<OpenMessage>,
53    gpadl_map: GpadlMapView,
54    event: Notify,
55    offer_resources: OfferResources,
56    _server_request_send: mesh::Sender<ChannelServerRequest>,
57}
58
59impl Offer {
60    /// Offers a new channel.
61    pub async fn new(
62        driver: impl Spawn,
63        bus: &dyn ParentBus,
64        offer_params: OfferParams,
65    ) -> anyhow::Result<Self> {
66        let instance_id = offer_params.instance_id;
67        let event = Event::new();
68        let (request_send, request_recv) = mesh::channel();
69        let (server_request_send, server_request_recv) = mesh::channel();
70        let result = bus
71            .add_child(OfferInput {
72                params: offer_params,
73                event: Interrupt::from_event(event.clone()),
74                request_send,
75                server_request_recv,
76            })
77            .await?;
78
79        let gpadls = GpadlMap::new();
80        let gpadl_map = gpadls.clone().view();
81        let (open_send, open_recv) = mesh::channel();
82        let task = driver.spawn(format!("vmbus-offer-{}", instance_id), {
83            let event = event.clone();
84            async move { Self::task(event, gpadls, request_recv, open_send).await }
85        });
86
87        let offer = Self {
88            offer_resources: result,
89            task,
90            open_recv,
91            gpadl_map,
92            event: Notify::from_event(event),
93            _server_request_send: server_request_send,
94        };
95        Ok(offer)
96    }
97
98    async fn task(
99        event: Event,
100        gpadls: Arc<GpadlMap>,
101        mut request_recv: mesh::Receiver<ChannelRequest>,
102        send: mesh::Sender<OpenMessage>,
103    ) {
104        let mut open_done = None;
105        while let Ok(request) = request_recv.recv().await {
106            match request {
107                ChannelRequest::Open(rpc) => {
108                    let (open_request, response_send) = rpc.split();
109                    let done = Arc::new(AtomicBool::new(false));
110                    send.send(OpenMessage {
111                        open_request,
112                        done: done.clone(),
113                        response: OpenResponse(Some(response_send)),
114                    });
115                    open_done = Some(done);
116                }
117                ChannelRequest::Close(rpc) => {
118                    let _response_send = rpc; // TODO: figure out if we should really just drop this here.
119                    open_done
120                        .take()
121                        .expect("channel must be open")
122                        .store(true, Ordering::Relaxed);
123                    event.signal();
124                }
125                ChannelRequest::Gpadl(rpc) => rpc.handle_sync(|gpadl| {
126                    match MultiPagedRangeBuf::from_range_buffer(gpadl.count.into(), gpadl.buf) {
127                        Ok(buf) => {
128                            gpadls.add(gpadl.id, buf);
129                            true
130                        }
131                        Err(err) => {
132                            tracelimit::error_ratelimited!(
133                                error = &err as &dyn std::error::Error,
134                                "failed to parse gpadl"
135                            );
136                            false
137                        }
138                    }
139                }),
140                ChannelRequest::TeardownGpadl(rpc) => {
141                    let (id, response_send) = rpc.split();
142                    if let Some(f) = gpadls.remove(
143                        id,
144                        Box::new(move || {
145                            response_send.complete(());
146                        }),
147                    ) {
148                        f();
149                    }
150                }
151                ChannelRequest::Modify(rpc) => rpc.handle_sync(|_| 0),
152            }
153        }
154    }
155
156    /// Accepts a channel open request from the guest.
157    pub async fn wait_for_open(
158        &mut self,
159        driver: &(impl Driver + ?Sized),
160    ) -> Result<OpeningChannel, Error> {
161        let message = self.open_recv.next().await.ok_or(Error::Revoked)?;
162
163        let (in_ring, out_ring) = make_rings(
164            self.offer_resources.ring_memory(&message.open_request),
165            &self.gpadl_map,
166            &message.open_request.open_data,
167        )?;
168        let event = OfferChannelSignal {
169            event: self.event.clone().pollable(driver).map_err(Error::Driver)?,
170            interrupt: message.open_request.interrupt.clone(),
171            done: message.done,
172        };
173        let channel = RawAsyncChannel {
174            in_ring,
175            out_ring,
176            signal: Box::new(event),
177        };
178        let resources = OpenChannelResources {
179            channel,
180            gpadl_map: self.gpadl_map.clone(),
181        };
182        Ok(OpeningChannel {
183            resources,
184            response: message.response,
185        })
186    }
187
188    /// Revokes the channel.
189    pub async fn revoke(self) {
190        drop(self.open_recv);
191        self.task.await;
192    }
193}
194
195/// An in-progress channel opening, returned by [`Offer::wait_for_open`].
196pub struct OpeningChannel {
197    resources: OpenChannelResources,
198    response: OpenResponse,
199}
200
201impl OpeningChannel {
202    /// Accepts the channel open request.
203    pub fn accept(self) -> OpenChannelResources {
204        self.response.respond(true);
205        self.resources
206    }
207
208    /// Rejects the channel open request.
209    pub fn reject(self) {
210        self.response.respond(false);
211    }
212}
213
214struct OfferChannelSignal {
215    event: PolledNotify,
216    interrupt: Interrupt,
217    done: Arc<AtomicBool>,
218}
219
220impl SignalVmbusChannel for OfferChannelSignal {
221    fn signal_remote(&self) {
222        self.interrupt.deliver();
223    }
224
225    fn poll_for_signal(
226        &self,
227        cx: &mut std::task::Context<'_>,
228    ) -> std::task::Poll<Result<(), ChannelClosed>> {
229        if self.done.load(Ordering::Relaxed) {
230            return Err(ChannelClosed).into();
231        }
232        self.event.poll_wait(cx).map(Ok)
233    }
234}
235
236struct OpenMessage {
237    open_request: OpenRequest,
238    done: Arc<AtomicBool>,
239    response: OpenResponse,
240}
241
242struct OpenResponse(Option<Rpc<(), bool>>);
243
244impl OpenResponse {
245    fn respond(mut self, open: bool) {
246        self.0.take().unwrap().complete(open)
247    }
248}
249
250impl Drop for OpenResponse {
251    fn drop(&mut self) {
252        if let Some(rpc) = self.0.take() {
253            rpc.complete(false);
254        }
255    }
256}
257
258/// Channel resources for an open channel.
259pub struct OpenChannelResources {
260    /// The channel ring buffer and interrupt state.
261    pub channel: RawAsyncChannel<GpadlRingMem>,
262    /// The channel's GPADL map.
263    pub gpadl_map: GpadlMapView,
264}