1use crate::ChannelRequest;
8use crate::OfferInfo;
9use crate::OpenRequest;
10use anyhow::Context as _;
11use futures::FutureExt;
12use futures_concurrency::future::Race;
13use inspect::InspectMut;
14use mesh::rpc::RpcSend;
15use pal_async::driver::SpawnDriver;
16use std::mem::ManuallyDrop;
17use std::sync::Arc;
18use std::sync::atomic::AtomicBool;
19use std::sync::atomic::AtomicU8;
20use std::sync::atomic::Ordering::Relaxed;
21use std::task::Poll;
22use user_driver::DmaClient;
23use user_driver::memory::MemoryBlock;
24use vmbus_channel::ChannelClosed;
25use vmbus_channel::RawAsyncChannel;
26use vmbus_channel::SignalVmbusChannel;
27use vmbus_channel::bus::GpadlRequest;
28use vmbus_channel::bus::OpenData;
29use vmbus_core::protocol::GpadlId;
30use vmbus_core::protocol::UserDefinedData;
31use vmbus_ring::IncomingRing;
32use vmbus_ring::OutgoingRing;
33use vmbus_ring::SingleMappedRingMem;
34use vmcore::interrupt::Interrupt;
35use vmcore::notify::Notify;
36use vmcore::notify::PolledNotify;
37
38pub struct OpenParams {
40 pub ring_pages: u16,
42 pub ring_offset_in_pages: u16,
44}
45
46pub type MemoryBlockRingMem = SingleMappedRingMem<MemoryBlockView>;
48
49pub async fn open_channel(
51 driver: impl SpawnDriver + Clone + 'static,
52 offer_info: OfferInfo,
53 params: OpenParams,
54 dma_client: &dyn DmaClient,
55) -> anyhow::Result<RawAsyncChannel<MemoryBlockRingMem>> {
56 let gpadl =
57 dma_client.allocate_dma_buffer(vmbus_ring::PAGE_SIZE * params.ring_pages as usize)?;
58
59 let (resp_send, resp_recv) = mesh::oneshot();
60 driver
62 .clone()
63 .spawn("vmbus_client", async move {
64 ChannelWorker::run(driver, offer_info, params, gpadl, resp_send).await
65 })
66 .detach();
67
68 resp_recv.await.context("no response opening channel")?
69}
70
71#[derive(InspectMut)]
72struct ChannelWorker<D> {
73 #[inspect(skip)]
74 driver: D,
75 offer: vmbus_core::protocol::OfferChannel,
76 #[inspect(skip)]
77 request_send: mesh::Sender<ChannelRequest>,
78 #[inspect(skip)]
79 gpadl: ManuallyDrop<Arc<MemoryBlock>>,
80 #[inspect(debug)]
81 ring_gpadl_id: GpadlId,
82 is_gpadl_created: bool,
83 is_open: bool,
84}
85
86impl<D: SpawnDriver> ChannelWorker<D> {
87 async fn open(
88 &mut self,
89 input: OpenParams,
90 close_send: mesh::OneshotSender<()>,
91 host_to_guest: pal_event::Event,
92 revoked: Arc<AtomicBool>,
93 ) -> anyhow::Result<RawAsyncChannel<MemoryBlockRingMem>> {
94 let gpadl_buf = [self.gpadl.len() as u64]
95 .into_iter()
96 .chain(self.gpadl.pfns().iter().copied())
97 .collect::<Vec<_>>();
98
99 self.request_send
100 .call_failable(
101 ChannelRequest::Gpadl,
102 GpadlRequest {
103 id: self.ring_gpadl_id,
104 count: 1,
105 buf: gpadl_buf,
106 },
107 )
108 .await?;
109
110 self.is_gpadl_created = true;
111
112 let open = self
113 .request_send
114 .call_failable(
115 ChannelRequest::Open,
116 OpenRequest {
117 open_data: OpenData {
118 target_vp: 0, ring_offset: input.ring_offset_in_pages.into(),
120 ring_gpadl_id: self.ring_gpadl_id,
121 event_flag: !0,
122 connection_id: !0,
123 user_data: UserDefinedData::default(),
124 },
125 incoming_event: Some(host_to_guest.clone()),
126 use_vtl2_connection_id: true,
127 },
128 )
129 .await?;
130
131 self.is_open = true;
132
133 let in_ring = MemoryBlockView {
134 mem: Arc::clone(&self.gpadl),
135 offset: input.ring_offset_in_pages as usize * vmbus_ring::PAGE_SIZE,
136 len: (input.ring_pages - input.ring_offset_in_pages) as usize * vmbus_ring::PAGE_SIZE,
137 };
138
139 let out_ring = MemoryBlockView {
140 mem: Arc::clone(&self.gpadl),
141 offset: 0,
142 len: input.ring_offset_in_pages as usize * vmbus_ring::PAGE_SIZE,
143 };
144
145 let signal = ClientSignaller {
146 guest_to_host: open.guest_to_host_signal,
147 host_to_guest: Notify::from_event(host_to_guest).pollable(&self.driver)?,
148 revoked: revoked.clone(),
149 _close: close_send,
150 };
151
152 Ok(RawAsyncChannel {
153 in_ring: IncomingRing::new(SingleMappedRingMem(in_ring))?,
154 out_ring: OutgoingRing::new(SingleMappedRingMem(out_ring))?,
155 signal: Box::new(signal),
156 })
157 }
158
159 async fn shutdown(self) {
160 if self.is_open {
161 self.request_send.call(ChannelRequest::Close, ()).await.ok();
162 }
163
164 if self.is_gpadl_created {
165 self.request_send
166 .call(ChannelRequest::TeardownGpadl, self.ring_gpadl_id)
167 .await
168 .ok();
169 }
170
171 ManuallyDrop::into_inner(self.gpadl);
173 }
174
175 async fn run(
176 driver: D,
177 offer_info: OfferInfo,
178 input: OpenParams,
179 gpadl_mem: MemoryBlock,
180 resp: mesh::OneshotSender<anyhow::Result<RawAsyncChannel<MemoryBlockRingMem>>>,
181 ) {
182 let instance_id = offer_info.offer.instance_id;
183 let ring_gpadl_id = GpadlId((1 << 31) | offer_info.offer.channel_id.0);
184
185 let mut worker = ChannelWorker {
186 driver,
187 offer: offer_info.offer,
188 request_send: offer_info.request_send,
189 gpadl: ManuallyDrop::new(Arc::new(gpadl_mem)),
190 ring_gpadl_id,
191 is_gpadl_created: false,
192 is_open: false,
193 };
194
195 let revoked = Arc::new(AtomicBool::new(false));
196 let (close_send, close_recv) = mesh::oneshot();
197 let host_to_guest = pal_event::Event::new();
198 match worker
199 .open(input, close_send, host_to_guest.clone(), revoked.clone())
200 .await
201 {
202 Ok(channel) => {
203 resp.send(Ok(channel));
204 }
205 Err(e) => {
206 resp.send(Err(e));
207 worker.shutdown().await;
208 return;
209 }
210 }
211
212 enum Event<T, U> {
213 Close(T),
214 Revoke(U),
215 }
216
217 let revoke = offer_info.revoke_recv.map(Event::Revoke);
218 let close = close_recv.map(Event::Close);
219 let event = (revoke, close).race().await;
220 match event {
221 Event::Close(_) => {
222 tracing::debug!(%instance_id, "channel close requested");
223 }
224 Event::Revoke(_) => {
225 tracing::debug!(%instance_id, "channel revoked");
226 revoked.store(true, Relaxed);
227 host_to_guest.signal();
228 worker.is_open = false;
229 }
230 }
231
232 worker.shutdown().await;
233 }
234}
235
236pub struct MemoryBlockView {
238 mem: Arc<MemoryBlock>,
239 offset: usize,
240 len: usize,
241}
242
243impl AsRef<[AtomicU8]> for MemoryBlockView {
244 fn as_ref(&self) -> &[AtomicU8] {
245 &self.mem.as_slice()[self.offset..][..self.len]
246 }
247}
248
249struct ClientSignaller {
250 guest_to_host: Interrupt,
251 host_to_guest: PolledNotify,
252 revoked: Arc<AtomicBool>,
253 _close: mesh::OneshotSender<()>,
255}
256
257impl SignalVmbusChannel for ClientSignaller {
258 fn signal_remote(&self) {
259 self.guest_to_host.deliver();
260 }
261
262 fn poll_for_signal(&self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), ChannelClosed>> {
263 if self.revoked.load(Relaxed) {
264 return Poll::Ready(Err(ChannelClosed));
265 }
266 self.host_to_guest.poll_wait(cx).map(Ok)
267 }
268}