1#![expect(missing_docs)]
5#![cfg(windows)]
6#![expect(unsafe_code)]
8#![expect(clippy::undocumented_unsafe_blocks, clippy::missing_safety_doc)]
9
10use futures::poll;
11use guestmem::GuestMemory;
12use guid::Guid;
13use mesh::CancelContext;
14use mesh::MeshPayload;
15use pal::windows::ObjectAttributes;
16use pal::windows::UnicodeStringRef;
17use pal_async::driver::Driver;
18use pal_async::windows::overlapped::IoBuf;
19use pal_async::windows::overlapped::IoBufMut;
20use pal_async::windows::overlapped::OverlappedFile;
21use pal_event::Event;
22use std::mem::ManuallyDrop;
23use std::mem::zeroed;
24use std::num::NonZeroU32;
25use std::os::windows::prelude::*;
26use vmbus_core::HvsockConnectRequest;
27use vmbus_core::HvsockConnectResult;
28use vmbusioctl::VMBUS_CHANNEL_OFFER;
29use vmbusioctl::VMBUS_SERVER_OPEN_CHANNEL_OUTPUT_PARAMETERS;
30use widestring::Utf16Str;
31use widestring::utf16str;
32use windows::Wdk::Storage::FileSystem::NtOpenFile;
33use windows::Win32::Foundation::ERROR_OPERATION_ABORTED;
34use windows::Win32::Foundation::HANDLE;
35use windows::Win32::Foundation::NTSTATUS;
36use windows::Win32::Storage::FileSystem::FILE_ALL_ACCESS;
37use windows::Win32::Storage::FileSystem::SYNCHRONIZE;
38use windows::Win32::System::IO::DeviceIoControl;
39use zerocopy::IntoBytes;
40
41mod proxyioctl;
42pub mod vmbusioctl;
43
44pub type Error = windows::core::Error;
45pub type Result<T> = windows::core::Result<T>;
46
47#[derive(Debug, MeshPayload)]
49pub struct ProxyHandle(std::fs::File);
50
51impl ProxyHandle {
52 pub fn new() -> Result<Self> {
54 const DEVICE_PATH: &Utf16Str = utf16str!("\\Device\\VmbusProxy");
55 let pathu = UnicodeStringRef::try_from(DEVICE_PATH).expect("string fits");
56 let mut oa = ObjectAttributes::new();
57 oa.name(&pathu);
58 unsafe {
60 let mut iosb = zeroed();
61 let mut handle = HANDLE::default();
62 NtOpenFile(
63 &mut handle,
64 (FILE_ALL_ACCESS | SYNCHRONIZE).0,
65 oa.as_ref(),
66 &mut iosb,
67 0,
68 0,
69 )
70 .ok()?;
71 Ok(Self(std::fs::File::from_raw_handle(handle.0 as RawHandle)))
72 }
73 }
74}
75
76impl From<OwnedHandle> for ProxyHandle {
77 fn from(value: OwnedHandle) -> Self {
79 Self(value.into())
80 }
81}
82
83pub struct VmbusProxy {
84 file: ManuallyDrop<OverlappedFile>,
85 guest_memory: Option<GuestMemory>,
86 cancel: CancelContext,
87 drop_send: Option<mesh::OneshotSender<()>>,
88}
89
90impl Drop for VmbusProxy {
91 fn drop(&mut self) {
92 if self.guest_memory.is_some() {
96 if let Err(err) = self.ioctl_sync(
98 proxyioctl::IOCTL_VMBUS_PROXY_DETACH,
99 &proxyioctl::VMBUS_PROXY_DETACH_INPUT {
100 DetachChannels: false,
101 },
102 ) {
103 tracing::warn!(
104 err = &err as &dyn std::error::Error,
105 "failed to clear proxy driver guest memory"
106 );
107 }
108 }
109
110 let file = unsafe { ManuallyDrop::take(&mut self.file) };
112
113 file.into_inner();
116 if let Some(drop_send) = self.drop_send.take() {
117 drop_send.send(());
118 }
119 }
120}
121
122#[derive(Debug)]
123pub enum ProxyAction {
124 Offer {
125 id: u64,
126 offer: VMBUS_CHANNEL_OFFER,
127 incoming_event: Event,
128 outgoing_event: Option<Event>,
129 device_order: Option<NonZeroU32>,
130 },
131 Revoke {
132 id: u64,
133 },
134 InterruptPolicy {},
135 TlConnectResult {
136 result: HvsockConnectResult,
137 vtl: u8,
138 },
139}
140
141struct StaticIoctlBuffer<T>(T);
142
143unsafe impl<T> IoBuf for StaticIoctlBuffer<T> {
146 fn as_ptr(&self) -> *const u8 {
147 std::ptr::from_ref::<Self>(self).cast()
148 }
149
150 fn len(&self) -> usize {
151 size_of_val(self)
152 }
153}
154
155unsafe impl<T> IoBufMut for StaticIoctlBuffer<T> {
158 fn as_mut_ptr(&mut self) -> *mut u8 {
159 std::ptr::from_mut::<Self>(self).cast()
160 }
161}
162
163impl VmbusProxy {
164 pub fn new(
168 driver: &dyn Driver,
169 handle: ProxyHandle,
170 ctx: CancelContext,
171 drop_send: mesh::OneshotSender<()>,
172 ) -> Result<Self> {
173 let file = unsafe { OverlappedFile::new(driver, handle.0)? };
176 Ok(Self {
177 file: ManuallyDrop::new(file),
178 guest_memory: None,
179 cancel: ctx,
180 drop_send: Some(drop_send),
181 })
182 }
183
184 pub fn handle(&self) -> BorrowedHandle<'_> {
185 self.file.get().as_handle()
186 }
187
188 async unsafe fn ioctl<In, Out>(&self, code: u32, input: In, output: Out) -> Result<Out>
189 where
190 In: IoBufMut,
191 Out: IoBufMut,
192 {
193 let (r, (_, output)) = unsafe { self.file.ioctl(code, input, output).await };
195 let size = r?;
196 assert_eq!(size, output.len(), "ioctl returned unexpected size");
197 Ok(output)
198 }
199
200 async unsafe fn ioctl_cancellable<In, Out>(
201 &self,
202 code: u32,
203 input: In,
204 output: Out,
205 ) -> Result<Out>
206 where
207 In: IoBufMut,
208 Out: IoBufMut,
209 {
210 let mut cancel = self.cancel.clone();
212 if cancel.is_cancelled() {
213 tracing::trace!("ioctl cancelled before issued");
214 return Err(ERROR_OPERATION_ABORTED.into());
215 }
216
217 let mut ioctl = unsafe { self.file.ioctl(code, input, output) };
219 let (r, (_, output)) = match poll!(&mut ioctl) {
220 std::task::Poll::Ready(result) => result,
221 std::task::Poll::Pending => {
222 match cancel.until_cancelled(&mut ioctl).await {
223 Ok(r) => r,
224 Err(_) => {
225 tracing::trace!("ioctl cancelled after issued");
226 ioctl.cancel();
227 ioctl.await
230 }
231 }
232 }
233 };
234
235 let size = r?;
236 assert_eq!(size, output.len(), "ioctl returned unexpected size");
237 Ok(output)
238 }
239
240 fn ioctl_sync<T>(&self, code: u32, input: &T) -> Result<()>
241 where
242 T: IntoBytes + zerocopy::Immutable,
243 {
244 unsafe {
246 let mut bytes = 0;
247 DeviceIoControl(
248 HANDLE(self.file.get().as_raw_handle()),
249 code,
250 Some(input.as_bytes().as_ptr().cast()),
251 size_of_val(input) as u32,
252 None,
253 0,
254 Some(&mut bytes),
255 None,
256 )
257 }
258 }
259
260 pub async fn set_memory(&mut self, guest_memory: &GuestMemory) -> Result<()> {
261 assert!(self.guest_memory.is_none());
262 let (base, len) = guest_memory.full_mapping().ok_or_else(|| {
263 std::io::Error::other("vmbusproxy not supported without mapped memory")
264 })?;
265 self.guest_memory = Some(guest_memory.clone());
266 unsafe {
267 self.ioctl(
268 proxyioctl::IOCTL_VMBUS_PROXY_SET_MEMORY,
269 StaticIoctlBuffer(proxyioctl::VMBUS_PROXY_SET_MEMORY_INPUT {
270 BaseAddress: base as usize as u64,
271 Size: len as u64,
272 }),
273 (),
274 )
275 .await
276 }
277 }
278
279 pub async fn next_action(&self) -> Result<ProxyAction> {
280 let output = unsafe {
281 self.ioctl_cancellable(
282 proxyioctl::IOCTL_VMBUS_PROXY_NEXT_ACTION,
283 (),
284 StaticIoctlBuffer(zeroed::<proxyioctl::VMBUS_PROXY_NEXT_ACTION_OUTPUT>()),
285 )
286 .await?
287 .0
288 };
289 match output.Type {
290 proxyioctl::VmbusProxyActionTypeOffer => unsafe {
291 Ok(ProxyAction::Offer {
292 id: output.ProxyId,
293 offer: output.u.Offer.Offer,
294 incoming_event: OwnedHandle::from_raw_handle(
295 output.u.Offer.DeviceIncomingRingEvent as usize as RawHandle,
296 )
297 .into(),
298 outgoing_event: if output.u.Offer.DeviceOutgoingRingEvent != 0 {
299 Some(
300 OwnedHandle::from_raw_handle(
301 output.u.Offer.DeviceOutgoingRingEvent as usize as RawHandle,
302 )
303 .into(),
304 )
305 } else {
306 None
307 },
308 device_order: NonZeroU32::new(output.u.Offer.DeviceOrder),
309 })
310 },
311 proxyioctl::VmbusProxyActionTypeRevoke => {
312 Ok(ProxyAction::Revoke { id: output.ProxyId })
313 }
314 proxyioctl::VmbusProxyActionTypeInterruptPolicy => Ok(ProxyAction::InterruptPolicy {}),
315 proxyioctl::VmbusProxyActionTypeTlConnectResult => unsafe {
316 Ok(ProxyAction::TlConnectResult {
317 result: HvsockConnectResult {
318 endpoint_id: output.u.TlConnectResult.EndpointId.into(),
319 service_id: output.u.TlConnectResult.ServiceId.into(),
320 success: output.u.TlConnectResult.Status.is_ok(),
321 },
322 vtl: output.u.TlConnectResult.Vtl,
323 })
324 },
325 n => panic!("unexpected action: {}", n),
326 }
327 }
328
329 pub async fn open(
330 &self,
331 id: u64,
332 params: &VMBUS_SERVER_OPEN_CHANNEL_OUTPUT_PARAMETERS,
333 event: &Event,
334 ) -> Result<()> {
335 let output = unsafe {
336 let handle = event.as_handle().as_raw_handle() as usize as u64;
337 self.ioctl(
338 proxyioctl::IOCTL_VMBUS_PROXY_OPEN_CHANNEL,
339 StaticIoctlBuffer(proxyioctl::VMBUS_PROXY_OPEN_CHANNEL_INPUT {
340 ProxyId: id,
341 Padding: 0,
342 OpenParameters: *params,
343 VmmSignalEvent: handle,
344 }),
345 StaticIoctlBuffer(zeroed::<proxyioctl::VMBUS_PROXY_OPEN_CHANNEL_OUTPUT>()),
346 )
347 .await?
348 .0
349 };
350 NTSTATUS(output.Status).ok()
351 }
352 pub async fn set_interrupt(&self, id: u64, event: &Event) -> Result<()> {
353 unsafe {
354 let handle = event.as_handle().as_raw_handle() as usize as u64;
355 self.ioctl(
356 proxyioctl::IOCTL_VMBUS_PROXY_RESTORE_SET_INTERRUPT,
357 StaticIoctlBuffer(proxyioctl::VMBUS_PROXY_SET_INTERRUPT_INPUT {
358 ProxyId: id,
359 VmmSignalEvent: handle,
360 }),
361 (),
362 )
363 .await?
364 };
365 Ok(())
366 }
367
368 pub async fn restore(
369 &self,
370 interface_type: Guid,
371 interface_instance: Guid,
372 subchannel_index: u16,
373 target_vtl: u8,
374 open_params: Option<VMBUS_SERVER_OPEN_CHANNEL_OUTPUT_PARAMETERS>,
375 gpadls: impl Iterator<Item = Gpadl<'_>>,
376 ) -> Result<u64> {
377 let mut buffer = Vec::new();
378 let mut header = proxyioctl::VMBUS_PROXY_RESTORE_CHANNEL_INPUT {
379 InterfaceType: interface_type,
380 InterfaceInstance: interface_instance,
381 SubchannelIndex: subchannel_index,
382 TargetVtl: target_vtl,
383 GpadlCount: 0,
384 OpenParameters: open_params.unwrap_or_default(),
385 Open: open_params.is_some().into(),
386 };
387
388 const HEADER_LEN: usize = size_of::<proxyioctl::VMBUS_PROXY_RESTORE_CHANNEL_INPUT>();
390 buffer.resize(HEADER_LEN, 0);
391
392 for gpadl in gpadls {
394 header.GpadlCount += 1;
395 Self::add_gpadl(
396 &mut buffer,
397 0, gpadl.gpadl_id,
399 gpadl.range_count,
400 gpadl.range_buffer.as_bytes(),
401 );
402 }
403
404 buffer[..HEADER_LEN].copy_from_slice(header.as_bytes());
406 Ok(unsafe {
407 self.ioctl(
408 proxyioctl::IOCTL_VMBUS_PROXY_RESTORE_CHANNEL,
409 buffer,
410 StaticIoctlBuffer(zeroed::<proxyioctl::VMBUS_PROXY_RESTORE_CHANNEL_OUTPUT>()),
411 )
412 .await?
413 .0
414 .ProxyId
415 })
416 }
417
418 pub async fn revoke_unclaimed_channels(&self) -> Result<()> {
419 unsafe {
420 self.ioctl(
421 proxyioctl::IOCTL_VMBUS_PROXY_REVOKE_UNCLAIMED_CHANNELS,
422 (),
423 (),
424 )
425 .await
426 }
427 }
428
429 pub async fn close(&self, id: u64) -> Result<()> {
430 unsafe {
431 self.ioctl(
432 proxyioctl::IOCTL_VMBUS_PROXY_CLOSE_CHANNEL,
433 StaticIoctlBuffer(proxyioctl::VMBUS_PROXY_CLOSE_CHANNEL_INPUT { ProxyId: id }),
434 (),
435 )
436 .await
437 }
438 }
439
440 pub async fn release(&self, id: u64) -> Result<()> {
441 unsafe {
442 self.ioctl(
443 proxyioctl::IOCTL_VMBUS_PROXY_RELEASE_CHANNEL,
444 StaticIoctlBuffer(proxyioctl::VMBUS_PROXY_RELEASE_CHANNEL_INPUT { ProxyId: id }),
445 (),
446 )
447 .await
448 }
449 }
450
451 pub async fn create_gpadl(
452 &self,
453 id: u64,
454 gpadl_id: u32,
455 range_count: u32,
456 range_buf: &[u8],
457 ) -> Result<()> {
458 let mut buf = Vec::new();
459 Self::add_gpadl(&mut buf, id, gpadl_id, range_count, range_buf);
460 unsafe {
461 self.ioctl(proxyioctl::IOCTL_VMBUS_PROXY_CREATE_GPADL, buf, ())
462 .await
463 }
464 }
465
466 pub async fn delete_gpadl(&self, id: u64, gpadl_id: u32) -> Result<()> {
467 unsafe {
468 self.ioctl(
469 proxyioctl::IOCTL_VMBUS_PROXY_DELETE_GPADL,
470 StaticIoctlBuffer(proxyioctl::VMBUS_PROXY_DELETE_GPADL_INPUT {
471 ProxyId: id,
472 GpadlId: gpadl_id,
473 Padding: 0,
474 }),
475 (),
476 )
477 .await
478 }
479 }
480
481 pub async fn tl_connect_request(&self, request: &HvsockConnectRequest, vtl: u8) -> Result<()> {
482 unsafe {
483 self.ioctl(
484 proxyioctl::IOCTL_VMBUS_PROXY_TL_CONNECT_REQUEST,
485 StaticIoctlBuffer(proxyioctl::VMBUS_PROXY_TL_CONNECT_REQUEST_INPUT {
486 EndpoindId: request.endpoint_id.into(),
487 ServiceId: request.service_id.into(),
488 SiloId: request.silo_id.into(),
489 Flags: proxyioctl::VMBUS_PROXY_TL_CONNECT_REQUEST_FLAGS::new()
490 .with_hosted_silo_unaware(request.hosted_silo_unaware),
491 Vtl: vtl,
492 Padding: [0; 3],
493 }),
494 (),
495 )
496 .await
497 }
498 }
499
500 pub fn run_channel(&self, id: u64) -> Result<()> {
501 let input = proxyioctl::VMBUS_PROXY_RUN_CHANNEL_INPUT { ProxyId: id };
502 self.ioctl_sync(proxyioctl::IOCTL_VMBUS_PROXY_RUN_CHANNEL, &input)
503 }
504
505 fn add_gpadl(
507 buffer: &mut Vec<u8>,
508 id: u64,
509 gpadl_id: u32,
510 range_count: u32,
511 range_buffer: &[u8],
512 ) {
513 let header = proxyioctl::VMBUS_PROXY_CREATE_GPADL_INPUT {
514 ProxyId: id,
515 GpadlId: gpadl_id,
516 RangeCount: range_count,
517 RangeBufferOffset: size_of::<proxyioctl::VMBUS_PROXY_CREATE_GPADL_INPUT>() as u32,
518 RangeBufferSize: range_buffer.len() as u32,
519 };
520 buffer.extend_from_slice(header.as_bytes());
521 buffer.extend_from_slice(range_buffer);
522 }
523}
524
525pub struct Gpadl<'a> {
527 pub gpadl_id: u32,
528 pub range_count: u32,
529 pub range_buffer: &'a [u64],
530}