vmbus_proxy/
lib.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4#![expect(missing_docs)]
5#![cfg(windows)]
6// UNSAFETY: Calling vmbus proxy ioctls.
7#![expect(unsafe_code)]
8#![expect(clippy::undocumented_unsafe_blocks, clippy::missing_safety_doc)]
9
10use futures::poll;
11use guestmem::GuestMemory;
12use guid::Guid;
13use mesh::CancelContext;
14use mesh::MeshPayload;
15use pal::windows::ObjectAttributes;
16use pal::windows::UnicodeStringRef;
17use pal_async::driver::Driver;
18use pal_async::windows::overlapped::IoBuf;
19use pal_async::windows::overlapped::IoBufMut;
20use pal_async::windows::overlapped::OverlappedFile;
21use pal_event::Event;
22use std::mem::ManuallyDrop;
23use std::mem::zeroed;
24use std::num::NonZeroU32;
25use std::os::windows::prelude::*;
26use vmbus_core::HvsockConnectRequest;
27use vmbus_core::HvsockConnectResult;
28use vmbusioctl::VMBUS_CHANNEL_OFFER;
29use vmbusioctl::VMBUS_SERVER_OPEN_CHANNEL_OUTPUT_PARAMETERS;
30use widestring::Utf16Str;
31use widestring::utf16str;
32use windows::Wdk::Storage::FileSystem::NtOpenFile;
33use windows::Win32::Foundation::ERROR_OPERATION_ABORTED;
34use windows::Win32::Foundation::HANDLE;
35use windows::Win32::Foundation::NTSTATUS;
36use windows::Win32::Storage::FileSystem::FILE_ALL_ACCESS;
37use windows::Win32::Storage::FileSystem::SYNCHRONIZE;
38use windows::Win32::System::IO::DeviceIoControl;
39use zerocopy::IntoBytes;
40
41mod proxyioctl;
42pub mod vmbusioctl;
43
44pub type Error = windows::core::Error;
45pub type Result<T> = windows::core::Result<T>;
46
47/// A VM handle the VMBus proxy driver.
48#[derive(Debug, MeshPayload)]
49pub struct ProxyHandle(std::fs::File);
50
51impl ProxyHandle {
52    /// Creates a new VM handle.
53    pub fn new() -> Result<Self> {
54        const DEVICE_PATH: &Utf16Str = utf16str!("\\Device\\VmbusProxy");
55        let pathu = UnicodeStringRef::try_from(DEVICE_PATH).expect("string fits");
56        let mut oa = ObjectAttributes::new();
57        oa.name(&pathu);
58        // SAFETY: calling API according to docs.
59        unsafe {
60            let mut iosb = zeroed();
61            let mut handle = HANDLE::default();
62            NtOpenFile(
63                &mut handle,
64                (FILE_ALL_ACCESS | SYNCHRONIZE).0,
65                oa.as_ref(),
66                &mut iosb,
67                0,
68                0,
69            )
70            .ok()?;
71            Ok(Self(std::fs::File::from_raw_handle(handle.0 as RawHandle)))
72        }
73    }
74}
75
76impl From<OwnedHandle> for ProxyHandle {
77    /// Create a `ProxyHandle` from an existing VM handle.
78    fn from(value: OwnedHandle) -> Self {
79        Self(value.into())
80    }
81}
82
83pub struct VmbusProxy {
84    file: ManuallyDrop<OverlappedFile>,
85    // NOTE: This must come after `file` so that it is not released until `file`
86    // is closed.
87    guest_memory: Option<GuestMemory>,
88    cancel: CancelContext,
89}
90
91impl Drop for VmbusProxy {
92    fn drop(&mut self) {
93        // SAFETY: VmbusProxy is being dropped so can no longer be used.
94        let file = unsafe { ManuallyDrop::take(&mut self.file) };
95
96        // Extract the inner file to dissociate the I/O completion port. This is required so the
97        // file object can be reused in case of handle brokering.
98        file.into_inner();
99    }
100}
101
102#[derive(Debug)]
103pub enum ProxyAction {
104    Offer {
105        id: u64,
106        offer: VMBUS_CHANNEL_OFFER,
107        incoming_event: Event,
108        outgoing_event: Option<Event>,
109        device_order: Option<NonZeroU32>,
110    },
111    Revoke {
112        id: u64,
113    },
114    InterruptPolicy {},
115    TlConnectResult {
116        result: HvsockConnectResult,
117        vtl: u8,
118    },
119}
120
121struct StaticIoctlBuffer<T>(T);
122
123// SAFETY: this is not generically safe, so callers must be careful to only use
124// this newtype for values that can be safely passed to overlapped IO.
125unsafe impl<T> IoBuf for StaticIoctlBuffer<T> {
126    fn as_ptr(&self) -> *const u8 {
127        std::ptr::from_ref::<Self>(self).cast()
128    }
129
130    fn len(&self) -> usize {
131        size_of_val(self)
132    }
133}
134
135// SAFETY: this is not generically safe, so callers must be careful to only use
136// this newtype for values that can be safely passed to overlapped IO.
137unsafe impl<T> IoBufMut for StaticIoctlBuffer<T> {
138    fn as_mut_ptr(&mut self) -> *mut u8 {
139        std::ptr::from_mut::<Self>(self).cast()
140    }
141}
142
143impl VmbusProxy {
144    pub fn new(driver: &dyn Driver, handle: ProxyHandle, ctx: CancelContext) -> Result<Self> {
145        // SAFETY: This handle is duplicated and can be shared with other devices, so safety depends
146        // on this being the only user of the handle for overlapped IO.
147        let file = unsafe { OverlappedFile::new(driver, handle.0)? };
148        Ok(Self {
149            file: ManuallyDrop::new(file),
150            guest_memory: None,
151            cancel: ctx,
152        })
153    }
154
155    pub fn handle(&self) -> BorrowedHandle<'_> {
156        self.file.get().as_handle()
157    }
158
159    async unsafe fn ioctl<In, Out>(&self, code: u32, input: In, output: Out) -> Result<Out>
160    where
161        In: IoBufMut,
162        Out: IoBufMut,
163    {
164        // SAFETY: guaranteed by caller.
165        let (r, (_, output)) = unsafe { self.file.ioctl(code, input, output).await };
166        let size = r?;
167        assert_eq!(size, output.len(), "ioctl returned unexpected size");
168        Ok(output)
169    }
170
171    async unsafe fn ioctl_cancellable<In, Out>(
172        &self,
173        code: u32,
174        input: In,
175        output: Out,
176    ) -> Result<Out>
177    where
178        In: IoBufMut,
179        Out: IoBufMut,
180    {
181        // Don't issue new IO if the cancel context has already been cancelled.
182        let mut cancel = self.cancel.clone();
183        if cancel.is_cancelled() {
184            tracing::trace!("ioctl cancelled before issued");
185            return Err(ERROR_OPERATION_ABORTED.into());
186        }
187
188        // SAFETY: guaranteed by caller.
189        let mut ioctl = unsafe { self.file.ioctl(code, input, output) };
190        let (r, (_, output)) = match poll!(&mut ioctl) {
191            std::task::Poll::Ready(result) => result,
192            std::task::Poll::Pending => {
193                match cancel.until_cancelled(&mut ioctl).await {
194                    Ok(r) => r,
195                    Err(_) => {
196                        tracing::trace!("ioctl cancelled after issued");
197                        ioctl.cancel();
198                        // Even when cancelled, we must wait to complete the IO so buffers aren't released
199                        // while still in use.
200                        ioctl.await
201                    }
202                }
203            }
204        };
205
206        let size = r?;
207        assert_eq!(size, output.len(), "ioctl returned unexpected size");
208        Ok(output)
209    }
210
211    pub async fn set_memory(&mut self, guest_memory: &GuestMemory) -> Result<()> {
212        assert!(self.guest_memory.is_none());
213        let (base, len) = guest_memory.full_mapping().ok_or_else(|| {
214            std::io::Error::other("vmbusproxy not supported without mapped memory")
215        })?;
216        self.guest_memory = Some(guest_memory.clone());
217        unsafe {
218            self.ioctl(
219                proxyioctl::IOCTL_VMBUS_PROXY_SET_MEMORY,
220                StaticIoctlBuffer(proxyioctl::VMBUS_PROXY_SET_MEMORY_INPUT {
221                    BaseAddress: base as usize as u64,
222                    Size: len as u64,
223                }),
224                (),
225            )
226            .await
227        }
228    }
229
230    pub async fn next_action(&self) -> Result<ProxyAction> {
231        let output = unsafe {
232            self.ioctl_cancellable(
233                proxyioctl::IOCTL_VMBUS_PROXY_NEXT_ACTION,
234                (),
235                StaticIoctlBuffer(zeroed::<proxyioctl::VMBUS_PROXY_NEXT_ACTION_OUTPUT>()),
236            )
237            .await?
238            .0
239        };
240        match output.Type {
241            proxyioctl::VmbusProxyActionTypeOffer => unsafe {
242                Ok(ProxyAction::Offer {
243                    id: output.ProxyId,
244                    offer: output.u.Offer.Offer,
245                    incoming_event: OwnedHandle::from_raw_handle(
246                        output.u.Offer.DeviceIncomingRingEvent as usize as RawHandle,
247                    )
248                    .into(),
249                    outgoing_event: if output.u.Offer.DeviceOutgoingRingEvent != 0 {
250                        Some(
251                            OwnedHandle::from_raw_handle(
252                                output.u.Offer.DeviceOutgoingRingEvent as usize as RawHandle,
253                            )
254                            .into(),
255                        )
256                    } else {
257                        None
258                    },
259                    device_order: NonZeroU32::new(output.u.Offer.DeviceOrder),
260                })
261            },
262            proxyioctl::VmbusProxyActionTypeRevoke => {
263                Ok(ProxyAction::Revoke { id: output.ProxyId })
264            }
265            proxyioctl::VmbusProxyActionTypeInterruptPolicy => Ok(ProxyAction::InterruptPolicy {}),
266            proxyioctl::VmbusProxyActionTypeTlConnectResult => unsafe {
267                Ok(ProxyAction::TlConnectResult {
268                    result: HvsockConnectResult {
269                        endpoint_id: output.u.TlConnectResult.EndpointId.into(),
270                        service_id: output.u.TlConnectResult.ServiceId.into(),
271                        success: output.u.TlConnectResult.Status.is_ok(),
272                    },
273                    vtl: output.u.TlConnectResult.Vtl,
274                })
275            },
276            n => panic!("unexpected action: {}", n),
277        }
278    }
279
280    pub async fn open(
281        &self,
282        id: u64,
283        params: &VMBUS_SERVER_OPEN_CHANNEL_OUTPUT_PARAMETERS,
284        event: &Event,
285    ) -> Result<()> {
286        let output = unsafe {
287            let handle = event.as_handle().as_raw_handle() as usize as u64;
288            self.ioctl(
289                proxyioctl::IOCTL_VMBUS_PROXY_OPEN_CHANNEL,
290                StaticIoctlBuffer(proxyioctl::VMBUS_PROXY_OPEN_CHANNEL_INPUT {
291                    ProxyId: id,
292                    Padding: 0,
293                    OpenParameters: *params,
294                    VmmSignalEvent: handle,
295                }),
296                StaticIoctlBuffer(zeroed::<proxyioctl::VMBUS_PROXY_OPEN_CHANNEL_OUTPUT>()),
297            )
298            .await?
299            .0
300        };
301        NTSTATUS(output.Status).ok()
302    }
303    pub async fn set_interrupt(&self, id: u64, event: &Event) -> Result<()> {
304        unsafe {
305            let handle = event.as_handle().as_raw_handle() as usize as u64;
306            self.ioctl(
307                proxyioctl::IOCTL_VMBUS_PROXY_RESTORE_SET_INTERRUPT,
308                StaticIoctlBuffer(proxyioctl::VMBUS_PROXY_SET_INTERRUPT_INPUT {
309                    ProxyId: id,
310                    VmmSignalEvent: handle,
311                }),
312                (),
313            )
314            .await?
315        };
316        Ok(())
317    }
318
319    pub async fn restore(
320        &self,
321        interface_type: Guid,
322        interface_instance: Guid,
323        subchannel_index: u16,
324        target_vtl: u8,
325        open_params: Option<VMBUS_SERVER_OPEN_CHANNEL_OUTPUT_PARAMETERS>,
326        gpadls: impl Iterator<Item = Gpadl<'_>>,
327    ) -> Result<u64> {
328        let mut buffer = Vec::new();
329        let mut header = proxyioctl::VMBUS_PROXY_RESTORE_CHANNEL_INPUT {
330            InterfaceType: interface_type,
331            InterfaceInstance: interface_instance,
332            SubchannelIndex: subchannel_index,
333            TargetVtl: target_vtl,
334            GpadlCount: 0,
335            OpenParameters: open_params.unwrap_or_default(),
336            Open: open_params.is_some().into(),
337        };
338
339        // Leave space for the header.
340        const HEADER_LEN: usize = size_of::<proxyioctl::VMBUS_PROXY_RESTORE_CHANNEL_INPUT>();
341        buffer.resize(HEADER_LEN, 0);
342
343        // Add GPADLs to the buffer and count them.
344        for gpadl in gpadls {
345            header.GpadlCount += 1;
346            Self::add_gpadl(
347                &mut buffer,
348                0, // Not used for restoring.
349                gpadl.gpadl_id,
350                gpadl.range_count,
351                gpadl.range_buffer.as_bytes(),
352            );
353        }
354
355        // Copy the header now that the GPADL count is known.
356        buffer[..HEADER_LEN].copy_from_slice(header.as_bytes());
357        Ok(unsafe {
358            self.ioctl(
359                proxyioctl::IOCTL_VMBUS_PROXY_RESTORE_CHANNEL,
360                buffer,
361                StaticIoctlBuffer(zeroed::<proxyioctl::VMBUS_PROXY_RESTORE_CHANNEL_OUTPUT>()),
362            )
363            .await?
364            .0
365            .ProxyId
366        })
367    }
368
369    pub async fn revoke_unclaimed_channels(&self) -> Result<()> {
370        unsafe {
371            self.ioctl(
372                proxyioctl::IOCTL_VMBUS_PROXY_REVOKE_UNCLAIMED_CHANNELS,
373                (),
374                (),
375            )
376            .await
377        }
378    }
379
380    pub async fn close(&self, id: u64) -> Result<()> {
381        unsafe {
382            self.ioctl(
383                proxyioctl::IOCTL_VMBUS_PROXY_CLOSE_CHANNEL,
384                StaticIoctlBuffer(proxyioctl::VMBUS_PROXY_CLOSE_CHANNEL_INPUT { ProxyId: id }),
385                (),
386            )
387            .await
388        }
389    }
390
391    pub async fn release(&self, id: u64) -> Result<()> {
392        unsafe {
393            self.ioctl(
394                proxyioctl::IOCTL_VMBUS_PROXY_RELEASE_CHANNEL,
395                StaticIoctlBuffer(proxyioctl::VMBUS_PROXY_RELEASE_CHANNEL_INPUT { ProxyId: id }),
396                (),
397            )
398            .await
399        }
400    }
401
402    pub async fn create_gpadl(
403        &self,
404        id: u64,
405        gpadl_id: u32,
406        range_count: u32,
407        range_buf: &[u8],
408    ) -> Result<()> {
409        let mut buf = Vec::new();
410        Self::add_gpadl(&mut buf, id, gpadl_id, range_count, range_buf);
411        unsafe {
412            self.ioctl(proxyioctl::IOCTL_VMBUS_PROXY_CREATE_GPADL, buf, ())
413                .await
414        }
415    }
416
417    pub async fn delete_gpadl(&self, id: u64, gpadl_id: u32) -> Result<()> {
418        unsafe {
419            self.ioctl(
420                proxyioctl::IOCTL_VMBUS_PROXY_DELETE_GPADL,
421                StaticIoctlBuffer(proxyioctl::VMBUS_PROXY_DELETE_GPADL_INPUT {
422                    ProxyId: id,
423                    GpadlId: gpadl_id,
424                    Padding: 0,
425                }),
426                (),
427            )
428            .await
429        }
430    }
431
432    pub async fn tl_connect_request(&self, request: &HvsockConnectRequest, vtl: u8) -> Result<()> {
433        unsafe {
434            self.ioctl(
435                proxyioctl::IOCTL_VMBUS_PROXY_TL_CONNECT_REQUEST,
436                StaticIoctlBuffer(proxyioctl::VMBUS_PROXY_TL_CONNECT_REQUEST_INPUT {
437                    EndpoindId: request.endpoint_id.into(),
438                    ServiceId: request.service_id.into(),
439                    SiloId: request.silo_id.into(),
440                    Flags: proxyioctl::VMBUS_PROXY_TL_CONNECT_REQUEST_FLAGS::new()
441                        .with_hosted_silo_unaware(request.hosted_silo_unaware),
442                    Vtl: vtl,
443                    Padding: [0; 3],
444                }),
445                (),
446            )
447            .await
448        }
449    }
450
451    pub fn run_channel(&self, id: u64) -> Result<()> {
452        unsafe {
453            // This is a synchronous operation, so don't use the async IO infrastructure.
454            let input = proxyioctl::VMBUS_PROXY_RUN_CHANNEL_INPUT { ProxyId: id };
455            let mut bytes = 0;
456            DeviceIoControl(
457                HANDLE(self.file.get().as_raw_handle()),
458                proxyioctl::IOCTL_VMBUS_PROXY_RUN_CHANNEL,
459                Some(std::ptr::from_ref(&input).cast()),
460                size_of_val(&input) as u32,
461                None,
462                0,
463                Some(&mut bytes),
464                None,
465            )?;
466        };
467        Ok(())
468    }
469
470    /// Adds GPADL ioctl data to a buffer.
471    fn add_gpadl(
472        buffer: &mut Vec<u8>,
473        id: u64,
474        gpadl_id: u32,
475        range_count: u32,
476        range_buffer: &[u8],
477    ) {
478        let header = proxyioctl::VMBUS_PROXY_CREATE_GPADL_INPUT {
479            ProxyId: id,
480            GpadlId: gpadl_id,
481            RangeCount: range_count,
482            RangeBufferOffset: size_of::<proxyioctl::VMBUS_PROXY_CREATE_GPADL_INPUT>() as u32,
483            RangeBufferSize: range_buffer.len() as u32,
484        };
485        buffer.extend_from_slice(header.as_bytes());
486        buffer.extend_from_slice(range_buffer);
487    }
488}
489
490/// Represents data to be restored for a GPADL.
491pub struct Gpadl<'a> {
492    pub gpadl_id: u32,
493    pub range_count: u32,
494    pub range_buffer: &'a [u64],
495}