vmbus_proxy/
lib.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4#![expect(missing_docs)]
5#![cfg(windows)]
6// UNSAFETY: Calling vmbus proxy ioctls.
7#![expect(unsafe_code)]
8#![expect(clippy::undocumented_unsafe_blocks, clippy::missing_safety_doc)]
9
10use futures::poll;
11use guestmem::GuestMemory;
12use guid::Guid;
13use mesh::CancelContext;
14use mesh::MeshPayload;
15use pal::windows::ObjectAttributes;
16use pal::windows::UnicodeStringRef;
17use pal_async::driver::Driver;
18use pal_async::windows::overlapped::IoBuf;
19use pal_async::windows::overlapped::IoBufMut;
20use pal_async::windows::overlapped::OverlappedFile;
21use pal_event::Event;
22use std::mem::ManuallyDrop;
23use std::mem::zeroed;
24use std::num::NonZeroU32;
25use std::os::windows::prelude::*;
26use vmbus_core::HvsockConnectRequest;
27use vmbus_core::HvsockConnectResult;
28use vmbusioctl::VMBUS_CHANNEL_OFFER;
29use vmbusioctl::VMBUS_SERVER_OPEN_CHANNEL_OUTPUT_PARAMETERS;
30use widestring::Utf16Str;
31use widestring::utf16str;
32use windows::Wdk::Storage::FileSystem::NtOpenFile;
33use windows::Win32::Foundation::ERROR_OPERATION_ABORTED;
34use windows::Win32::Foundation::HANDLE;
35use windows::Win32::Foundation::NTSTATUS;
36use windows::Win32::Storage::FileSystem::FILE_ALL_ACCESS;
37use windows::Win32::Storage::FileSystem::SYNCHRONIZE;
38use windows::Win32::System::IO::DeviceIoControl;
39use zerocopy::IntoBytes;
40
41mod proxyioctl;
42pub mod vmbusioctl;
43
44pub type Error = windows::core::Error;
45pub type Result<T> = windows::core::Result<T>;
46
47/// A VM handle the VMBus proxy driver.
48#[derive(Debug, MeshPayload)]
49pub struct ProxyHandle(std::fs::File);
50
51impl ProxyHandle {
52    /// Creates a new VM handle.
53    pub fn new() -> Result<Self> {
54        const DEVICE_PATH: &Utf16Str = utf16str!("\\Device\\VmbusProxy");
55        let pathu = UnicodeStringRef::try_from(DEVICE_PATH).expect("string fits");
56        let mut oa = ObjectAttributes::new();
57        oa.name(&pathu);
58        // SAFETY: calling API according to docs.
59        unsafe {
60            let mut iosb = zeroed();
61            let mut handle = HANDLE::default();
62            NtOpenFile(
63                &mut handle,
64                (FILE_ALL_ACCESS | SYNCHRONIZE).0,
65                oa.as_ref(),
66                &mut iosb,
67                0,
68                0,
69            )
70            .ok()?;
71            Ok(Self(std::fs::File::from_raw_handle(handle.0 as RawHandle)))
72        }
73    }
74}
75
76impl From<OwnedHandle> for ProxyHandle {
77    /// Create a `ProxyHandle` from an existing VM handle.
78    fn from(value: OwnedHandle) -> Self {
79        Self(value.into())
80    }
81}
82
83pub struct VmbusProxy {
84    file: ManuallyDrop<OverlappedFile>,
85    guest_memory: Option<GuestMemory>,
86    cancel: CancelContext,
87    drop_send: Option<mesh::OneshotSender<()>>,
88}
89
90impl Drop for VmbusProxy {
91    fn drop(&mut self) {
92        // If a GuestMemory was used, clear it from the kernel before it gets destructed.
93        // N.B. Not all versions of the proxy driver support this operation, in which case this
94        //      will fail with ERROR_INVALID_FUNCTION or ERROR_NOT_SUPPORTED.
95        if self.guest_memory.is_some() {
96            // Since we are not in an async method, issue this ioctl synchronously.
97            if let Err(err) = self.ioctl_sync(
98                proxyioctl::IOCTL_VMBUS_PROXY_DETACH,
99                &proxyioctl::VMBUS_PROXY_DETACH_INPUT {
100                    DetachChannels: false,
101                },
102            ) {
103                tracing::warn!(
104                    err = &err as &dyn std::error::Error,
105                    "failed to clear proxy driver guest memory"
106                );
107            }
108        }
109
110        // SAFETY: VmbusProxy is being dropped so can no longer be used.
111        let file = unsafe { ManuallyDrop::take(&mut self.file) };
112
113        // Extract the inner file to dissociate the I/O completion port. This is required so the
114        // file object can be reused in case of handle brokering.
115        file.into_inner();
116        if let Some(drop_send) = self.drop_send.take() {
117            drop_send.send(());
118        }
119    }
120}
121
122#[derive(Debug)]
123pub enum ProxyAction {
124    Offer {
125        id: u64,
126        offer: VMBUS_CHANNEL_OFFER,
127        incoming_event: Event,
128        outgoing_event: Option<Event>,
129        device_order: Option<NonZeroU32>,
130    },
131    Revoke {
132        id: u64,
133    },
134    InterruptPolicy {},
135    TlConnectResult {
136        result: HvsockConnectResult,
137        vtl: u8,
138    },
139}
140
141struct StaticIoctlBuffer<T>(T);
142
143// SAFETY: this is not generically safe, so callers must be careful to only use
144// this newtype for values that can be safely passed to overlapped IO.
145unsafe impl<T> IoBuf for StaticIoctlBuffer<T> {
146    fn as_ptr(&self) -> *const u8 {
147        std::ptr::from_ref::<Self>(self).cast()
148    }
149
150    fn len(&self) -> usize {
151        size_of_val(self)
152    }
153}
154
155// SAFETY: this is not generically safe, so callers must be careful to only use
156// this newtype for values that can be safely passed to overlapped IO.
157unsafe impl<T> IoBufMut for StaticIoctlBuffer<T> {
158    fn as_mut_ptr(&mut self) -> *mut u8 {
159        std::ptr::from_mut::<Self>(self).cast()
160    }
161}
162
163impl VmbusProxy {
164    /// Creates a new `VmbusProxy` from a [`ProxyHandle`]. When the `VmbusProxy` instance is
165    /// dropped, `drop_send` is signaled. This allows users to wait until all IO is guaranteed to
166    /// be finished and the IO completion port is disassociated.
167    pub fn new(
168        driver: &dyn Driver,
169        handle: ProxyHandle,
170        ctx: CancelContext,
171        drop_send: mesh::OneshotSender<()>,
172    ) -> Result<Self> {
173        // SAFETY: This handle is duplicated and can be shared with other devices, so safety depends
174        // on this being the only user of the handle for overlapped IO.
175        let file = unsafe { OverlappedFile::new(driver, handle.0)? };
176        Ok(Self {
177            file: ManuallyDrop::new(file),
178            guest_memory: None,
179            cancel: ctx,
180            drop_send: Some(drop_send),
181        })
182    }
183
184    pub fn handle(&self) -> BorrowedHandle<'_> {
185        self.file.get().as_handle()
186    }
187
188    async unsafe fn ioctl<In, Out>(&self, code: u32, input: In, output: Out) -> Result<Out>
189    where
190        In: IoBufMut,
191        Out: IoBufMut,
192    {
193        // SAFETY: guaranteed by caller.
194        let (r, (_, output)) = unsafe { self.file.ioctl(code, input, output).await };
195        let size = r?;
196        assert_eq!(size, output.len(), "ioctl returned unexpected size");
197        Ok(output)
198    }
199
200    async unsafe fn ioctl_cancellable<In, Out>(
201        &self,
202        code: u32,
203        input: In,
204        output: Out,
205    ) -> Result<Out>
206    where
207        In: IoBufMut,
208        Out: IoBufMut,
209    {
210        // Don't issue new IO if the cancel context has already been cancelled.
211        let mut cancel = self.cancel.clone();
212        if cancel.is_cancelled() {
213            tracing::trace!("ioctl cancelled before issued");
214            return Err(ERROR_OPERATION_ABORTED.into());
215        }
216
217        // SAFETY: guaranteed by caller.
218        let mut ioctl = unsafe { self.file.ioctl(code, input, output) };
219        let (r, (_, output)) = match poll!(&mut ioctl) {
220            std::task::Poll::Ready(result) => result,
221            std::task::Poll::Pending => {
222                match cancel.until_cancelled(&mut ioctl).await {
223                    Ok(r) => r,
224                    Err(_) => {
225                        tracing::trace!("ioctl cancelled after issued");
226                        ioctl.cancel();
227                        // Even when cancelled, we must wait to complete the IO so buffers aren't released
228                        // while still in use.
229                        ioctl.await
230                    }
231                }
232            }
233        };
234
235        let size = r?;
236        assert_eq!(size, output.len(), "ioctl returned unexpected size");
237        Ok(output)
238    }
239
240    fn ioctl_sync<T>(&self, code: u32, input: &T) -> Result<()>
241    where
242        T: IntoBytes + zerocopy::Immutable,
243    {
244        // SAFETY: Calling API as documented
245        unsafe {
246            let mut bytes = 0;
247            DeviceIoControl(
248                HANDLE(self.file.get().as_raw_handle()),
249                code,
250                Some(input.as_bytes().as_ptr().cast()),
251                size_of_val(input) as u32,
252                None,
253                0,
254                Some(&mut bytes),
255                None,
256            )
257        }
258    }
259
260    pub async fn set_memory(&mut self, guest_memory: &GuestMemory) -> Result<()> {
261        assert!(self.guest_memory.is_none());
262        let (base, len) = guest_memory.full_mapping().ok_or_else(|| {
263            std::io::Error::other("vmbusproxy not supported without mapped memory")
264        })?;
265        self.guest_memory = Some(guest_memory.clone());
266        unsafe {
267            self.ioctl(
268                proxyioctl::IOCTL_VMBUS_PROXY_SET_MEMORY,
269                StaticIoctlBuffer(proxyioctl::VMBUS_PROXY_SET_MEMORY_INPUT {
270                    BaseAddress: base as usize as u64,
271                    Size: len as u64,
272                }),
273                (),
274            )
275            .await
276        }
277    }
278
279    pub async fn next_action(&self) -> Result<ProxyAction> {
280        let output = unsafe {
281            self.ioctl_cancellable(
282                proxyioctl::IOCTL_VMBUS_PROXY_NEXT_ACTION,
283                (),
284                StaticIoctlBuffer(zeroed::<proxyioctl::VMBUS_PROXY_NEXT_ACTION_OUTPUT>()),
285            )
286            .await?
287            .0
288        };
289        match output.Type {
290            proxyioctl::VmbusProxyActionTypeOffer => unsafe {
291                Ok(ProxyAction::Offer {
292                    id: output.ProxyId,
293                    offer: output.u.Offer.Offer,
294                    incoming_event: OwnedHandle::from_raw_handle(
295                        output.u.Offer.DeviceIncomingRingEvent as usize as RawHandle,
296                    )
297                    .into(),
298                    outgoing_event: if output.u.Offer.DeviceOutgoingRingEvent != 0 {
299                        Some(
300                            OwnedHandle::from_raw_handle(
301                                output.u.Offer.DeviceOutgoingRingEvent as usize as RawHandle,
302                            )
303                            .into(),
304                        )
305                    } else {
306                        None
307                    },
308                    device_order: NonZeroU32::new(output.u.Offer.DeviceOrder),
309                })
310            },
311            proxyioctl::VmbusProxyActionTypeRevoke => {
312                Ok(ProxyAction::Revoke { id: output.ProxyId })
313            }
314            proxyioctl::VmbusProxyActionTypeInterruptPolicy => Ok(ProxyAction::InterruptPolicy {}),
315            proxyioctl::VmbusProxyActionTypeTlConnectResult => unsafe {
316                Ok(ProxyAction::TlConnectResult {
317                    result: HvsockConnectResult {
318                        endpoint_id: output.u.TlConnectResult.EndpointId.into(),
319                        service_id: output.u.TlConnectResult.ServiceId.into(),
320                        success: output.u.TlConnectResult.Status.is_ok(),
321                    },
322                    vtl: output.u.TlConnectResult.Vtl,
323                })
324            },
325            n => panic!("unexpected action: {}", n),
326        }
327    }
328
329    pub async fn open(
330        &self,
331        id: u64,
332        params: &VMBUS_SERVER_OPEN_CHANNEL_OUTPUT_PARAMETERS,
333        event: &Event,
334    ) -> Result<()> {
335        let output = unsafe {
336            let handle = event.as_handle().as_raw_handle() as usize as u64;
337            self.ioctl(
338                proxyioctl::IOCTL_VMBUS_PROXY_OPEN_CHANNEL,
339                StaticIoctlBuffer(proxyioctl::VMBUS_PROXY_OPEN_CHANNEL_INPUT {
340                    ProxyId: id,
341                    Padding: 0,
342                    OpenParameters: *params,
343                    VmmSignalEvent: handle,
344                }),
345                StaticIoctlBuffer(zeroed::<proxyioctl::VMBUS_PROXY_OPEN_CHANNEL_OUTPUT>()),
346            )
347            .await?
348            .0
349        };
350        NTSTATUS(output.Status).ok()
351    }
352    pub async fn set_interrupt(&self, id: u64, event: &Event) -> Result<()> {
353        unsafe {
354            let handle = event.as_handle().as_raw_handle() as usize as u64;
355            self.ioctl(
356                proxyioctl::IOCTL_VMBUS_PROXY_RESTORE_SET_INTERRUPT,
357                StaticIoctlBuffer(proxyioctl::VMBUS_PROXY_SET_INTERRUPT_INPUT {
358                    ProxyId: id,
359                    VmmSignalEvent: handle,
360                }),
361                (),
362            )
363            .await?
364        };
365        Ok(())
366    }
367
368    pub async fn restore(
369        &self,
370        interface_type: Guid,
371        interface_instance: Guid,
372        subchannel_index: u16,
373        target_vtl: u8,
374        open_params: Option<VMBUS_SERVER_OPEN_CHANNEL_OUTPUT_PARAMETERS>,
375        gpadls: impl Iterator<Item = Gpadl<'_>>,
376    ) -> Result<u64> {
377        let mut buffer = Vec::new();
378        let mut header = proxyioctl::VMBUS_PROXY_RESTORE_CHANNEL_INPUT {
379            InterfaceType: interface_type,
380            InterfaceInstance: interface_instance,
381            SubchannelIndex: subchannel_index,
382            TargetVtl: target_vtl,
383            GpadlCount: 0,
384            OpenParameters: open_params.unwrap_or_default(),
385            Open: open_params.is_some().into(),
386        };
387
388        // Leave space for the header.
389        const HEADER_LEN: usize = size_of::<proxyioctl::VMBUS_PROXY_RESTORE_CHANNEL_INPUT>();
390        buffer.resize(HEADER_LEN, 0);
391
392        // Add GPADLs to the buffer and count them.
393        for gpadl in gpadls {
394            header.GpadlCount += 1;
395            Self::add_gpadl(
396                &mut buffer,
397                0, // Not used for restoring.
398                gpadl.gpadl_id,
399                gpadl.range_count,
400                gpadl.range_buffer.as_bytes(),
401            );
402        }
403
404        // Copy the header now that the GPADL count is known.
405        buffer[..HEADER_LEN].copy_from_slice(header.as_bytes());
406        Ok(unsafe {
407            self.ioctl(
408                proxyioctl::IOCTL_VMBUS_PROXY_RESTORE_CHANNEL,
409                buffer,
410                StaticIoctlBuffer(zeroed::<proxyioctl::VMBUS_PROXY_RESTORE_CHANNEL_OUTPUT>()),
411            )
412            .await?
413            .0
414            .ProxyId
415        })
416    }
417
418    pub async fn revoke_unclaimed_channels(&self) -> Result<()> {
419        unsafe {
420            self.ioctl(
421                proxyioctl::IOCTL_VMBUS_PROXY_REVOKE_UNCLAIMED_CHANNELS,
422                (),
423                (),
424            )
425            .await
426        }
427    }
428
429    pub async fn close(&self, id: u64) -> Result<()> {
430        unsafe {
431            self.ioctl(
432                proxyioctl::IOCTL_VMBUS_PROXY_CLOSE_CHANNEL,
433                StaticIoctlBuffer(proxyioctl::VMBUS_PROXY_CLOSE_CHANNEL_INPUT { ProxyId: id }),
434                (),
435            )
436            .await
437        }
438    }
439
440    pub async fn release(&self, id: u64) -> Result<()> {
441        unsafe {
442            self.ioctl(
443                proxyioctl::IOCTL_VMBUS_PROXY_RELEASE_CHANNEL,
444                StaticIoctlBuffer(proxyioctl::VMBUS_PROXY_RELEASE_CHANNEL_INPUT { ProxyId: id }),
445                (),
446            )
447            .await
448        }
449    }
450
451    pub async fn create_gpadl(
452        &self,
453        id: u64,
454        gpadl_id: u32,
455        range_count: u32,
456        range_buf: &[u8],
457    ) -> Result<()> {
458        let mut buf = Vec::new();
459        Self::add_gpadl(&mut buf, id, gpadl_id, range_count, range_buf);
460        unsafe {
461            self.ioctl(proxyioctl::IOCTL_VMBUS_PROXY_CREATE_GPADL, buf, ())
462                .await
463        }
464    }
465
466    pub async fn delete_gpadl(&self, id: u64, gpadl_id: u32) -> Result<()> {
467        unsafe {
468            self.ioctl(
469                proxyioctl::IOCTL_VMBUS_PROXY_DELETE_GPADL,
470                StaticIoctlBuffer(proxyioctl::VMBUS_PROXY_DELETE_GPADL_INPUT {
471                    ProxyId: id,
472                    GpadlId: gpadl_id,
473                    Padding: 0,
474                }),
475                (),
476            )
477            .await
478        }
479    }
480
481    pub async fn tl_connect_request(&self, request: &HvsockConnectRequest, vtl: u8) -> Result<()> {
482        unsafe {
483            self.ioctl(
484                proxyioctl::IOCTL_VMBUS_PROXY_TL_CONNECT_REQUEST,
485                StaticIoctlBuffer(proxyioctl::VMBUS_PROXY_TL_CONNECT_REQUEST_INPUT {
486                    EndpoindId: request.endpoint_id.into(),
487                    ServiceId: request.service_id.into(),
488                    SiloId: request.silo_id.into(),
489                    Flags: proxyioctl::VMBUS_PROXY_TL_CONNECT_REQUEST_FLAGS::new()
490                        .with_hosted_silo_unaware(request.hosted_silo_unaware),
491                    Vtl: vtl,
492                    Padding: [0; 3],
493                }),
494                (),
495            )
496            .await
497        }
498    }
499
500    pub fn run_channel(&self, id: u64) -> Result<()> {
501        let input = proxyioctl::VMBUS_PROXY_RUN_CHANNEL_INPUT { ProxyId: id };
502        self.ioctl_sync(proxyioctl::IOCTL_VMBUS_PROXY_RUN_CHANNEL, &input)
503    }
504
505    /// Adds GPADL ioctl data to a buffer.
506    fn add_gpadl(
507        buffer: &mut Vec<u8>,
508        id: u64,
509        gpadl_id: u32,
510        range_count: u32,
511        range_buffer: &[u8],
512    ) {
513        let header = proxyioctl::VMBUS_PROXY_CREATE_GPADL_INPUT {
514            ProxyId: id,
515            GpadlId: gpadl_id,
516            RangeCount: range_count,
517            RangeBufferOffset: size_of::<proxyioctl::VMBUS_PROXY_CREATE_GPADL_INPUT>() as u32,
518            RangeBufferSize: range_buffer.len() as u32,
519        };
520        buffer.extend_from_slice(header.as_bytes());
521        buffer.extend_from_slice(range_buffer);
522    }
523}
524
525/// Represents data to be restored for a GPADL.
526pub struct Gpadl<'a> {
527    pub gpadl_id: u32,
528    pub range_count: u32,
529    pub range_buffer: &'a [u64],
530}