vmsocket/
lib.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

//! Support for `AF_HYPERV` (on Windows) and `AF_VSOCK` (on Linux) socket families.
//!
//! This crate abstracts over the differences between these and provides unified
//! [`VmStream`] and [`VmListener`] types.

#![cfg(any(windows, target_os = "linux"))]

cfg_if::cfg_if! {
    if #[cfg(windows)] {
        mod af_hyperv;
        use af_hyperv as sys;
    } else if #[cfg(unix)] {
        mod af_vsock;
        use af_vsock as sys;
    }
}

use socket2::SockAddr;
use socket2::Socket;
use std::io;
use std::io::Read;
use std::io::Write;

/// A VM socket address.
#[derive(Debug)]
pub struct VmAddress(sys::Address);

impl std::fmt::Display for VmAddress {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        #[cfg(windows)]
        write!(f, "{}:{}", self.vm_id(), self.service_id())?;
        #[cfg(unix)]
        write!(f, "{}:{}", self.cid(), self.port())?;
        Ok(())
    }
}

impl VmAddress {
    /// Creates a new AF_VSOCK address from `cid` and `port`.
    #[cfg(unix)]
    pub fn vsock(cid: u32, port: u32) -> Self {
        Self(sys::Address::new(cid, port))
    }

    /// Creates a new AF_HYPERV address from `vm_id` and `service_id`.
    #[cfg(windows)]
    pub fn hyperv(vm_id: guid::Guid, service_id: guid::Guid) -> Self {
        Self(sys::Address::new(vm_id, service_id))
    }

    /// Creates a new AF_HYPERV address from `vm_id` and VSOCK `port`.
    #[cfg(windows)]
    pub fn hyperv_vsock(vm_id: guid::Guid, port: u32) -> Self {
        Self(sys::Address::vsock(vm_id, port))
    }

    /// Creates a new AF_HYPERV address referring to any VM and the specified
    /// service ID.
    #[cfg(windows)]
    pub fn hyperv_any(service_id: guid::Guid) -> Self {
        Self(sys::Address::hyperv_any(service_id))
    }

    /// Creates a new AF_HYPERV address referring to the parent VM and the
    /// specified service ID.
    #[cfg(windows)]
    pub fn hyperv_host(service_id: guid::Guid) -> Self {
        Self(sys::Address::hyperv_host(service_id))
    }

    /// Creates a new address referring to any VM, with the specified VSOCK
    /// port.
    pub fn vsock_any(port: u32) -> Self {
        Self(sys::Address::vsock_any(port))
    }

    /// Creates a new address referring to the host, with the specified VSOCK
    /// port.
    pub fn vsock_host(port: u32) -> Self {
        Self(sys::Address::vsock_host(port))
    }

    /// Creates a new address from the specified [`SockAddr`] when the address
    /// is an `AF_HYPERV` (on Windows) or `AF_VSOCK` (on Linux) address.
    pub fn try_from_sock_addr(addr: &SockAddr) -> Option<Self> {
        Some(Self(sys::Address::try_from_sock_addr(addr)?))
    }

    /// Gets the VSOCK CID.
    #[cfg(unix)]
    pub fn cid(&self) -> u32 {
        self.0.cid
    }

    /// Gets the VSOCK port.
    #[cfg(unix)]
    pub fn port(&self) -> u32 {
        self.0.port
    }

    /// Gets the VM ID.
    #[cfg(windows)]
    pub fn vm_id(&self) -> guid::Guid {
        self.0.vm_id
    }

    /// Gets the service ID.
    #[cfg(windows)]
    pub fn service_id(&self) -> guid::Guid {
        self.0.service_id
    }
}

impl From<VmAddress> for SockAddr {
    fn from(address: VmAddress) -> Self {
        address.0.into_sock_addr()
    }
}

/// A VM socket that has not yet been bound.
pub struct VmSocket(Socket);

impl VmSocket {
    /// Creates a new stream socket not bound or connected to anything.
    pub fn new() -> io::Result<Self> {
        Self::new_inner()
    }

    /// Binds the socket to `address`.
    pub fn bind(&mut self, address: VmAddress) -> io::Result<()> {
        self.0.bind(&address.into())?;
        Ok(())
    }

    /// Listens for connections, returning a [`VmListener`].
    pub fn listen(self, backlog: i32) -> io::Result<VmListener> {
        self.0.listen(backlog)?;
        Ok(VmListener(self.0))
    }

    /// Connects to `address`, returning a [`VmStream`].
    pub fn connect(self, address: VmAddress) -> io::Result<VmStream> {
        self.0.connect(&address.into())?;
        Ok(VmStream(self.0))
    }
}

impl From<Socket> for VmSocket {
    fn from(s: Socket) -> Self {
        Self(s)
    }
}

impl From<VmSocket> for Socket {
    fn from(s: VmSocket) -> Self {
        s.0
    }
}

/// A VM socket listener.
#[derive(Debug)]
pub struct VmListener(Socket);

impl VmListener {
    /// Creates a new socket bound to the specified address.
    pub fn bind(address: VmAddress) -> io::Result<Self> {
        let mut s = VmSocket::new()?;
        s.bind(address)?;
        s.listen(4)
    }

    /// Accepts the next connection.
    pub fn accept(&self) -> io::Result<(VmStream, VmAddress)> {
        let (s, addr) = self.0.accept()?;
        Ok((VmStream(s), VmAddress::try_from_sock_addr(&addr).unwrap()))
    }

    /// Retrieves the address that the listener is bound to.
    pub fn local_addr(&self) -> io::Result<VmAddress> {
        Ok(VmAddress::try_from_sock_addr(&self.0.local_addr()?).unwrap())
    }
}

impl pal_async::socket::Listener for VmListener {
    type Socket = VmStream;
    type Address = VmAddress;

    fn accept(&self) -> io::Result<(Self::Socket, Self::Address)> {
        self.accept()
    }

    fn local_addr(&self) -> io::Result<Self::Address> {
        self.local_addr()
    }
}

impl From<Socket> for VmListener {
    fn from(s: Socket) -> Self {
        Self(s)
    }
}

impl From<VmListener> for Socket {
    fn from(s: VmListener) -> Self {
        s.0
    }
}

/// A VM stream socket.
#[derive(Debug)]
pub struct VmStream(Socket);

impl VmStream {
    /// Connects to the specified address.
    pub fn connect(addr: VmAddress) -> io::Result<Self> {
        VmSocket::new()?.connect(addr)
    }

    /// Attempts to clone the underlying socket.
    pub fn try_clone(&self) -> io::Result<Self> {
        Ok(Self(self.0.try_clone()?))
    }
}

impl From<Socket> for VmStream {
    fn from(s: Socket) -> Self {
        Self(s)
    }
}

impl From<VmStream> for Socket {
    fn from(s: VmStream) -> Self {
        s.0
    }
}

impl Read for VmStream {
    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
        self.0.read(buf)
    }
}

impl Write for VmStream {
    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
        self.0.write(buf)
    }

    fn flush(&mut self) -> io::Result<()> {
        self.0.flush()
    }
}

impl Read for &'_ VmStream {
    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
        (&self.0).read(buf)
    }
}

impl Write for &'_ VmStream {
    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
        (&self.0).write(buf)
    }

    fn flush(&mut self) -> io::Result<()> {
        (&self.0).flush()
    }
}