vmbus_proxy/
lib.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4#![expect(missing_docs)]
5#![cfg(windows)]
6// UNSAFETY: Calling vmbus proxy ioctls.
7#![expect(unsafe_code)]
8#![expect(clippy::undocumented_unsafe_blocks, clippy::missing_safety_doc)]
9
10use futures::poll;
11use guestmem::GuestMemory;
12use guid::Guid;
13use mesh::CancelContext;
14use mesh::MeshPayload;
15use pal::windows::ObjectAttributes;
16use pal::windows::UnicodeStringRef;
17use pal_async::driver::Driver;
18use pal_async::windows::overlapped::IoBuf;
19use pal_async::windows::overlapped::IoBufMut;
20use pal_async::windows::overlapped::OverlappedFile;
21use pal_event::Event;
22use std::mem::zeroed;
23use std::os::windows::prelude::*;
24use vmbus_core::HvsockConnectRequest;
25use vmbus_core::HvsockConnectResult;
26use vmbusioctl::VMBUS_CHANNEL_OFFER;
27use vmbusioctl::VMBUS_SERVER_OPEN_CHANNEL_OUTPUT_PARAMETERS;
28use widestring::Utf16Str;
29use widestring::utf16str;
30use windows::Wdk::Storage::FileSystem::NtOpenFile;
31use windows::Win32::Foundation::ERROR_OPERATION_ABORTED;
32use windows::Win32::Foundation::HANDLE;
33use windows::Win32::Foundation::NTSTATUS;
34use windows::Win32::Storage::FileSystem::FILE_ALL_ACCESS;
35use windows::Win32::Storage::FileSystem::SYNCHRONIZE;
36use windows::Win32::System::IO::DeviceIoControl;
37use zerocopy::IntoBytes;
38
39mod proxyioctl;
40pub mod vmbusioctl;
41
42pub type Error = windows::core::Error;
43pub type Result<T> = windows::core::Result<T>;
44
45/// A VM handle the VMBus proxy driver.
46#[derive(Debug, MeshPayload)]
47pub struct ProxyHandle(std::fs::File);
48
49impl ProxyHandle {
50    /// Creates a new VM handle.
51    pub fn new() -> Result<Self> {
52        const DEVICE_PATH: &Utf16Str = utf16str!("\\Device\\VmbusProxy");
53        let pathu = UnicodeStringRef::try_from(DEVICE_PATH).expect("string fits");
54        let mut oa = ObjectAttributes::new();
55        oa.name(&pathu);
56        // SAFETY: calling API according to docs.
57        unsafe {
58            let mut iosb = zeroed();
59            let mut handle = HANDLE::default();
60            NtOpenFile(
61                &mut handle,
62                (FILE_ALL_ACCESS | SYNCHRONIZE).0,
63                oa.as_ref(),
64                &mut iosb,
65                0,
66                0,
67            )
68            .ok()?;
69            Ok(Self(std::fs::File::from_raw_handle(handle.0 as RawHandle)))
70        }
71    }
72}
73
74impl From<OwnedHandle> for ProxyHandle {
75    /// Create a `ProxyHandle` from an existing VM handle.
76    fn from(value: OwnedHandle) -> Self {
77        Self(value.into())
78    }
79}
80
81pub struct VmbusProxy {
82    file: OverlappedFile,
83    // NOTE: This must come after `file` so that it is not released until `file`
84    // is closed.
85    guest_memory: Option<GuestMemory>,
86    cancel: CancelContext,
87}
88
89#[derive(Debug)]
90pub enum ProxyAction {
91    Offer {
92        id: u64,
93        offer: VMBUS_CHANNEL_OFFER,
94        incoming_event: Event,
95        outgoing_event: Option<Event>,
96    },
97    Revoke {
98        id: u64,
99    },
100    InterruptPolicy {},
101    TlConnectResult {
102        result: HvsockConnectResult,
103        vtl: u8,
104    },
105}
106
107struct StaticIoctlBuffer<T>(T);
108
109// SAFETY: this is not generically safe, so callers must be careful to only use
110// this newtype for values that can be safely passed to overlapped IO.
111unsafe impl<T> IoBuf for StaticIoctlBuffer<T> {
112    fn as_ptr(&self) -> *const u8 {
113        std::ptr::from_ref::<Self>(self).cast()
114    }
115
116    fn len(&self) -> usize {
117        size_of_val(self)
118    }
119}
120
121// SAFETY: this is not generically safe, so callers must be careful to only use
122// this newtype for values that can be safely passed to overlapped IO.
123unsafe impl<T> IoBufMut for StaticIoctlBuffer<T> {
124    fn as_mut_ptr(&mut self) -> *mut u8 {
125        std::ptr::from_mut::<Self>(self).cast()
126    }
127}
128
129impl VmbusProxy {
130    pub fn new(driver: &dyn Driver, handle: ProxyHandle, ctx: CancelContext) -> Result<Self> {
131        Ok(Self {
132            file: OverlappedFile::new(driver, handle.0)?,
133            guest_memory: None,
134            cancel: ctx,
135        })
136    }
137
138    pub fn handle(&self) -> BorrowedHandle<'_> {
139        self.file.get().as_handle()
140    }
141
142    async unsafe fn ioctl<In, Out>(&self, code: u32, input: In, output: Out) -> Result<Out>
143    where
144        In: IoBufMut,
145        Out: IoBufMut,
146    {
147        // SAFETY: guaranteed by caller.
148        let (r, (_, output)) = unsafe { self.file.ioctl(code, input, output).await };
149        r?;
150        Ok(output)
151    }
152
153    async unsafe fn ioctl_cancellable<In, Out>(
154        &self,
155        code: u32,
156        input: In,
157        output: Out,
158    ) -> Result<Out>
159    where
160        In: IoBufMut,
161        Out: IoBufMut,
162    {
163        // Don't issue new IO if the cancel context has already been cancelled.
164        let mut cancel = self.cancel.clone();
165        if cancel.is_cancelled() {
166            tracing::trace!("ioctl cancelled before issued");
167            return Err(ERROR_OPERATION_ABORTED.into());
168        }
169
170        // SAFETY: guaranteed by caller.
171        let mut ioctl = unsafe { self.file.ioctl(code, input, output) };
172        let (r, (_, output)) = match poll!(&mut ioctl) {
173            std::task::Poll::Ready(result) => result,
174            std::task::Poll::Pending => {
175                match cancel.until_cancelled(&mut ioctl).await {
176                    Ok(r) => r,
177                    Err(_) => {
178                        tracing::trace!("ioctl cancelled after issued");
179                        ioctl.cancel();
180                        // Even when cancelled, we must wait to complete the IO so buffers aren't released
181                        // while still in use.
182                        ioctl.await
183                    }
184                }
185            }
186        };
187
188        r?;
189        Ok(output)
190    }
191
192    pub async fn set_memory(&mut self, guest_memory: &GuestMemory) -> Result<()> {
193        assert!(self.guest_memory.is_none());
194        let (base, len) = guest_memory.full_mapping().ok_or_else(|| {
195            std::io::Error::other("vmbusproxy not supported without mapped memory")
196        })?;
197        self.guest_memory = Some(guest_memory.clone());
198        unsafe {
199            self.ioctl(
200                proxyioctl::IOCTL_VMBUS_PROXY_SET_MEMORY,
201                StaticIoctlBuffer(proxyioctl::VMBUS_PROXY_SET_MEMORY_INPUT {
202                    BaseAddress: base as usize as u64,
203                    Size: len as u64,
204                }),
205                (),
206            )
207            .await
208        }
209    }
210
211    pub async fn next_action(&self) -> Result<ProxyAction> {
212        let output = unsafe {
213            self.ioctl_cancellable(
214                proxyioctl::IOCTL_VMBUS_PROXY_NEXT_ACTION,
215                (),
216                StaticIoctlBuffer(zeroed::<proxyioctl::VMBUS_PROXY_NEXT_ACTION_OUTPUT>()),
217            )
218            .await?
219            .0
220        };
221        match output.Type {
222            proxyioctl::VmbusProxyActionTypeOffer => unsafe {
223                Ok(ProxyAction::Offer {
224                    id: output.ProxyId,
225                    offer: output.u.Offer.Offer,
226                    incoming_event: OwnedHandle::from_raw_handle(
227                        output.u.Offer.DeviceIncomingRingEvent as usize as RawHandle,
228                    )
229                    .into(),
230                    outgoing_event: if output.u.Offer.DeviceOutgoingRingEvent != 0 {
231                        Some(
232                            OwnedHandle::from_raw_handle(
233                                output.u.Offer.DeviceOutgoingRingEvent as usize as RawHandle,
234                            )
235                            .into(),
236                        )
237                    } else {
238                        None
239                    },
240                })
241            },
242            proxyioctl::VmbusProxyActionTypeRevoke => {
243                Ok(ProxyAction::Revoke { id: output.ProxyId })
244            }
245            proxyioctl::VmbusProxyActionTypeInterruptPolicy => Ok(ProxyAction::InterruptPolicy {}),
246            proxyioctl::VmbusProxyActionTypeTlConnectResult => unsafe {
247                Ok(ProxyAction::TlConnectResult {
248                    result: HvsockConnectResult {
249                        endpoint_id: output.u.TlConnectResult.EndpointId.into(),
250                        service_id: output.u.TlConnectResult.ServiceId.into(),
251                        success: output.u.TlConnectResult.Status.is_ok(),
252                    },
253                    vtl: output.u.TlConnectResult.Vtl,
254                })
255            },
256            n => panic!("unexpected action: {}", n),
257        }
258    }
259
260    pub async fn open(
261        &self,
262        id: u64,
263        params: &VMBUS_SERVER_OPEN_CHANNEL_OUTPUT_PARAMETERS,
264        event: &Event,
265    ) -> Result<()> {
266        let output = unsafe {
267            let handle = event.as_handle().as_raw_handle() as usize as u64;
268            self.ioctl(
269                proxyioctl::IOCTL_VMBUS_PROXY_OPEN_CHANNEL,
270                StaticIoctlBuffer(proxyioctl::VMBUS_PROXY_OPEN_CHANNEL_INPUT {
271                    ProxyId: id,
272                    Padding: 0,
273                    OpenParameters: *params,
274                    VmmSignalEvent: handle,
275                }),
276                StaticIoctlBuffer(zeroed::<proxyioctl::VMBUS_PROXY_OPEN_CHANNEL_OUTPUT>()),
277            )
278            .await?
279            .0
280        };
281        NTSTATUS(output.Status).ok()
282    }
283    pub async fn set_interrupt(&self, id: u64, event: &Event) -> Result<()> {
284        unsafe {
285            let handle = event.as_handle().as_raw_handle() as usize as u64;
286            self.ioctl(
287                proxyioctl::IOCTL_VMBUS_PROXY_RESTORE_SET_INTERRUPT,
288                StaticIoctlBuffer(proxyioctl::VMBUS_PROXY_SET_INTERRUPT_INPUT {
289                    ProxyId: id,
290                    VmmSignalEvent: handle,
291                }),
292                (),
293            )
294            .await?
295        };
296        Ok(())
297    }
298
299    pub async fn restore(
300        &self,
301        interface_type: Guid,
302        interface_instance: Guid,
303        subchannel_index: u16,
304        target_vtl: u8,
305        open_params: Option<VMBUS_SERVER_OPEN_CHANNEL_OUTPUT_PARAMETERS>,
306        gpadls: impl Iterator<Item = Gpadl<'_>>,
307    ) -> Result<u64> {
308        let mut buffer = Vec::new();
309        let mut header = proxyioctl::VMBUS_PROXY_RESTORE_CHANNEL_INPUT {
310            InterfaceType: interface_type,
311            InterfaceInstance: interface_instance,
312            SubchannelIndex: subchannel_index,
313            TargetVtl: target_vtl,
314            GpadlCount: 0,
315            OpenParameters: open_params.unwrap_or_default(),
316            Open: open_params.is_some().into(),
317        };
318
319        // Leave space for the header.
320        const HEADER_LEN: usize = size_of::<proxyioctl::VMBUS_PROXY_RESTORE_CHANNEL_INPUT>();
321        buffer.resize(HEADER_LEN, 0);
322
323        // Add GPADLs to the buffer and count them.
324        for gpadl in gpadls {
325            header.GpadlCount += 1;
326            Self::add_gpadl(
327                &mut buffer,
328                0, // Not used for restoring.
329                gpadl.gpadl_id,
330                gpadl.range_count,
331                gpadl.range_buffer.as_bytes(),
332            );
333        }
334
335        // Copy the header now that the GPADL count is known.
336        buffer[..HEADER_LEN].copy_from_slice(header.as_bytes());
337        Ok(unsafe {
338            self.ioctl(
339                proxyioctl::IOCTL_VMBUS_PROXY_RESTORE_CHANNEL,
340                buffer,
341                StaticIoctlBuffer(zeroed::<proxyioctl::VMBUS_PROXY_RESTORE_CHANNEL_OUTPUT>()),
342            )
343            .await?
344            .0
345            .ProxyId
346        })
347    }
348
349    pub async fn revoke_unclaimed_channels(&self) -> Result<()> {
350        unsafe {
351            self.ioctl(
352                proxyioctl::IOCTL_VMBUS_PROXY_REVOKE_UNCLAIMED_CHANNELS,
353                (),
354                (),
355            )
356            .await
357        }
358    }
359
360    pub async fn close(&self, id: u64) -> Result<()> {
361        unsafe {
362            self.ioctl(
363                proxyioctl::IOCTL_VMBUS_PROXY_CLOSE_CHANNEL,
364                StaticIoctlBuffer(proxyioctl::VMBUS_PROXY_CLOSE_CHANNEL_INPUT { ProxyId: id }),
365                (),
366            )
367            .await
368        }
369    }
370
371    pub async fn release(&self, id: u64) -> Result<()> {
372        unsafe {
373            self.ioctl(
374                proxyioctl::IOCTL_VMBUS_PROXY_RELEASE_CHANNEL,
375                StaticIoctlBuffer(proxyioctl::VMBUS_PROXY_RELEASE_CHANNEL_INPUT { ProxyId: id }),
376                (),
377            )
378            .await
379        }
380    }
381
382    pub async fn create_gpadl(
383        &self,
384        id: u64,
385        gpadl_id: u32,
386        range_count: u32,
387        range_buf: &[u8],
388    ) -> Result<()> {
389        let mut buf = Vec::new();
390        Self::add_gpadl(&mut buf, id, gpadl_id, range_count, range_buf);
391        unsafe {
392            self.ioctl(proxyioctl::IOCTL_VMBUS_PROXY_CREATE_GPADL, buf, ())
393                .await
394        }
395    }
396
397    pub async fn delete_gpadl(&self, id: u64, gpadl_id: u32) -> Result<()> {
398        unsafe {
399            self.ioctl(
400                proxyioctl::IOCTL_VMBUS_PROXY_DELETE_GPADL,
401                StaticIoctlBuffer(proxyioctl::VMBUS_PROXY_DELETE_GPADL_INPUT {
402                    ProxyId: id,
403                    GpadlId: gpadl_id,
404                    Padding: 0,
405                }),
406                (),
407            )
408            .await
409        }
410    }
411
412    pub async fn tl_connect_request(&self, request: &HvsockConnectRequest, vtl: u8) -> Result<()> {
413        unsafe {
414            self.ioctl(
415                proxyioctl::IOCTL_VMBUS_PROXY_TL_CONNECT_REQUEST,
416                StaticIoctlBuffer(proxyioctl::VMBUS_PROXY_TL_CONNECT_REQUEST_INPUT {
417                    EndpoindId: request.endpoint_id.into(),
418                    ServiceId: request.service_id.into(),
419                    SiloId: request.silo_id.into(),
420                    Flags: proxyioctl::VMBUS_PROXY_TL_CONNECT_REQUEST_FLAGS::new()
421                        .with_hosted_silo_unaware(request.hosted_silo_unaware),
422                    Vtl: vtl,
423                    Padding: [0; 3],
424                }),
425                (),
426            )
427            .await
428        }
429    }
430
431    pub fn run_channel(&self, id: u64) -> Result<()> {
432        unsafe {
433            // This is a synchronous operation, so don't use the async IO infrastructure.
434            let input = proxyioctl::VMBUS_PROXY_RUN_CHANNEL_INPUT { ProxyId: id };
435            let mut bytes = 0;
436            DeviceIoControl(
437                HANDLE(self.file.get().as_raw_handle()),
438                proxyioctl::IOCTL_VMBUS_PROXY_RUN_CHANNEL,
439                Some(std::ptr::from_ref(&input).cast()),
440                size_of_val(&input) as u32,
441                None,
442                0,
443                Some(&mut bytes),
444                None,
445            )?;
446        };
447        Ok(())
448    }
449
450    /// Adds GPADL ioctl data to a buffer.
451    fn add_gpadl(
452        buffer: &mut Vec<u8>,
453        id: u64,
454        gpadl_id: u32,
455        range_count: u32,
456        range_buffer: &[u8],
457    ) {
458        let header = proxyioctl::VMBUS_PROXY_CREATE_GPADL_INPUT {
459            ProxyId: id,
460            GpadlId: gpadl_id,
461            RangeCount: range_count,
462            RangeBufferOffset: size_of::<proxyioctl::VMBUS_PROXY_CREATE_GPADL_INPUT>() as u32,
463            RangeBufferSize: range_buffer.len() as u32,
464        };
465        buffer.extend_from_slice(header.as_bytes());
466        buffer.extend_from_slice(range_buffer);
467    }
468}
469
470/// Represents data to be restored for a GPADL.
471pub struct Gpadl<'a> {
472    pub gpadl_id: u32,
473    pub range_count: u32,
474    pub range_buffer: &'a [u64],
475}