1use crate::ChannelRequest;
8use crate::OfferInfo;
9use crate::OpenRequest;
10use anyhow::Context as _;
11use futures::FutureExt;
12use futures_concurrency::future::Race;
13use inspect::InspectMut;
14use mesh::rpc::RpcSend;
15use pal_async::driver::SpawnDriver;
16use std::mem::ManuallyDrop;
17use std::sync::Arc;
18use std::sync::atomic::AtomicBool;
19use std::sync::atomic::AtomicU8;
20use std::sync::atomic::Ordering::Relaxed;
21use std::task::Poll;
22use user_driver::DmaClient;
23use user_driver::memory::MemoryBlock;
24use vmbus_channel::ChannelClosed;
25use vmbus_channel::RawAsyncChannel;
26use vmbus_channel::SignalVmbusChannel;
27use vmbus_channel::bus::GpadlRequest;
28use vmbus_channel::bus::OpenData;
29use vmbus_core::protocol::GpadlId;
30use vmbus_core::protocol::UserDefinedData;
31use vmbus_ring::IncomingRing;
32use vmbus_ring::OutgoingRing;
33use vmbus_ring::SingleMappedRingMem;
34use vmcore::interrupt::Interrupt;
35use vmcore::notify::Notify;
36use vmcore::notify::PolledNotify;
37
38pub struct OpenParams {
40 pub ring_pages: u16,
42 pub ring_offset_in_pages: u16,
44}
45
46pub type MemoryBlockRingMem = SingleMappedRingMem<MemoryBlockView>;
48
49pub async fn open_channel(
51 driver: impl SpawnDriver + Clone + 'static,
52 offer_info: OfferInfo,
53 params: OpenParams,
54 dma_client: &dyn DmaClient,
55) -> anyhow::Result<RawAsyncChannel<MemoryBlockRingMem>> {
56 let gpadl =
57 dma_client.allocate_dma_buffer(vmbus_ring::PAGE_SIZE * params.ring_pages as usize)?;
58
59 let (resp_send, resp_recv) = mesh::oneshot();
60 driver
62 .clone()
63 .spawn("vmbus_client", async move {
64 ChannelWorker::run(driver, offer_info, params, gpadl, resp_send).await
65 })
66 .detach();
67
68 resp_recv.await.context("no response opening channel")?
69}
70
71#[derive(InspectMut)]
72struct ChannelWorker<D> {
73 #[inspect(skip)]
74 driver: D,
75 offer: vmbus_core::protocol::OfferChannel,
76 #[inspect(skip)]
77 request_send: mesh::Sender<ChannelRequest>,
78 #[inspect(skip)]
79 gpadl: ManuallyDrop<Arc<MemoryBlock>>,
80 #[inspect(debug)]
81 ring_gpadl_id: GpadlId,
82 is_gpadl_created: bool,
83 is_open: bool,
84}
85
86impl<D: SpawnDriver> ChannelWorker<D> {
87 async fn open(
88 &mut self,
89 input: OpenParams,
90 close_send: mesh::OneshotSender<()>,
91 guest_to_host: Interrupt,
92 host_to_guest: pal_event::Event,
93 revoked: Arc<AtomicBool>,
94 ) -> anyhow::Result<RawAsyncChannel<MemoryBlockRingMem>> {
95 let gpadl_buf = [self.gpadl.len() as u64]
96 .into_iter()
97 .chain(self.gpadl.pfns().iter().copied())
98 .collect::<Vec<_>>();
99
100 self.request_send
101 .call_failable(
102 ChannelRequest::Gpadl,
103 GpadlRequest {
104 id: self.ring_gpadl_id,
105 count: 1,
106 buf: gpadl_buf,
107 },
108 )
109 .await?;
110
111 self.is_gpadl_created = true;
112
113 self.request_send
114 .call_failable(
115 ChannelRequest::Open,
116 OpenRequest {
117 open_data: OpenData {
118 target_vp: 0, ring_offset: input.ring_offset_in_pages.into(),
120 ring_gpadl_id: self.ring_gpadl_id,
121 event_flag: !0,
122 connection_id: !0,
123 user_data: UserDefinedData::default(),
124 },
125 incoming_event: Some(host_to_guest.clone()),
126 use_vtl2_connection_id: true,
127 },
128 )
129 .await?;
130
131 self.is_open = true;
132
133 let in_ring = MemoryBlockView {
134 mem: Arc::clone(&self.gpadl),
135 offset: input.ring_offset_in_pages as usize * vmbus_ring::PAGE_SIZE,
136 len: (input.ring_pages - input.ring_offset_in_pages) as usize * vmbus_ring::PAGE_SIZE,
137 };
138
139 let out_ring = MemoryBlockView {
140 mem: Arc::clone(&self.gpadl),
141 offset: 0,
142 len: input.ring_offset_in_pages as usize * vmbus_ring::PAGE_SIZE,
143 };
144
145 let signal = ClientSignaller {
146 guest_to_host,
147 host_to_guest: Notify::from_event(host_to_guest).pollable(&self.driver)?,
148 revoked: revoked.clone(),
149 _close: close_send,
150 };
151
152 Ok(RawAsyncChannel {
153 in_ring: IncomingRing::new(SingleMappedRingMem(in_ring))?,
154 out_ring: OutgoingRing::new(SingleMappedRingMem(out_ring))?,
155 signal: Box::new(signal),
156 })
157 }
158
159 async fn shutdown(self) {
160 if self.is_open {
161 self.request_send.call(ChannelRequest::Close, ()).await.ok();
162 }
163
164 if self.is_gpadl_created {
165 self.request_send
166 .call(ChannelRequest::TeardownGpadl, self.ring_gpadl_id)
167 .await
168 .ok();
169 }
170
171 ManuallyDrop::into_inner(self.gpadl);
173 }
174
175 async fn run(
176 driver: D,
177 offer_info: OfferInfo,
178 input: OpenParams,
179 gpadl_mem: MemoryBlock,
180 resp: mesh::OneshotSender<anyhow::Result<RawAsyncChannel<MemoryBlockRingMem>>>,
181 ) {
182 let instance_id = offer_info.offer.instance_id;
183 let ring_gpadl_id = GpadlId((1 << 31) | offer_info.offer.channel_id.0);
184
185 let mut worker = ChannelWorker {
186 driver,
187 offer: offer_info.offer,
188 request_send: offer_info.request_send,
189 gpadl: ManuallyDrop::new(Arc::new(gpadl_mem)),
190 ring_gpadl_id,
191 is_gpadl_created: false,
192 is_open: false,
193 };
194
195 let revoked = Arc::new(AtomicBool::new(false));
196 let (close_send, close_recv) = mesh::oneshot();
197 let host_to_guest = pal_event::Event::new();
198 match worker
199 .open(
200 input,
201 close_send,
202 offer_info.guest_to_host_interrupt,
203 host_to_guest.clone(),
204 revoked.clone(),
205 )
206 .await
207 {
208 Ok(channel) => {
209 resp.send(Ok(channel));
210 }
211 Err(e) => {
212 resp.send(Err(e));
213 worker.shutdown().await;
214 return;
215 }
216 }
217
218 enum Event<T, U> {
219 Close(T),
220 Revoke(U),
221 }
222
223 let revoke = offer_info.revoke_recv.map(Event::Revoke);
224 let close = close_recv.map(Event::Close);
225 let event = (revoke, close).race().await;
226 match event {
227 Event::Close(_) => {
228 tracing::debug!(%instance_id, "channel close requested");
229 }
230 Event::Revoke(_) => {
231 tracing::debug!(%instance_id, "channel revoked");
232 revoked.store(true, Relaxed);
233 host_to_guest.signal();
234 worker.is_open = false;
235 }
236 }
237
238 worker.shutdown().await;
239 }
240}
241
242pub struct MemoryBlockView {
244 mem: Arc<MemoryBlock>,
245 offset: usize,
246 len: usize,
247}
248
249impl AsRef<[AtomicU8]> for MemoryBlockView {
250 fn as_ref(&self) -> &[AtomicU8] {
251 &self.mem.as_slice()[self.offset..][..self.len]
252 }
253}
254
255struct ClientSignaller {
256 guest_to_host: Interrupt,
257 host_to_guest: PolledNotify,
258 revoked: Arc<AtomicBool>,
259 _close: mesh::OneshotSender<()>,
261}
262
263impl SignalVmbusChannel for ClientSignaller {
264 fn signal_remote(&self) {
265 self.guest_to_host.deliver();
266 }
267
268 fn poll_for_signal(&self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), ChannelClosed>> {
269 if self.revoked.load(Relaxed) {
270 return Poll::Ready(Err(ChannelClosed));
271 }
272 self.host_to_guest.poll_wait(cx).map(Ok)
273 }
274}