vmsocket/
lib.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! Support for `AF_HYPERV` (on Windows) and `AF_VSOCK` (on Linux) socket families.
5//!
6//! This crate abstracts over the differences between these and provides unified
7//! [`VmStream`] and [`VmListener`] types.
8
9#![cfg(any(windows, target_os = "linux"))]
10
11cfg_if::cfg_if! {
12    if #[cfg(windows)] {
13        mod af_hyperv;
14        use af_hyperv as sys;
15    } else if #[cfg(unix)] {
16        mod af_vsock;
17        use af_vsock as sys;
18    }
19}
20
21use socket2::SockAddr;
22use socket2::Socket;
23use std::io;
24use std::io::Read;
25use std::io::Write;
26
27/// A VM socket address.
28#[derive(Debug)]
29pub struct VmAddress(sys::Address);
30
31impl std::fmt::Display for VmAddress {
32    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33        #[cfg(windows)]
34        write!(f, "{}:{}", self.vm_id(), self.service_id())?;
35        #[cfg(unix)]
36        write!(f, "{}:{}", self.cid(), self.port())?;
37        Ok(())
38    }
39}
40
41impl VmAddress {
42    /// Creates a new AF_VSOCK address from `cid` and `port`.
43    #[cfg(unix)]
44    pub fn vsock(cid: u32, port: u32) -> Self {
45        Self(sys::Address::new(cid, port))
46    }
47
48    /// Creates a new AF_HYPERV address from `vm_id` and `service_id`.
49    #[cfg(windows)]
50    pub fn hyperv(vm_id: guid::Guid, service_id: guid::Guid) -> Self {
51        Self(sys::Address::new(vm_id, service_id))
52    }
53
54    /// Creates a new AF_HYPERV address from `vm_id` and VSOCK `port`.
55    #[cfg(windows)]
56    pub fn hyperv_vsock(vm_id: guid::Guid, port: u32) -> Self {
57        Self(sys::Address::vsock(vm_id, port))
58    }
59
60    /// Creates a new AF_HYPERV address referring to any VM and the specified
61    /// service ID.
62    #[cfg(windows)]
63    pub fn hyperv_any(service_id: guid::Guid) -> Self {
64        Self(sys::Address::hyperv_any(service_id))
65    }
66
67    /// Creates a new AF_HYPERV address referring to the parent VM and the
68    /// specified service ID.
69    #[cfg(windows)]
70    pub fn hyperv_host(service_id: guid::Guid) -> Self {
71        Self(sys::Address::hyperv_host(service_id))
72    }
73
74    /// Creates a new address referring to any VM, with the specified VSOCK
75    /// port.
76    pub fn vsock_any(port: u32) -> Self {
77        Self(sys::Address::vsock_any(port))
78    }
79
80    /// Creates a new address referring to the host, with the specified VSOCK
81    /// port.
82    pub fn vsock_host(port: u32) -> Self {
83        Self(sys::Address::vsock_host(port))
84    }
85
86    /// Creates a new address from the specified [`SockAddr`] when the address
87    /// is an `AF_HYPERV` (on Windows) or `AF_VSOCK` (on Linux) address.
88    pub fn try_from_sock_addr(addr: &SockAddr) -> Option<Self> {
89        Some(Self(sys::Address::try_from_sock_addr(addr)?))
90    }
91
92    /// Gets the VSOCK CID.
93    #[cfg(unix)]
94    pub fn cid(&self) -> u32 {
95        self.0.cid
96    }
97
98    /// Gets the VSOCK port.
99    #[cfg(unix)]
100    pub fn port(&self) -> u32 {
101        self.0.port
102    }
103
104    /// Gets the VM ID.
105    #[cfg(windows)]
106    pub fn vm_id(&self) -> guid::Guid {
107        self.0.vm_id
108    }
109
110    /// Gets the service ID.
111    #[cfg(windows)]
112    pub fn service_id(&self) -> guid::Guid {
113        self.0.service_id
114    }
115}
116
117impl From<VmAddress> for SockAddr {
118    fn from(address: VmAddress) -> Self {
119        address.0.into_sock_addr()
120    }
121}
122
123/// A VM socket that has not yet been bound.
124pub struct VmSocket(Socket);
125
126impl VmSocket {
127    /// Creates a new stream socket not bound or connected to anything.
128    pub fn new() -> io::Result<Self> {
129        Self::new_inner()
130    }
131
132    /// Binds the socket to `address`.
133    pub fn bind(&mut self, address: VmAddress) -> io::Result<()> {
134        self.0.bind(&address.into())?;
135        Ok(())
136    }
137
138    /// Listens for connections, returning a [`VmListener`].
139    pub fn listen(self, backlog: i32) -> io::Result<VmListener> {
140        self.0.listen(backlog)?;
141        Ok(VmListener(self.0))
142    }
143
144    /// Connects to `address`, returning a [`VmStream`].
145    pub fn connect(self, address: VmAddress) -> io::Result<VmStream> {
146        self.0.connect(&address.into())?;
147        Ok(VmStream(self.0))
148    }
149}
150
151impl From<Socket> for VmSocket {
152    fn from(s: Socket) -> Self {
153        Self(s)
154    }
155}
156
157impl From<VmSocket> for Socket {
158    fn from(s: VmSocket) -> Self {
159        s.0
160    }
161}
162
163/// A VM socket listener.
164#[derive(Debug)]
165pub struct VmListener(Socket);
166
167impl VmListener {
168    /// Creates a new socket bound to the specified address.
169    pub fn bind(address: VmAddress) -> io::Result<Self> {
170        let mut s = VmSocket::new()?;
171        s.bind(address)?;
172        s.listen(4)
173    }
174
175    /// Accepts the next connection.
176    pub fn accept(&self) -> io::Result<(VmStream, VmAddress)> {
177        let (s, addr) = self.0.accept()?;
178        Ok((VmStream(s), VmAddress::try_from_sock_addr(&addr).unwrap()))
179    }
180
181    /// Retrieves the address that the listener is bound to.
182    pub fn local_addr(&self) -> io::Result<VmAddress> {
183        Ok(VmAddress::try_from_sock_addr(&self.0.local_addr()?).unwrap())
184    }
185}
186
187impl pal_async::socket::Listener for VmListener {
188    type Socket = VmStream;
189    type Address = VmAddress;
190
191    fn accept(&self) -> io::Result<(Self::Socket, Self::Address)> {
192        self.accept()
193    }
194
195    fn local_addr(&self) -> io::Result<Self::Address> {
196        self.local_addr()
197    }
198}
199
200impl From<Socket> for VmListener {
201    fn from(s: Socket) -> Self {
202        Self(s)
203    }
204}
205
206impl From<VmListener> for Socket {
207    fn from(s: VmListener) -> Self {
208        s.0
209    }
210}
211
212/// A VM stream socket.
213#[derive(Debug)]
214pub struct VmStream(Socket);
215
216impl VmStream {
217    /// Connects to the specified address.
218    pub fn connect(addr: VmAddress) -> io::Result<Self> {
219        VmSocket::new()?.connect(addr)
220    }
221
222    /// Attempts to clone the underlying socket.
223    pub fn try_clone(&self) -> io::Result<Self> {
224        Ok(Self(self.0.try_clone()?))
225    }
226}
227
228impl From<Socket> for VmStream {
229    fn from(s: Socket) -> Self {
230        Self(s)
231    }
232}
233
234impl From<VmStream> for Socket {
235    fn from(s: VmStream) -> Self {
236        s.0
237    }
238}
239
240impl Read for VmStream {
241    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
242        self.0.read(buf)
243    }
244}
245
246impl Write for VmStream {
247    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
248        self.0.write(buf)
249    }
250
251    fn flush(&mut self) -> io::Result<()> {
252        self.0.flush()
253    }
254}
255
256impl Read for &'_ VmStream {
257    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
258        (&self.0).read(buf)
259    }
260}
261
262impl Write for &'_ VmStream {
263    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
264        (&self.0).write(buf)
265    }
266
267    fn flush(&mut self) -> io::Result<()> {
268        (&self.0).flush()
269    }
270}