1#![expect(missing_docs)]
5#![cfg(windows)]
6#![expect(unsafe_code)]
8#![expect(clippy::undocumented_unsafe_blocks, clippy::missing_safety_doc)]
9
10use futures::poll;
11use guestmem::GuestMemory;
12use guid::Guid;
13use mesh::CancelContext;
14use mesh::MeshPayload;
15use pal::windows::ObjectAttributes;
16use pal::windows::UnicodeStringRef;
17use pal_async::driver::Driver;
18use pal_async::windows::overlapped::IoBuf;
19use pal_async::windows::overlapped::IoBufMut;
20use pal_async::windows::overlapped::OverlappedFile;
21use pal_event::Event;
22use std::mem::zeroed;
23use std::num::NonZeroU32;
24use std::os::windows::prelude::*;
25use vmbus_core::HvsockConnectRequest;
26use vmbus_core::HvsockConnectResult;
27use vmbusioctl::VMBUS_CHANNEL_OFFER;
28use vmbusioctl::VMBUS_SERVER_OPEN_CHANNEL_OUTPUT_PARAMETERS;
29use widestring::Utf16Str;
30use widestring::utf16str;
31use windows::Wdk::Storage::FileSystem::NtOpenFile;
32use windows::Win32::Foundation::ERROR_OPERATION_ABORTED;
33use windows::Win32::Foundation::HANDLE;
34use windows::Win32::Foundation::NTSTATUS;
35use windows::Win32::Storage::FileSystem::FILE_ALL_ACCESS;
36use windows::Win32::Storage::FileSystem::SYNCHRONIZE;
37use windows::Win32::System::IO::DeviceIoControl;
38use zerocopy::IntoBytes;
39
40mod proxyioctl;
41pub mod vmbusioctl;
42
43pub type Error = windows::core::Error;
44pub type Result<T> = windows::core::Result<T>;
45
46#[derive(Debug, MeshPayload)]
48pub struct ProxyHandle(std::fs::File);
49
50impl ProxyHandle {
51 pub fn new() -> Result<Self> {
53 const DEVICE_PATH: &Utf16Str = utf16str!("\\Device\\VmbusProxy");
54 let pathu = UnicodeStringRef::try_from(DEVICE_PATH).expect("string fits");
55 let mut oa = ObjectAttributes::new();
56 oa.name(&pathu);
57 unsafe {
59 let mut iosb = zeroed();
60 let mut handle = HANDLE::default();
61 NtOpenFile(
62 &mut handle,
63 (FILE_ALL_ACCESS | SYNCHRONIZE).0,
64 oa.as_ref(),
65 &mut iosb,
66 0,
67 0,
68 )
69 .ok()?;
70 Ok(Self(std::fs::File::from_raw_handle(handle.0 as RawHandle)))
71 }
72 }
73}
74
75impl From<OwnedHandle> for ProxyHandle {
76 fn from(value: OwnedHandle) -> Self {
78 Self(value.into())
79 }
80}
81
82pub struct VmbusProxy {
83 file: OverlappedFile,
84 guest_memory: Option<GuestMemory>,
87 cancel: CancelContext,
88}
89
90#[derive(Debug)]
91pub enum ProxyAction {
92 Offer {
93 id: u64,
94 offer: VMBUS_CHANNEL_OFFER,
95 incoming_event: Event,
96 outgoing_event: Option<Event>,
97 device_order: Option<NonZeroU32>,
98 },
99 Revoke {
100 id: u64,
101 },
102 InterruptPolicy {},
103 TlConnectResult {
104 result: HvsockConnectResult,
105 vtl: u8,
106 },
107}
108
109struct StaticIoctlBuffer<T>(T);
110
111unsafe impl<T> IoBuf for StaticIoctlBuffer<T> {
114 fn as_ptr(&self) -> *const u8 {
115 std::ptr::from_ref::<Self>(self).cast()
116 }
117
118 fn len(&self) -> usize {
119 size_of_val(self)
120 }
121}
122
123unsafe impl<T> IoBufMut for StaticIoctlBuffer<T> {
126 fn as_mut_ptr(&mut self) -> *mut u8 {
127 std::ptr::from_mut::<Self>(self).cast()
128 }
129}
130
131impl VmbusProxy {
132 pub fn new(driver: &dyn Driver, handle: ProxyHandle, ctx: CancelContext) -> Result<Self> {
133 let file = unsafe { OverlappedFile::new(driver, handle.0)? };
136 Ok(Self {
137 file,
138 guest_memory: None,
139 cancel: ctx,
140 })
141 }
142
143 pub fn handle(&self) -> BorrowedHandle<'_> {
144 self.file.get().as_handle()
145 }
146
147 async unsafe fn ioctl<In, Out>(&self, code: u32, input: In, output: Out) -> Result<Out>
148 where
149 In: IoBufMut,
150 Out: IoBufMut,
151 {
152 let (r, (_, output)) = unsafe { self.file.ioctl(code, input, output).await };
154 let size = r?;
155 assert_eq!(size, output.len(), "ioctl returned unexpected size");
156 Ok(output)
157 }
158
159 async unsafe fn ioctl_cancellable<In, Out>(
160 &self,
161 code: u32,
162 input: In,
163 output: Out,
164 ) -> Result<Out>
165 where
166 In: IoBufMut,
167 Out: IoBufMut,
168 {
169 let mut cancel = self.cancel.clone();
171 if cancel.is_cancelled() {
172 tracing::trace!("ioctl cancelled before issued");
173 return Err(ERROR_OPERATION_ABORTED.into());
174 }
175
176 let mut ioctl = unsafe { self.file.ioctl(code, input, output) };
178 let (r, (_, output)) = match poll!(&mut ioctl) {
179 std::task::Poll::Ready(result) => result,
180 std::task::Poll::Pending => {
181 match cancel.until_cancelled(&mut ioctl).await {
182 Ok(r) => r,
183 Err(_) => {
184 tracing::trace!("ioctl cancelled after issued");
185 ioctl.cancel();
186 ioctl.await
189 }
190 }
191 }
192 };
193
194 let size = r?;
195 assert_eq!(size, output.len(), "ioctl returned unexpected size");
196 Ok(output)
197 }
198
199 pub async fn set_memory(&mut self, guest_memory: &GuestMemory) -> Result<()> {
200 assert!(self.guest_memory.is_none());
201 let (base, len) = guest_memory.full_mapping().ok_or_else(|| {
202 std::io::Error::other("vmbusproxy not supported without mapped memory")
203 })?;
204 self.guest_memory = Some(guest_memory.clone());
205 unsafe {
206 self.ioctl(
207 proxyioctl::IOCTL_VMBUS_PROXY_SET_MEMORY,
208 StaticIoctlBuffer(proxyioctl::VMBUS_PROXY_SET_MEMORY_INPUT {
209 BaseAddress: base as usize as u64,
210 Size: len as u64,
211 }),
212 (),
213 )
214 .await
215 }
216 }
217
218 pub async fn next_action(&self) -> Result<ProxyAction> {
219 let output = unsafe {
220 self.ioctl_cancellable(
221 proxyioctl::IOCTL_VMBUS_PROXY_NEXT_ACTION,
222 (),
223 StaticIoctlBuffer(zeroed::<proxyioctl::VMBUS_PROXY_NEXT_ACTION_OUTPUT>()),
224 )
225 .await?
226 .0
227 };
228 match output.Type {
229 proxyioctl::VmbusProxyActionTypeOffer => unsafe {
230 Ok(ProxyAction::Offer {
231 id: output.ProxyId,
232 offer: output.u.Offer.Offer,
233 incoming_event: OwnedHandle::from_raw_handle(
234 output.u.Offer.DeviceIncomingRingEvent as usize as RawHandle,
235 )
236 .into(),
237 outgoing_event: if output.u.Offer.DeviceOutgoingRingEvent != 0 {
238 Some(
239 OwnedHandle::from_raw_handle(
240 output.u.Offer.DeviceOutgoingRingEvent as usize as RawHandle,
241 )
242 .into(),
243 )
244 } else {
245 None
246 },
247 device_order: NonZeroU32::new(output.u.Offer.DeviceOrder),
248 })
249 },
250 proxyioctl::VmbusProxyActionTypeRevoke => {
251 Ok(ProxyAction::Revoke { id: output.ProxyId })
252 }
253 proxyioctl::VmbusProxyActionTypeInterruptPolicy => Ok(ProxyAction::InterruptPolicy {}),
254 proxyioctl::VmbusProxyActionTypeTlConnectResult => unsafe {
255 Ok(ProxyAction::TlConnectResult {
256 result: HvsockConnectResult {
257 endpoint_id: output.u.TlConnectResult.EndpointId.into(),
258 service_id: output.u.TlConnectResult.ServiceId.into(),
259 success: output.u.TlConnectResult.Status.is_ok(),
260 },
261 vtl: output.u.TlConnectResult.Vtl,
262 })
263 },
264 n => panic!("unexpected action: {}", n),
265 }
266 }
267
268 pub async fn open(
269 &self,
270 id: u64,
271 params: &VMBUS_SERVER_OPEN_CHANNEL_OUTPUT_PARAMETERS,
272 event: &Event,
273 ) -> Result<()> {
274 let output = unsafe {
275 let handle = event.as_handle().as_raw_handle() as usize as u64;
276 self.ioctl(
277 proxyioctl::IOCTL_VMBUS_PROXY_OPEN_CHANNEL,
278 StaticIoctlBuffer(proxyioctl::VMBUS_PROXY_OPEN_CHANNEL_INPUT {
279 ProxyId: id,
280 Padding: 0,
281 OpenParameters: *params,
282 VmmSignalEvent: handle,
283 }),
284 StaticIoctlBuffer(zeroed::<proxyioctl::VMBUS_PROXY_OPEN_CHANNEL_OUTPUT>()),
285 )
286 .await?
287 .0
288 };
289 NTSTATUS(output.Status).ok()
290 }
291 pub async fn set_interrupt(&self, id: u64, event: &Event) -> Result<()> {
292 unsafe {
293 let handle = event.as_handle().as_raw_handle() as usize as u64;
294 self.ioctl(
295 proxyioctl::IOCTL_VMBUS_PROXY_RESTORE_SET_INTERRUPT,
296 StaticIoctlBuffer(proxyioctl::VMBUS_PROXY_SET_INTERRUPT_INPUT {
297 ProxyId: id,
298 VmmSignalEvent: handle,
299 }),
300 (),
301 )
302 .await?
303 };
304 Ok(())
305 }
306
307 pub async fn restore(
308 &self,
309 interface_type: Guid,
310 interface_instance: Guid,
311 subchannel_index: u16,
312 target_vtl: u8,
313 open_params: Option<VMBUS_SERVER_OPEN_CHANNEL_OUTPUT_PARAMETERS>,
314 gpadls: impl Iterator<Item = Gpadl<'_>>,
315 ) -> Result<u64> {
316 let mut buffer = Vec::new();
317 let mut header = proxyioctl::VMBUS_PROXY_RESTORE_CHANNEL_INPUT {
318 InterfaceType: interface_type,
319 InterfaceInstance: interface_instance,
320 SubchannelIndex: subchannel_index,
321 TargetVtl: target_vtl,
322 GpadlCount: 0,
323 OpenParameters: open_params.unwrap_or_default(),
324 Open: open_params.is_some().into(),
325 };
326
327 const HEADER_LEN: usize = size_of::<proxyioctl::VMBUS_PROXY_RESTORE_CHANNEL_INPUT>();
329 buffer.resize(HEADER_LEN, 0);
330
331 for gpadl in gpadls {
333 header.GpadlCount += 1;
334 Self::add_gpadl(
335 &mut buffer,
336 0, gpadl.gpadl_id,
338 gpadl.range_count,
339 gpadl.range_buffer.as_bytes(),
340 );
341 }
342
343 buffer[..HEADER_LEN].copy_from_slice(header.as_bytes());
345 Ok(unsafe {
346 self.ioctl(
347 proxyioctl::IOCTL_VMBUS_PROXY_RESTORE_CHANNEL,
348 buffer,
349 StaticIoctlBuffer(zeroed::<proxyioctl::VMBUS_PROXY_RESTORE_CHANNEL_OUTPUT>()),
350 )
351 .await?
352 .0
353 .ProxyId
354 })
355 }
356
357 pub async fn revoke_unclaimed_channels(&self) -> Result<()> {
358 unsafe {
359 self.ioctl(
360 proxyioctl::IOCTL_VMBUS_PROXY_REVOKE_UNCLAIMED_CHANNELS,
361 (),
362 (),
363 )
364 .await
365 }
366 }
367
368 pub async fn close(&self, id: u64) -> Result<()> {
369 unsafe {
370 self.ioctl(
371 proxyioctl::IOCTL_VMBUS_PROXY_CLOSE_CHANNEL,
372 StaticIoctlBuffer(proxyioctl::VMBUS_PROXY_CLOSE_CHANNEL_INPUT { ProxyId: id }),
373 (),
374 )
375 .await
376 }
377 }
378
379 pub async fn release(&self, id: u64) -> Result<()> {
380 unsafe {
381 self.ioctl(
382 proxyioctl::IOCTL_VMBUS_PROXY_RELEASE_CHANNEL,
383 StaticIoctlBuffer(proxyioctl::VMBUS_PROXY_RELEASE_CHANNEL_INPUT { ProxyId: id }),
384 (),
385 )
386 .await
387 }
388 }
389
390 pub async fn create_gpadl(
391 &self,
392 id: u64,
393 gpadl_id: u32,
394 range_count: u32,
395 range_buf: &[u8],
396 ) -> Result<()> {
397 let mut buf = Vec::new();
398 Self::add_gpadl(&mut buf, id, gpadl_id, range_count, range_buf);
399 unsafe {
400 self.ioctl(proxyioctl::IOCTL_VMBUS_PROXY_CREATE_GPADL, buf, ())
401 .await
402 }
403 }
404
405 pub async fn delete_gpadl(&self, id: u64, gpadl_id: u32) -> Result<()> {
406 unsafe {
407 self.ioctl(
408 proxyioctl::IOCTL_VMBUS_PROXY_DELETE_GPADL,
409 StaticIoctlBuffer(proxyioctl::VMBUS_PROXY_DELETE_GPADL_INPUT {
410 ProxyId: id,
411 GpadlId: gpadl_id,
412 Padding: 0,
413 }),
414 (),
415 )
416 .await
417 }
418 }
419
420 pub async fn tl_connect_request(&self, request: &HvsockConnectRequest, vtl: u8) -> Result<()> {
421 unsafe {
422 self.ioctl(
423 proxyioctl::IOCTL_VMBUS_PROXY_TL_CONNECT_REQUEST,
424 StaticIoctlBuffer(proxyioctl::VMBUS_PROXY_TL_CONNECT_REQUEST_INPUT {
425 EndpoindId: request.endpoint_id.into(),
426 ServiceId: request.service_id.into(),
427 SiloId: request.silo_id.into(),
428 Flags: proxyioctl::VMBUS_PROXY_TL_CONNECT_REQUEST_FLAGS::new()
429 .with_hosted_silo_unaware(request.hosted_silo_unaware),
430 Vtl: vtl,
431 Padding: [0; 3],
432 }),
433 (),
434 )
435 .await
436 }
437 }
438
439 pub fn run_channel(&self, id: u64) -> Result<()> {
440 unsafe {
441 let input = proxyioctl::VMBUS_PROXY_RUN_CHANNEL_INPUT { ProxyId: id };
443 let mut bytes = 0;
444 DeviceIoControl(
445 HANDLE(self.file.get().as_raw_handle()),
446 proxyioctl::IOCTL_VMBUS_PROXY_RUN_CHANNEL,
447 Some(std::ptr::from_ref(&input).cast()),
448 size_of_val(&input) as u32,
449 None,
450 0,
451 Some(&mut bytes),
452 None,
453 )?;
454 };
455 Ok(())
456 }
457
458 fn add_gpadl(
460 buffer: &mut Vec<u8>,
461 id: u64,
462 gpadl_id: u32,
463 range_count: u32,
464 range_buffer: &[u8],
465 ) {
466 let header = proxyioctl::VMBUS_PROXY_CREATE_GPADL_INPUT {
467 ProxyId: id,
468 GpadlId: gpadl_id,
469 RangeCount: range_count,
470 RangeBufferOffset: size_of::<proxyioctl::VMBUS_PROXY_CREATE_GPADL_INPUT>() as u32,
471 RangeBufferSize: range_buffer.len() as u32,
472 };
473 buffer.extend_from_slice(header.as_bytes());
474 buffer.extend_from_slice(range_buffer);
475 }
476}
477
478pub struct Gpadl<'a> {
480 pub gpadl_id: u32,
481 pub range_count: u32,
482 pub range_buffer: &'a [u64],
483}