1#![expect(missing_docs)]
5#![cfg(windows)]
6#![expect(unsafe_code)]
8#![expect(clippy::undocumented_unsafe_blocks, clippy::missing_safety_doc)]
9
10use futures::poll;
11use guestmem::GuestMemory;
12use guid::Guid;
13use mesh::CancelContext;
14use mesh::MeshPayload;
15use pal::windows::ObjectAttributes;
16use pal::windows::UnicodeStringRef;
17use pal_async::driver::Driver;
18use pal_async::windows::overlapped::IoBuf;
19use pal_async::windows::overlapped::IoBufMut;
20use pal_async::windows::overlapped::OverlappedFile;
21use pal_event::Event;
22use std::mem::zeroed;
23use std::os::windows::prelude::*;
24use vmbus_core::HvsockConnectRequest;
25use vmbus_core::HvsockConnectResult;
26use vmbusioctl::VMBUS_CHANNEL_OFFER;
27use vmbusioctl::VMBUS_SERVER_OPEN_CHANNEL_OUTPUT_PARAMETERS;
28use widestring::Utf16Str;
29use widestring::utf16str;
30use windows::Wdk::Storage::FileSystem::NtOpenFile;
31use windows::Win32::Foundation::ERROR_OPERATION_ABORTED;
32use windows::Win32::Foundation::HANDLE;
33use windows::Win32::Foundation::NTSTATUS;
34use windows::Win32::Storage::FileSystem::FILE_ALL_ACCESS;
35use windows::Win32::Storage::FileSystem::SYNCHRONIZE;
36use windows::Win32::System::IO::DeviceIoControl;
37use zerocopy::IntoBytes;
38
39mod proxyioctl;
40pub mod vmbusioctl;
41
42pub type Error = windows::core::Error;
43pub type Result<T> = windows::core::Result<T>;
44
45#[derive(Debug, MeshPayload)]
47pub struct ProxyHandle(std::fs::File);
48
49impl ProxyHandle {
50 pub fn new() -> Result<Self> {
52 const DEVICE_PATH: &Utf16Str = utf16str!("\\Device\\VmbusProxy");
53 let pathu = UnicodeStringRef::try_from(DEVICE_PATH).expect("string fits");
54 let mut oa = ObjectAttributes::new();
55 oa.name(&pathu);
56 unsafe {
58 let mut iosb = zeroed();
59 let mut handle = HANDLE::default();
60 NtOpenFile(
61 &mut handle,
62 (FILE_ALL_ACCESS | SYNCHRONIZE).0,
63 oa.as_ref(),
64 &mut iosb,
65 0,
66 0,
67 )
68 .ok()?;
69 Ok(Self(std::fs::File::from_raw_handle(handle.0 as RawHandle)))
70 }
71 }
72}
73
74impl From<OwnedHandle> for ProxyHandle {
75 fn from(value: OwnedHandle) -> Self {
77 Self(value.into())
78 }
79}
80
81pub struct VmbusProxy {
82 file: OverlappedFile,
83 guest_memory: Option<GuestMemory>,
86 cancel: CancelContext,
87}
88
89#[derive(Debug)]
90pub enum ProxyAction {
91 Offer {
92 id: u64,
93 offer: VMBUS_CHANNEL_OFFER,
94 incoming_event: Event,
95 outgoing_event: Option<Event>,
96 },
97 Revoke {
98 id: u64,
99 },
100 InterruptPolicy {},
101 TlConnectResult {
102 result: HvsockConnectResult,
103 vtl: u8,
104 },
105}
106
107struct StaticIoctlBuffer<T>(T);
108
109unsafe impl<T> IoBuf for StaticIoctlBuffer<T> {
112 fn as_ptr(&self) -> *const u8 {
113 std::ptr::from_ref::<Self>(self).cast()
114 }
115
116 fn len(&self) -> usize {
117 size_of_val(self)
118 }
119}
120
121unsafe impl<T> IoBufMut for StaticIoctlBuffer<T> {
124 fn as_mut_ptr(&mut self) -> *mut u8 {
125 std::ptr::from_mut::<Self>(self).cast()
126 }
127}
128
129impl VmbusProxy {
130 pub fn new(driver: &dyn Driver, handle: ProxyHandle, ctx: CancelContext) -> Result<Self> {
131 Ok(Self {
132 file: OverlappedFile::new(driver, handle.0)?,
133 guest_memory: None,
134 cancel: ctx,
135 })
136 }
137
138 pub fn handle(&self) -> BorrowedHandle<'_> {
139 self.file.get().as_handle()
140 }
141
142 async unsafe fn ioctl<In, Out>(&self, code: u32, input: In, output: Out) -> Result<Out>
143 where
144 In: IoBufMut,
145 Out: IoBufMut,
146 {
147 let (r, (_, output)) = unsafe { self.file.ioctl(code, input, output).await };
149 r?;
150 Ok(output)
151 }
152
153 async unsafe fn ioctl_cancellable<In, Out>(
154 &self,
155 code: u32,
156 input: In,
157 output: Out,
158 ) -> Result<Out>
159 where
160 In: IoBufMut,
161 Out: IoBufMut,
162 {
163 let mut cancel = self.cancel.clone();
165 if cancel.is_cancelled() {
166 tracing::trace!("ioctl cancelled before issued");
167 return Err(ERROR_OPERATION_ABORTED.into());
168 }
169
170 let mut ioctl = unsafe { self.file.ioctl(code, input, output) };
172 let (r, (_, output)) = match poll!(&mut ioctl) {
173 std::task::Poll::Ready(result) => result,
174 std::task::Poll::Pending => {
175 match cancel.until_cancelled(&mut ioctl).await {
176 Ok(r) => r,
177 Err(_) => {
178 tracing::trace!("ioctl cancelled after issued");
179 ioctl.cancel();
180 ioctl.await
183 }
184 }
185 }
186 };
187
188 r?;
189 Ok(output)
190 }
191
192 pub async fn set_memory(&mut self, guest_memory: &GuestMemory) -> Result<()> {
193 assert!(self.guest_memory.is_none());
194 let (base, len) = guest_memory.full_mapping().ok_or_else(|| {
195 std::io::Error::other("vmbusproxy not supported without mapped memory")
196 })?;
197 self.guest_memory = Some(guest_memory.clone());
198 unsafe {
199 self.ioctl(
200 proxyioctl::IOCTL_VMBUS_PROXY_SET_MEMORY,
201 StaticIoctlBuffer(proxyioctl::VMBUS_PROXY_SET_MEMORY_INPUT {
202 BaseAddress: base as usize as u64,
203 Size: len as u64,
204 }),
205 (),
206 )
207 .await
208 }
209 }
210
211 pub async fn next_action(&self) -> Result<ProxyAction> {
212 let output = unsafe {
213 self.ioctl_cancellable(
214 proxyioctl::IOCTL_VMBUS_PROXY_NEXT_ACTION,
215 (),
216 StaticIoctlBuffer(zeroed::<proxyioctl::VMBUS_PROXY_NEXT_ACTION_OUTPUT>()),
217 )
218 .await?
219 .0
220 };
221 match output.Type {
222 proxyioctl::VmbusProxyActionTypeOffer => unsafe {
223 Ok(ProxyAction::Offer {
224 id: output.ProxyId,
225 offer: output.u.Offer.Offer,
226 incoming_event: OwnedHandle::from_raw_handle(
227 output.u.Offer.DeviceIncomingRingEvent as usize as RawHandle,
228 )
229 .into(),
230 outgoing_event: if output.u.Offer.DeviceOutgoingRingEvent != 0 {
231 Some(
232 OwnedHandle::from_raw_handle(
233 output.u.Offer.DeviceOutgoingRingEvent as usize as RawHandle,
234 )
235 .into(),
236 )
237 } else {
238 None
239 },
240 })
241 },
242 proxyioctl::VmbusProxyActionTypeRevoke => {
243 Ok(ProxyAction::Revoke { id: output.ProxyId })
244 }
245 proxyioctl::VmbusProxyActionTypeInterruptPolicy => Ok(ProxyAction::InterruptPolicy {}),
246 proxyioctl::VmbusProxyActionTypeTlConnectResult => unsafe {
247 Ok(ProxyAction::TlConnectResult {
248 result: HvsockConnectResult {
249 endpoint_id: output.u.TlConnectResult.EndpointId.into(),
250 service_id: output.u.TlConnectResult.ServiceId.into(),
251 success: output.u.TlConnectResult.Status.is_ok(),
252 },
253 vtl: output.u.TlConnectResult.Vtl,
254 })
255 },
256 n => panic!("unexpected action: {}", n),
257 }
258 }
259
260 pub async fn open(
261 &self,
262 id: u64,
263 params: &VMBUS_SERVER_OPEN_CHANNEL_OUTPUT_PARAMETERS,
264 event: &Event,
265 ) -> Result<()> {
266 let output = unsafe {
267 let handle = event.as_handle().as_raw_handle() as usize as u64;
268 self.ioctl(
269 proxyioctl::IOCTL_VMBUS_PROXY_OPEN_CHANNEL,
270 StaticIoctlBuffer(proxyioctl::VMBUS_PROXY_OPEN_CHANNEL_INPUT {
271 ProxyId: id,
272 Padding: 0,
273 OpenParameters: *params,
274 VmmSignalEvent: handle,
275 }),
276 StaticIoctlBuffer(zeroed::<proxyioctl::VMBUS_PROXY_OPEN_CHANNEL_OUTPUT>()),
277 )
278 .await?
279 .0
280 };
281 NTSTATUS(output.Status).ok()
282 }
283 pub async fn set_interrupt(&self, id: u64, event: &Event) -> Result<()> {
284 unsafe {
285 let handle = event.as_handle().as_raw_handle() as usize as u64;
286 self.ioctl(
287 proxyioctl::IOCTL_VMBUS_PROXY_RESTORE_SET_INTERRUPT,
288 StaticIoctlBuffer(proxyioctl::VMBUS_PROXY_SET_INTERRUPT_INPUT {
289 ProxyId: id,
290 VmmSignalEvent: handle,
291 }),
292 (),
293 )
294 .await?
295 };
296 Ok(())
297 }
298
299 pub async fn restore(
300 &self,
301 interface_type: Guid,
302 interface_instance: Guid,
303 subchannel_index: u16,
304 target_vtl: u8,
305 open_params: Option<VMBUS_SERVER_OPEN_CHANNEL_OUTPUT_PARAMETERS>,
306 gpadls: impl Iterator<Item = Gpadl<'_>>,
307 ) -> Result<u64> {
308 let mut buffer = Vec::new();
309 let mut header = proxyioctl::VMBUS_PROXY_RESTORE_CHANNEL_INPUT {
310 InterfaceType: interface_type,
311 InterfaceInstance: interface_instance,
312 SubchannelIndex: subchannel_index,
313 TargetVtl: target_vtl,
314 GpadlCount: 0,
315 OpenParameters: open_params.unwrap_or_default(),
316 Open: open_params.is_some().into(),
317 };
318
319 const HEADER_LEN: usize = size_of::<proxyioctl::VMBUS_PROXY_RESTORE_CHANNEL_INPUT>();
321 buffer.resize(HEADER_LEN, 0);
322
323 for gpadl in gpadls {
325 header.GpadlCount += 1;
326 Self::add_gpadl(
327 &mut buffer,
328 0, gpadl.gpadl_id,
330 gpadl.range_count,
331 gpadl.range_buffer.as_bytes(),
332 );
333 }
334
335 buffer[..HEADER_LEN].copy_from_slice(header.as_bytes());
337 Ok(unsafe {
338 self.ioctl(
339 proxyioctl::IOCTL_VMBUS_PROXY_RESTORE_CHANNEL,
340 buffer,
341 StaticIoctlBuffer(zeroed::<proxyioctl::VMBUS_PROXY_RESTORE_CHANNEL_OUTPUT>()),
342 )
343 .await?
344 .0
345 .ProxyId
346 })
347 }
348
349 pub async fn revoke_unclaimed_channels(&self) -> Result<()> {
350 unsafe {
351 self.ioctl(
352 proxyioctl::IOCTL_VMBUS_PROXY_REVOKE_UNCLAIMED_CHANNELS,
353 (),
354 (),
355 )
356 .await
357 }
358 }
359
360 pub async fn close(&self, id: u64) -> Result<()> {
361 unsafe {
362 self.ioctl(
363 proxyioctl::IOCTL_VMBUS_PROXY_CLOSE_CHANNEL,
364 StaticIoctlBuffer(proxyioctl::VMBUS_PROXY_CLOSE_CHANNEL_INPUT { ProxyId: id }),
365 (),
366 )
367 .await
368 }
369 }
370
371 pub async fn release(&self, id: u64) -> Result<()> {
372 unsafe {
373 self.ioctl(
374 proxyioctl::IOCTL_VMBUS_PROXY_RELEASE_CHANNEL,
375 StaticIoctlBuffer(proxyioctl::VMBUS_PROXY_RELEASE_CHANNEL_INPUT { ProxyId: id }),
376 (),
377 )
378 .await
379 }
380 }
381
382 pub async fn create_gpadl(
383 &self,
384 id: u64,
385 gpadl_id: u32,
386 range_count: u32,
387 range_buf: &[u8],
388 ) -> Result<()> {
389 let mut buf = Vec::new();
390 Self::add_gpadl(&mut buf, id, gpadl_id, range_count, range_buf);
391 unsafe {
392 self.ioctl(proxyioctl::IOCTL_VMBUS_PROXY_CREATE_GPADL, buf, ())
393 .await
394 }
395 }
396
397 pub async fn delete_gpadl(&self, id: u64, gpadl_id: u32) -> Result<()> {
398 unsafe {
399 self.ioctl(
400 proxyioctl::IOCTL_VMBUS_PROXY_DELETE_GPADL,
401 StaticIoctlBuffer(proxyioctl::VMBUS_PROXY_DELETE_GPADL_INPUT {
402 ProxyId: id,
403 GpadlId: gpadl_id,
404 Padding: 0,
405 }),
406 (),
407 )
408 .await
409 }
410 }
411
412 pub async fn tl_connect_request(&self, request: &HvsockConnectRequest, vtl: u8) -> Result<()> {
413 unsafe {
414 self.ioctl(
415 proxyioctl::IOCTL_VMBUS_PROXY_TL_CONNECT_REQUEST,
416 StaticIoctlBuffer(proxyioctl::VMBUS_PROXY_TL_CONNECT_REQUEST_INPUT {
417 EndpoindId: request.endpoint_id.into(),
418 ServiceId: request.service_id.into(),
419 SiloId: request.silo_id.into(),
420 Flags: proxyioctl::VMBUS_PROXY_TL_CONNECT_REQUEST_FLAGS::new()
421 .with_hosted_silo_unaware(request.hosted_silo_unaware),
422 Vtl: vtl,
423 Padding: [0; 3],
424 }),
425 (),
426 )
427 .await
428 }
429 }
430
431 pub fn run_channel(&self, id: u64) -> Result<()> {
432 unsafe {
433 let input = proxyioctl::VMBUS_PROXY_RUN_CHANNEL_INPUT { ProxyId: id };
435 let mut bytes = 0;
436 DeviceIoControl(
437 HANDLE(self.file.get().as_raw_handle()),
438 proxyioctl::IOCTL_VMBUS_PROXY_RUN_CHANNEL,
439 Some(std::ptr::from_ref(&input).cast()),
440 size_of_val(&input) as u32,
441 None,
442 0,
443 Some(&mut bytes),
444 None,
445 )?;
446 };
447 Ok(())
448 }
449
450 fn add_gpadl(
452 buffer: &mut Vec<u8>,
453 id: u64,
454 gpadl_id: u32,
455 range_count: u32,
456 range_buffer: &[u8],
457 ) {
458 let header = proxyioctl::VMBUS_PROXY_CREATE_GPADL_INPUT {
459 ProxyId: id,
460 GpadlId: gpadl_id,
461 RangeCount: range_count,
462 RangeBufferOffset: size_of::<proxyioctl::VMBUS_PROXY_CREATE_GPADL_INPUT>() as u32,
463 RangeBufferSize: range_buffer.len() as u32,
464 };
465 buffer.extend_from_slice(header.as_bytes());
466 buffer.extend_from_slice(range_buffer);
467 }
468}
469
470pub struct Gpadl<'a> {
472 pub gpadl_id: u32,
473 pub range_count: u32,
474 pub range_buffer: &'a [u64],
475}