vmbus_proxy/
lib.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4#![expect(missing_docs)]
5#![cfg(windows)]
6// UNSAFETY: Calling vmbus proxy ioctls.
7#![expect(unsafe_code)]
8#![expect(clippy::undocumented_unsafe_blocks, clippy::missing_safety_doc)]
9
10use futures::poll;
11use guestmem::GuestMemory;
12use guid::Guid;
13use mesh::CancelContext;
14use mesh::MeshPayload;
15use pal::windows::ObjectAttributes;
16use pal::windows::UnicodeStringRef;
17use pal_async::driver::Driver;
18use pal_async::windows::overlapped::IoBuf;
19use pal_async::windows::overlapped::IoBufMut;
20use pal_async::windows::overlapped::OverlappedFile;
21use pal_event::Event;
22use std::mem::zeroed;
23use std::num::NonZeroU32;
24use std::os::windows::prelude::*;
25use vmbus_core::HvsockConnectRequest;
26use vmbus_core::HvsockConnectResult;
27use vmbusioctl::VMBUS_CHANNEL_OFFER;
28use vmbusioctl::VMBUS_SERVER_OPEN_CHANNEL_OUTPUT_PARAMETERS;
29use widestring::Utf16Str;
30use widestring::utf16str;
31use windows::Wdk::Storage::FileSystem::NtOpenFile;
32use windows::Win32::Foundation::ERROR_OPERATION_ABORTED;
33use windows::Win32::Foundation::HANDLE;
34use windows::Win32::Foundation::NTSTATUS;
35use windows::Win32::Storage::FileSystem::FILE_ALL_ACCESS;
36use windows::Win32::Storage::FileSystem::SYNCHRONIZE;
37use windows::Win32::System::IO::DeviceIoControl;
38use zerocopy::IntoBytes;
39
40mod proxyioctl;
41pub mod vmbusioctl;
42
43pub type Error = windows::core::Error;
44pub type Result<T> = windows::core::Result<T>;
45
46/// A VM handle the VMBus proxy driver.
47#[derive(Debug, MeshPayload)]
48pub struct ProxyHandle(std::fs::File);
49
50impl ProxyHandle {
51    /// Creates a new VM handle.
52    pub fn new() -> Result<Self> {
53        const DEVICE_PATH: &Utf16Str = utf16str!("\\Device\\VmbusProxy");
54        let pathu = UnicodeStringRef::try_from(DEVICE_PATH).expect("string fits");
55        let mut oa = ObjectAttributes::new();
56        oa.name(&pathu);
57        // SAFETY: calling API according to docs.
58        unsafe {
59            let mut iosb = zeroed();
60            let mut handle = HANDLE::default();
61            NtOpenFile(
62                &mut handle,
63                (FILE_ALL_ACCESS | SYNCHRONIZE).0,
64                oa.as_ref(),
65                &mut iosb,
66                0,
67                0,
68            )
69            .ok()?;
70            Ok(Self(std::fs::File::from_raw_handle(handle.0 as RawHandle)))
71        }
72    }
73}
74
75impl From<OwnedHandle> for ProxyHandle {
76    /// Create a `ProxyHandle` from an existing VM handle.
77    fn from(value: OwnedHandle) -> Self {
78        Self(value.into())
79    }
80}
81
82pub struct VmbusProxy {
83    file: OverlappedFile,
84    // NOTE: This must come after `file` so that it is not released until `file`
85    // is closed.
86    guest_memory: Option<GuestMemory>,
87    cancel: CancelContext,
88}
89
90#[derive(Debug)]
91pub enum ProxyAction {
92    Offer {
93        id: u64,
94        offer: VMBUS_CHANNEL_OFFER,
95        incoming_event: Event,
96        outgoing_event: Option<Event>,
97        device_order: Option<NonZeroU32>,
98    },
99    Revoke {
100        id: u64,
101    },
102    InterruptPolicy {},
103    TlConnectResult {
104        result: HvsockConnectResult,
105        vtl: u8,
106    },
107}
108
109struct StaticIoctlBuffer<T>(T);
110
111// SAFETY: this is not generically safe, so callers must be careful to only use
112// this newtype for values that can be safely passed to overlapped IO.
113unsafe impl<T> IoBuf for StaticIoctlBuffer<T> {
114    fn as_ptr(&self) -> *const u8 {
115        std::ptr::from_ref::<Self>(self).cast()
116    }
117
118    fn len(&self) -> usize {
119        size_of_val(self)
120    }
121}
122
123// SAFETY: this is not generically safe, so callers must be careful to only use
124// this newtype for values that can be safely passed to overlapped IO.
125unsafe impl<T> IoBufMut for StaticIoctlBuffer<T> {
126    fn as_mut_ptr(&mut self) -> *mut u8 {
127        std::ptr::from_mut::<Self>(self).cast()
128    }
129}
130
131impl VmbusProxy {
132    pub fn new(driver: &dyn Driver, handle: ProxyHandle, ctx: CancelContext) -> Result<Self> {
133        // SAFETY: TODO, analyze whether we are guaranteed to follow the safety
134        // contract.
135        let file = unsafe { OverlappedFile::new(driver, handle.0)? };
136        Ok(Self {
137            file,
138            guest_memory: None,
139            cancel: ctx,
140        })
141    }
142
143    pub fn handle(&self) -> BorrowedHandle<'_> {
144        self.file.get().as_handle()
145    }
146
147    async unsafe fn ioctl<In, Out>(&self, code: u32, input: In, output: Out) -> Result<Out>
148    where
149        In: IoBufMut,
150        Out: IoBufMut,
151    {
152        // SAFETY: guaranteed by caller.
153        let (r, (_, output)) = unsafe { self.file.ioctl(code, input, output).await };
154        let size = r?;
155        assert_eq!(size, output.len(), "ioctl returned unexpected size");
156        Ok(output)
157    }
158
159    async unsafe fn ioctl_cancellable<In, Out>(
160        &self,
161        code: u32,
162        input: In,
163        output: Out,
164    ) -> Result<Out>
165    where
166        In: IoBufMut,
167        Out: IoBufMut,
168    {
169        // Don't issue new IO if the cancel context has already been cancelled.
170        let mut cancel = self.cancel.clone();
171        if cancel.is_cancelled() {
172            tracing::trace!("ioctl cancelled before issued");
173            return Err(ERROR_OPERATION_ABORTED.into());
174        }
175
176        // SAFETY: guaranteed by caller.
177        let mut ioctl = unsafe { self.file.ioctl(code, input, output) };
178        let (r, (_, output)) = match poll!(&mut ioctl) {
179            std::task::Poll::Ready(result) => result,
180            std::task::Poll::Pending => {
181                match cancel.until_cancelled(&mut ioctl).await {
182                    Ok(r) => r,
183                    Err(_) => {
184                        tracing::trace!("ioctl cancelled after issued");
185                        ioctl.cancel();
186                        // Even when cancelled, we must wait to complete the IO so buffers aren't released
187                        // while still in use.
188                        ioctl.await
189                    }
190                }
191            }
192        };
193
194        let size = r?;
195        assert_eq!(size, output.len(), "ioctl returned unexpected size");
196        Ok(output)
197    }
198
199    pub async fn set_memory(&mut self, guest_memory: &GuestMemory) -> Result<()> {
200        assert!(self.guest_memory.is_none());
201        let (base, len) = guest_memory.full_mapping().ok_or_else(|| {
202            std::io::Error::other("vmbusproxy not supported without mapped memory")
203        })?;
204        self.guest_memory = Some(guest_memory.clone());
205        unsafe {
206            self.ioctl(
207                proxyioctl::IOCTL_VMBUS_PROXY_SET_MEMORY,
208                StaticIoctlBuffer(proxyioctl::VMBUS_PROXY_SET_MEMORY_INPUT {
209                    BaseAddress: base as usize as u64,
210                    Size: len as u64,
211                }),
212                (),
213            )
214            .await
215        }
216    }
217
218    pub async fn next_action(&self) -> Result<ProxyAction> {
219        let output = unsafe {
220            self.ioctl_cancellable(
221                proxyioctl::IOCTL_VMBUS_PROXY_NEXT_ACTION,
222                (),
223                StaticIoctlBuffer(zeroed::<proxyioctl::VMBUS_PROXY_NEXT_ACTION_OUTPUT>()),
224            )
225            .await?
226            .0
227        };
228        match output.Type {
229            proxyioctl::VmbusProxyActionTypeOffer => unsafe {
230                Ok(ProxyAction::Offer {
231                    id: output.ProxyId,
232                    offer: output.u.Offer.Offer,
233                    incoming_event: OwnedHandle::from_raw_handle(
234                        output.u.Offer.DeviceIncomingRingEvent as usize as RawHandle,
235                    )
236                    .into(),
237                    outgoing_event: if output.u.Offer.DeviceOutgoingRingEvent != 0 {
238                        Some(
239                            OwnedHandle::from_raw_handle(
240                                output.u.Offer.DeviceOutgoingRingEvent as usize as RawHandle,
241                            )
242                            .into(),
243                        )
244                    } else {
245                        None
246                    },
247                    device_order: NonZeroU32::new(output.u.Offer.DeviceOrder),
248                })
249            },
250            proxyioctl::VmbusProxyActionTypeRevoke => {
251                Ok(ProxyAction::Revoke { id: output.ProxyId })
252            }
253            proxyioctl::VmbusProxyActionTypeInterruptPolicy => Ok(ProxyAction::InterruptPolicy {}),
254            proxyioctl::VmbusProxyActionTypeTlConnectResult => unsafe {
255                Ok(ProxyAction::TlConnectResult {
256                    result: HvsockConnectResult {
257                        endpoint_id: output.u.TlConnectResult.EndpointId.into(),
258                        service_id: output.u.TlConnectResult.ServiceId.into(),
259                        success: output.u.TlConnectResult.Status.is_ok(),
260                    },
261                    vtl: output.u.TlConnectResult.Vtl,
262                })
263            },
264            n => panic!("unexpected action: {}", n),
265        }
266    }
267
268    pub async fn open(
269        &self,
270        id: u64,
271        params: &VMBUS_SERVER_OPEN_CHANNEL_OUTPUT_PARAMETERS,
272        event: &Event,
273    ) -> Result<()> {
274        let output = unsafe {
275            let handle = event.as_handle().as_raw_handle() as usize as u64;
276            self.ioctl(
277                proxyioctl::IOCTL_VMBUS_PROXY_OPEN_CHANNEL,
278                StaticIoctlBuffer(proxyioctl::VMBUS_PROXY_OPEN_CHANNEL_INPUT {
279                    ProxyId: id,
280                    Padding: 0,
281                    OpenParameters: *params,
282                    VmmSignalEvent: handle,
283                }),
284                StaticIoctlBuffer(zeroed::<proxyioctl::VMBUS_PROXY_OPEN_CHANNEL_OUTPUT>()),
285            )
286            .await?
287            .0
288        };
289        NTSTATUS(output.Status).ok()
290    }
291    pub async fn set_interrupt(&self, id: u64, event: &Event) -> Result<()> {
292        unsafe {
293            let handle = event.as_handle().as_raw_handle() as usize as u64;
294            self.ioctl(
295                proxyioctl::IOCTL_VMBUS_PROXY_RESTORE_SET_INTERRUPT,
296                StaticIoctlBuffer(proxyioctl::VMBUS_PROXY_SET_INTERRUPT_INPUT {
297                    ProxyId: id,
298                    VmmSignalEvent: handle,
299                }),
300                (),
301            )
302            .await?
303        };
304        Ok(())
305    }
306
307    pub async fn restore(
308        &self,
309        interface_type: Guid,
310        interface_instance: Guid,
311        subchannel_index: u16,
312        target_vtl: u8,
313        open_params: Option<VMBUS_SERVER_OPEN_CHANNEL_OUTPUT_PARAMETERS>,
314        gpadls: impl Iterator<Item = Gpadl<'_>>,
315    ) -> Result<u64> {
316        let mut buffer = Vec::new();
317        let mut header = proxyioctl::VMBUS_PROXY_RESTORE_CHANNEL_INPUT {
318            InterfaceType: interface_type,
319            InterfaceInstance: interface_instance,
320            SubchannelIndex: subchannel_index,
321            TargetVtl: target_vtl,
322            GpadlCount: 0,
323            OpenParameters: open_params.unwrap_or_default(),
324            Open: open_params.is_some().into(),
325        };
326
327        // Leave space for the header.
328        const HEADER_LEN: usize = size_of::<proxyioctl::VMBUS_PROXY_RESTORE_CHANNEL_INPUT>();
329        buffer.resize(HEADER_LEN, 0);
330
331        // Add GPADLs to the buffer and count them.
332        for gpadl in gpadls {
333            header.GpadlCount += 1;
334            Self::add_gpadl(
335                &mut buffer,
336                0, // Not used for restoring.
337                gpadl.gpadl_id,
338                gpadl.range_count,
339                gpadl.range_buffer.as_bytes(),
340            );
341        }
342
343        // Copy the header now that the GPADL count is known.
344        buffer[..HEADER_LEN].copy_from_slice(header.as_bytes());
345        Ok(unsafe {
346            self.ioctl(
347                proxyioctl::IOCTL_VMBUS_PROXY_RESTORE_CHANNEL,
348                buffer,
349                StaticIoctlBuffer(zeroed::<proxyioctl::VMBUS_PROXY_RESTORE_CHANNEL_OUTPUT>()),
350            )
351            .await?
352            .0
353            .ProxyId
354        })
355    }
356
357    pub async fn revoke_unclaimed_channels(&self) -> Result<()> {
358        unsafe {
359            self.ioctl(
360                proxyioctl::IOCTL_VMBUS_PROXY_REVOKE_UNCLAIMED_CHANNELS,
361                (),
362                (),
363            )
364            .await
365        }
366    }
367
368    pub async fn close(&self, id: u64) -> Result<()> {
369        unsafe {
370            self.ioctl(
371                proxyioctl::IOCTL_VMBUS_PROXY_CLOSE_CHANNEL,
372                StaticIoctlBuffer(proxyioctl::VMBUS_PROXY_CLOSE_CHANNEL_INPUT { ProxyId: id }),
373                (),
374            )
375            .await
376        }
377    }
378
379    pub async fn release(&self, id: u64) -> Result<()> {
380        unsafe {
381            self.ioctl(
382                proxyioctl::IOCTL_VMBUS_PROXY_RELEASE_CHANNEL,
383                StaticIoctlBuffer(proxyioctl::VMBUS_PROXY_RELEASE_CHANNEL_INPUT { ProxyId: id }),
384                (),
385            )
386            .await
387        }
388    }
389
390    pub async fn create_gpadl(
391        &self,
392        id: u64,
393        gpadl_id: u32,
394        range_count: u32,
395        range_buf: &[u8],
396    ) -> Result<()> {
397        let mut buf = Vec::new();
398        Self::add_gpadl(&mut buf, id, gpadl_id, range_count, range_buf);
399        unsafe {
400            self.ioctl(proxyioctl::IOCTL_VMBUS_PROXY_CREATE_GPADL, buf, ())
401                .await
402        }
403    }
404
405    pub async fn delete_gpadl(&self, id: u64, gpadl_id: u32) -> Result<()> {
406        unsafe {
407            self.ioctl(
408                proxyioctl::IOCTL_VMBUS_PROXY_DELETE_GPADL,
409                StaticIoctlBuffer(proxyioctl::VMBUS_PROXY_DELETE_GPADL_INPUT {
410                    ProxyId: id,
411                    GpadlId: gpadl_id,
412                    Padding: 0,
413                }),
414                (),
415            )
416            .await
417        }
418    }
419
420    pub async fn tl_connect_request(&self, request: &HvsockConnectRequest, vtl: u8) -> Result<()> {
421        unsafe {
422            self.ioctl(
423                proxyioctl::IOCTL_VMBUS_PROXY_TL_CONNECT_REQUEST,
424                StaticIoctlBuffer(proxyioctl::VMBUS_PROXY_TL_CONNECT_REQUEST_INPUT {
425                    EndpoindId: request.endpoint_id.into(),
426                    ServiceId: request.service_id.into(),
427                    SiloId: request.silo_id.into(),
428                    Flags: proxyioctl::VMBUS_PROXY_TL_CONNECT_REQUEST_FLAGS::new()
429                        .with_hosted_silo_unaware(request.hosted_silo_unaware),
430                    Vtl: vtl,
431                    Padding: [0; 3],
432                }),
433                (),
434            )
435            .await
436        }
437    }
438
439    pub fn run_channel(&self, id: u64) -> Result<()> {
440        unsafe {
441            // This is a synchronous operation, so don't use the async IO infrastructure.
442            let input = proxyioctl::VMBUS_PROXY_RUN_CHANNEL_INPUT { ProxyId: id };
443            let mut bytes = 0;
444            DeviceIoControl(
445                HANDLE(self.file.get().as_raw_handle()),
446                proxyioctl::IOCTL_VMBUS_PROXY_RUN_CHANNEL,
447                Some(std::ptr::from_ref(&input).cast()),
448                size_of_val(&input) as u32,
449                None,
450                0,
451                Some(&mut bytes),
452                None,
453            )?;
454        };
455        Ok(())
456    }
457
458    /// Adds GPADL ioctl data to a buffer.
459    fn add_gpadl(
460        buffer: &mut Vec<u8>,
461        id: u64,
462        gpadl_id: u32,
463        range_count: u32,
464        range_buffer: &[u8],
465    ) {
466        let header = proxyioctl::VMBUS_PROXY_CREATE_GPADL_INPUT {
467            ProxyId: id,
468            GpadlId: gpadl_id,
469            RangeCount: range_count,
470            RangeBufferOffset: size_of::<proxyioctl::VMBUS_PROXY_CREATE_GPADL_INPUT>() as u32,
471            RangeBufferSize: range_buffer.len() as u32,
472        };
473        buffer.extend_from_slice(header.as_bytes());
474        buffer.extend_from_slice(range_buffer);
475    }
476}
477
478/// Represents data to be restored for a GPADL.
479pub struct Gpadl<'a> {
480    pub gpadl_id: u32,
481    pub range_count: u32,
482    pub range_buffer: &'a [u64],
483}