1#![expect(missing_docs)]
5#![cfg(windows)]
6#![expect(unsafe_code)]
8#![expect(clippy::undocumented_unsafe_blocks, clippy::missing_safety_doc)]
9
10use futures::poll;
11use guestmem::GuestMemory;
12use guid::Guid;
13use mesh::CancelContext;
14use mesh::MeshPayload;
15use pal::windows::ObjectAttributes;
16use pal::windows::UnicodeStringRef;
17use pal_async::driver::Driver;
18use pal_async::windows::overlapped::IoBuf;
19use pal_async::windows::overlapped::IoBufMut;
20use pal_async::windows::overlapped::OverlappedFile;
21use pal_event::Event;
22use std::mem::zeroed;
23use std::num::NonZeroU32;
24use std::os::windows::prelude::*;
25use vmbus_core::HvsockConnectRequest;
26use vmbus_core::HvsockConnectResult;
27use vmbusioctl::VMBUS_CHANNEL_OFFER;
28use vmbusioctl::VMBUS_SERVER_OPEN_CHANNEL_OUTPUT_PARAMETERS;
29use widestring::Utf16Str;
30use widestring::utf16str;
31use windows::Wdk::Storage::FileSystem::NtOpenFile;
32use windows::Win32::Foundation::ERROR_OPERATION_ABORTED;
33use windows::Win32::Foundation::HANDLE;
34use windows::Win32::Foundation::NTSTATUS;
35use windows::Win32::Storage::FileSystem::FILE_ALL_ACCESS;
36use windows::Win32::Storage::FileSystem::SYNCHRONIZE;
37use windows::Win32::System::IO::DeviceIoControl;
38use zerocopy::IntoBytes;
39
40mod proxyioctl;
41pub mod vmbusioctl;
42
43pub type Error = windows::core::Error;
44pub type Result<T> = windows::core::Result<T>;
45
46#[derive(Debug, MeshPayload)]
48pub struct ProxyHandle(std::fs::File);
49
50impl ProxyHandle {
51 pub fn new() -> Result<Self> {
53 const DEVICE_PATH: &Utf16Str = utf16str!("\\Device\\VmbusProxy");
54 let pathu = UnicodeStringRef::try_from(DEVICE_PATH).expect("string fits");
55 let mut oa = ObjectAttributes::new();
56 oa.name(&pathu);
57 unsafe {
59 let mut iosb = zeroed();
60 let mut handle = HANDLE::default();
61 NtOpenFile(
62 &mut handle,
63 (FILE_ALL_ACCESS | SYNCHRONIZE).0,
64 oa.as_ref(),
65 &mut iosb,
66 0,
67 0,
68 )
69 .ok()?;
70 Ok(Self(std::fs::File::from_raw_handle(handle.0 as RawHandle)))
71 }
72 }
73}
74
75impl From<OwnedHandle> for ProxyHandle {
76 fn from(value: OwnedHandle) -> Self {
78 Self(value.into())
79 }
80}
81
82pub struct VmbusProxy {
83 file: OverlappedFile,
84 guest_memory: Option<GuestMemory>,
87 cancel: CancelContext,
88}
89
90#[derive(Debug)]
91pub enum ProxyAction {
92 Offer {
93 id: u64,
94 offer: VMBUS_CHANNEL_OFFER,
95 incoming_event: Event,
96 outgoing_event: Option<Event>,
97 device_order: Option<NonZeroU32>,
98 },
99 Revoke {
100 id: u64,
101 },
102 InterruptPolicy {},
103 TlConnectResult {
104 result: HvsockConnectResult,
105 vtl: u8,
106 },
107}
108
109struct StaticIoctlBuffer<T>(T);
110
111unsafe impl<T> IoBuf for StaticIoctlBuffer<T> {
114 fn as_ptr(&self) -> *const u8 {
115 std::ptr::from_ref::<Self>(self).cast()
116 }
117
118 fn len(&self) -> usize {
119 size_of_val(self)
120 }
121}
122
123unsafe impl<T> IoBufMut for StaticIoctlBuffer<T> {
126 fn as_mut_ptr(&mut self) -> *mut u8 {
127 std::ptr::from_mut::<Self>(self).cast()
128 }
129}
130
131impl VmbusProxy {
132 pub fn new(driver: &dyn Driver, handle: ProxyHandle, ctx: CancelContext) -> Result<Self> {
133 Ok(Self {
134 file: OverlappedFile::new(driver, handle.0)?,
135 guest_memory: None,
136 cancel: ctx,
137 })
138 }
139
140 pub fn handle(&self) -> BorrowedHandle<'_> {
141 self.file.get().as_handle()
142 }
143
144 async unsafe fn ioctl<In, Out>(&self, code: u32, input: In, output: Out) -> Result<Out>
145 where
146 In: IoBufMut,
147 Out: IoBufMut,
148 {
149 let (r, (_, output)) = unsafe { self.file.ioctl(code, input, output).await };
151 let size = r?;
152 assert_eq!(size, output.len(), "ioctl returned unexpected size");
153 Ok(output)
154 }
155
156 async unsafe fn ioctl_cancellable<In, Out>(
157 &self,
158 code: u32,
159 input: In,
160 output: Out,
161 ) -> Result<Out>
162 where
163 In: IoBufMut,
164 Out: IoBufMut,
165 {
166 let mut cancel = self.cancel.clone();
168 if cancel.is_cancelled() {
169 tracing::trace!("ioctl cancelled before issued");
170 return Err(ERROR_OPERATION_ABORTED.into());
171 }
172
173 let mut ioctl = unsafe { self.file.ioctl(code, input, output) };
175 let (r, (_, output)) = match poll!(&mut ioctl) {
176 std::task::Poll::Ready(result) => result,
177 std::task::Poll::Pending => {
178 match cancel.until_cancelled(&mut ioctl).await {
179 Ok(r) => r,
180 Err(_) => {
181 tracing::trace!("ioctl cancelled after issued");
182 ioctl.cancel();
183 ioctl.await
186 }
187 }
188 }
189 };
190
191 let size = r?;
192 assert_eq!(size, output.len(), "ioctl returned unexpected size");
193 Ok(output)
194 }
195
196 pub async fn set_memory(&mut self, guest_memory: &GuestMemory) -> Result<()> {
197 assert!(self.guest_memory.is_none());
198 let (base, len) = guest_memory.full_mapping().ok_or_else(|| {
199 std::io::Error::other("vmbusproxy not supported without mapped memory")
200 })?;
201 self.guest_memory = Some(guest_memory.clone());
202 unsafe {
203 self.ioctl(
204 proxyioctl::IOCTL_VMBUS_PROXY_SET_MEMORY,
205 StaticIoctlBuffer(proxyioctl::VMBUS_PROXY_SET_MEMORY_INPUT {
206 BaseAddress: base as usize as u64,
207 Size: len as u64,
208 }),
209 (),
210 )
211 .await
212 }
213 }
214
215 pub async fn next_action(&self) -> Result<ProxyAction> {
216 let output = unsafe {
217 self.ioctl_cancellable(
218 proxyioctl::IOCTL_VMBUS_PROXY_NEXT_ACTION,
219 (),
220 StaticIoctlBuffer(zeroed::<proxyioctl::VMBUS_PROXY_NEXT_ACTION_OUTPUT>()),
221 )
222 .await?
223 .0
224 };
225 match output.Type {
226 proxyioctl::VmbusProxyActionTypeOffer => unsafe {
227 Ok(ProxyAction::Offer {
228 id: output.ProxyId,
229 offer: output.u.Offer.Offer,
230 incoming_event: OwnedHandle::from_raw_handle(
231 output.u.Offer.DeviceIncomingRingEvent as usize as RawHandle,
232 )
233 .into(),
234 outgoing_event: if output.u.Offer.DeviceOutgoingRingEvent != 0 {
235 Some(
236 OwnedHandle::from_raw_handle(
237 output.u.Offer.DeviceOutgoingRingEvent as usize as RawHandle,
238 )
239 .into(),
240 )
241 } else {
242 None
243 },
244 device_order: NonZeroU32::new(output.u.Offer.DeviceOrder),
245 })
246 },
247 proxyioctl::VmbusProxyActionTypeRevoke => {
248 Ok(ProxyAction::Revoke { id: output.ProxyId })
249 }
250 proxyioctl::VmbusProxyActionTypeInterruptPolicy => Ok(ProxyAction::InterruptPolicy {}),
251 proxyioctl::VmbusProxyActionTypeTlConnectResult => unsafe {
252 Ok(ProxyAction::TlConnectResult {
253 result: HvsockConnectResult {
254 endpoint_id: output.u.TlConnectResult.EndpointId.into(),
255 service_id: output.u.TlConnectResult.ServiceId.into(),
256 success: output.u.TlConnectResult.Status.is_ok(),
257 },
258 vtl: output.u.TlConnectResult.Vtl,
259 })
260 },
261 n => panic!("unexpected action: {}", n),
262 }
263 }
264
265 pub async fn open(
266 &self,
267 id: u64,
268 params: &VMBUS_SERVER_OPEN_CHANNEL_OUTPUT_PARAMETERS,
269 event: &Event,
270 ) -> Result<()> {
271 let output = unsafe {
272 let handle = event.as_handle().as_raw_handle() as usize as u64;
273 self.ioctl(
274 proxyioctl::IOCTL_VMBUS_PROXY_OPEN_CHANNEL,
275 StaticIoctlBuffer(proxyioctl::VMBUS_PROXY_OPEN_CHANNEL_INPUT {
276 ProxyId: id,
277 Padding: 0,
278 OpenParameters: *params,
279 VmmSignalEvent: handle,
280 }),
281 StaticIoctlBuffer(zeroed::<proxyioctl::VMBUS_PROXY_OPEN_CHANNEL_OUTPUT>()),
282 )
283 .await?
284 .0
285 };
286 NTSTATUS(output.Status).ok()
287 }
288 pub async fn set_interrupt(&self, id: u64, event: &Event) -> Result<()> {
289 unsafe {
290 let handle = event.as_handle().as_raw_handle() as usize as u64;
291 self.ioctl(
292 proxyioctl::IOCTL_VMBUS_PROXY_RESTORE_SET_INTERRUPT,
293 StaticIoctlBuffer(proxyioctl::VMBUS_PROXY_SET_INTERRUPT_INPUT {
294 ProxyId: id,
295 VmmSignalEvent: handle,
296 }),
297 (),
298 )
299 .await?
300 };
301 Ok(())
302 }
303
304 pub async fn restore(
305 &self,
306 interface_type: Guid,
307 interface_instance: Guid,
308 subchannel_index: u16,
309 target_vtl: u8,
310 open_params: Option<VMBUS_SERVER_OPEN_CHANNEL_OUTPUT_PARAMETERS>,
311 gpadls: impl Iterator<Item = Gpadl<'_>>,
312 ) -> Result<u64> {
313 let mut buffer = Vec::new();
314 let mut header = proxyioctl::VMBUS_PROXY_RESTORE_CHANNEL_INPUT {
315 InterfaceType: interface_type,
316 InterfaceInstance: interface_instance,
317 SubchannelIndex: subchannel_index,
318 TargetVtl: target_vtl,
319 GpadlCount: 0,
320 OpenParameters: open_params.unwrap_or_default(),
321 Open: open_params.is_some().into(),
322 };
323
324 const HEADER_LEN: usize = size_of::<proxyioctl::VMBUS_PROXY_RESTORE_CHANNEL_INPUT>();
326 buffer.resize(HEADER_LEN, 0);
327
328 for gpadl in gpadls {
330 header.GpadlCount += 1;
331 Self::add_gpadl(
332 &mut buffer,
333 0, gpadl.gpadl_id,
335 gpadl.range_count,
336 gpadl.range_buffer.as_bytes(),
337 );
338 }
339
340 buffer[..HEADER_LEN].copy_from_slice(header.as_bytes());
342 Ok(unsafe {
343 self.ioctl(
344 proxyioctl::IOCTL_VMBUS_PROXY_RESTORE_CHANNEL,
345 buffer,
346 StaticIoctlBuffer(zeroed::<proxyioctl::VMBUS_PROXY_RESTORE_CHANNEL_OUTPUT>()),
347 )
348 .await?
349 .0
350 .ProxyId
351 })
352 }
353
354 pub async fn revoke_unclaimed_channels(&self) -> Result<()> {
355 unsafe {
356 self.ioctl(
357 proxyioctl::IOCTL_VMBUS_PROXY_REVOKE_UNCLAIMED_CHANNELS,
358 (),
359 (),
360 )
361 .await
362 }
363 }
364
365 pub async fn close(&self, id: u64) -> Result<()> {
366 unsafe {
367 self.ioctl(
368 proxyioctl::IOCTL_VMBUS_PROXY_CLOSE_CHANNEL,
369 StaticIoctlBuffer(proxyioctl::VMBUS_PROXY_CLOSE_CHANNEL_INPUT { ProxyId: id }),
370 (),
371 )
372 .await
373 }
374 }
375
376 pub async fn release(&self, id: u64) -> Result<()> {
377 unsafe {
378 self.ioctl(
379 proxyioctl::IOCTL_VMBUS_PROXY_RELEASE_CHANNEL,
380 StaticIoctlBuffer(proxyioctl::VMBUS_PROXY_RELEASE_CHANNEL_INPUT { ProxyId: id }),
381 (),
382 )
383 .await
384 }
385 }
386
387 pub async fn create_gpadl(
388 &self,
389 id: u64,
390 gpadl_id: u32,
391 range_count: u32,
392 range_buf: &[u8],
393 ) -> Result<()> {
394 let mut buf = Vec::new();
395 Self::add_gpadl(&mut buf, id, gpadl_id, range_count, range_buf);
396 unsafe {
397 self.ioctl(proxyioctl::IOCTL_VMBUS_PROXY_CREATE_GPADL, buf, ())
398 .await
399 }
400 }
401
402 pub async fn delete_gpadl(&self, id: u64, gpadl_id: u32) -> Result<()> {
403 unsafe {
404 self.ioctl(
405 proxyioctl::IOCTL_VMBUS_PROXY_DELETE_GPADL,
406 StaticIoctlBuffer(proxyioctl::VMBUS_PROXY_DELETE_GPADL_INPUT {
407 ProxyId: id,
408 GpadlId: gpadl_id,
409 Padding: 0,
410 }),
411 (),
412 )
413 .await
414 }
415 }
416
417 pub async fn tl_connect_request(&self, request: &HvsockConnectRequest, vtl: u8) -> Result<()> {
418 unsafe {
419 self.ioctl(
420 proxyioctl::IOCTL_VMBUS_PROXY_TL_CONNECT_REQUEST,
421 StaticIoctlBuffer(proxyioctl::VMBUS_PROXY_TL_CONNECT_REQUEST_INPUT {
422 EndpoindId: request.endpoint_id.into(),
423 ServiceId: request.service_id.into(),
424 SiloId: request.silo_id.into(),
425 Flags: proxyioctl::VMBUS_PROXY_TL_CONNECT_REQUEST_FLAGS::new()
426 .with_hosted_silo_unaware(request.hosted_silo_unaware),
427 Vtl: vtl,
428 Padding: [0; 3],
429 }),
430 (),
431 )
432 .await
433 }
434 }
435
436 pub fn run_channel(&self, id: u64) -> Result<()> {
437 unsafe {
438 let input = proxyioctl::VMBUS_PROXY_RUN_CHANNEL_INPUT { ProxyId: id };
440 let mut bytes = 0;
441 DeviceIoControl(
442 HANDLE(self.file.get().as_raw_handle()),
443 proxyioctl::IOCTL_VMBUS_PROXY_RUN_CHANNEL,
444 Some(std::ptr::from_ref(&input).cast()),
445 size_of_val(&input) as u32,
446 None,
447 0,
448 Some(&mut bytes),
449 None,
450 )?;
451 };
452 Ok(())
453 }
454
455 fn add_gpadl(
457 buffer: &mut Vec<u8>,
458 id: u64,
459 gpadl_id: u32,
460 range_count: u32,
461 range_buffer: &[u8],
462 ) {
463 let header = proxyioctl::VMBUS_PROXY_CREATE_GPADL_INPUT {
464 ProxyId: id,
465 GpadlId: gpadl_id,
466 RangeCount: range_count,
467 RangeBufferOffset: size_of::<proxyioctl::VMBUS_PROXY_CREATE_GPADL_INPUT>() as u32,
468 RangeBufferSize: range_buffer.len() as u32,
469 };
470 buffer.extend_from_slice(header.as_bytes());
471 buffer.extend_from_slice(range_buffer);
472 }
473}
474
475pub struct Gpadl<'a> {
477 pub gpadl_id: u32,
478 pub range_count: u32,
479 pub range_buffer: &'a [u64],
480}