Skip to main content

hybrid_vsock/
lib.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4#![forbid(unsafe_code)]
5
6//! Provides helper functions for bridging between vsock/hvsocket and Unix domain sockets, utilized
7//! by VMBus-based hvsocket and virtio-vsock.
8
9use fs_err::PathExt;
10use guid::Guid;
11use std::io::Write;
12use std::path::Path;
13use std::path::PathBuf;
14use std::str::FromStr;
15
16/// The maximum length of a valid connect request. It could be shorter if it contains a port number
17/// instead of a service ID.
18pub const HYBRID_CONNECT_REQUEST_LEN: usize =
19    "CONNECT 00000000-facb-11e6-bd58-64006a7986d3\n".len();
20
21/// This GUID is an embedding of the AF_VSOCK port into an AF_HYPERV service ID.
22const VSOCK_TEMPLATE: Guid = guid::guid!("00000000-facb-11e6-bd58-64006a7986d3");
23
24/// Represents the local or remote port number for a vsock connection, or the service ID or instance
25/// ID for an hvsocket connection.
26#[derive(Clone, Copy, Debug, PartialEq, Eq)]
27pub enum VsockPortOrId {
28    /// The vsock port number.
29    Port(u32),
30    /// The hvsocket service ID or instance ID, represented as a GUID.
31    Id(Guid),
32}
33
34impl VsockPortOrId {
35    /// Gets the vsock port number. This will return `Some` if the instance either directly uses a
36    /// port, or uses a service ID that matches the hvsocket vsock template.
37    pub fn port(&self) -> Option<u32> {
38        match self {
39            VsockPortOrId::Port(port) => Some(*port),
40            VsockPortOrId::Id(service_id) => {
41                let stripped_id = Guid {
42                    data1: 0,
43                    ..*service_id
44                };
45                (VSOCK_TEMPLATE == stripped_id).then_some(service_id.data1)
46            }
47        }
48    }
49
50    /// Gets the vsock service ID. If this instance is a port, it will use the hvsocket vsock
51    /// template to construct a service ID.
52    pub fn id(&self) -> Guid {
53        match self {
54            VsockPortOrId::Port(port) => Self::port_to_id(*port),
55            VsockPortOrId::Id(service_id) => *service_id,
56        }
57    }
58
59    /// Converts a vsock port number into a GUID using the hvsocket vsock template.
60    pub fn port_to_id(port: u32) -> Guid {
61        Guid {
62            data1: port,
63            ..VSOCK_TEMPLATE
64        }
65    }
66
67    /// Gets the path of a Unix domain socket listener on the host using this port or id.
68    ///
69    /// If this instance is a port, or uses a GUID that matches the hvsocket vsock template, this
70    /// function will first use a path with that port number appended. If that path doesn't exist,
71    /// or if this instance uses a non-vsock GUID, it will use a path with the full ID.
72    pub fn host_uds_path(&self, base_path: impl AsRef<Path>) -> Result<PathBuf, UdsPathError> {
73        let base_path = base_path.as_ref();
74        let mut path = base_path.as_os_str().to_owned();
75        if let Some(port) = self.port() {
76            // This is a vsock connection, so first try connecting after appending the
77            // port to the path.
78            path.push(format!("_{port}"));
79            if Path::new(&path).fs_err_try_exists()? {
80                return Ok(path.into());
81            }
82
83            // If the port didn't exist, try again with the service ID.
84            path.clear();
85            path.push(base_path);
86        }
87
88        path.push(format!("_{}", self.id()));
89        if !Path::new(&path).fs_err_try_exists()? {
90            return Err(UdsPathError::NoListener(path.into()));
91        }
92
93        Ok(path.into())
94    }
95
96    /// Parses a connection request from a buffer containing a UTF-8 string of the format "CONNECT
97    /// \<port or service ID>\n".
98    pub fn parse_connect_request(buf: &[u8]) -> Result<Self, ParseError> {
99        let rest = strip_ascii_prefix_case_insensitive(buf, b"CONNECT ")
100            .ok_or(ParseError::MissingPrefix)?;
101
102        let rest = std::str::from_utf8(rest).map_err(ParseError::InvalidString)?;
103        if let Ok(port) = u32::from_str(rest) {
104            Ok(VsockPortOrId::Port(port))
105        } else if let Ok(service_id) = Guid::from_str(rest) {
106            Ok(VsockPortOrId::Id(service_id))
107        } else {
108            Err(ParseError::InvalidFormat(rest.to_string()))
109        }
110    }
111
112    /// Gets the response string that should be sent back to the guest on a successful connection,
113    /// of the format "OK \<port or service ID>\n".
114    ///
115    /// In this case, any instance using a GUID will be formatted using the full service ID, even if
116    /// it matches the hvsocket vsock template. The format returned should always match the format
117    /// that was used in the "CONNECT" request.
118    pub fn get_ok_response(&self) -> String {
119        match self {
120            VsockPortOrId::Port(port) => format!("OK {}\n", port),
121            VsockPortOrId::Id(service_id) => format!("OK {}\n", service_id),
122        }
123    }
124
125    /// Writes the response string that should be sent back to the guest on a successful connection
126    /// into the provided buffer, and returns the number of bytes written.
127    ///
128    /// # Panics
129    ///
130    /// This function will panic if the buffer is too small to hold the response.
131    pub fn write_ok_response(&self, buf: &mut [u8]) -> usize {
132        let mut cursor = std::io::Cursor::new(buf);
133        match self {
134            VsockPortOrId::Port(port) => {
135                writeln!(cursor, "OK {}", port).expect("buffer should be large enough")
136            }
137            VsockPortOrId::Id(service_id) => {
138                writeln!(cursor, "OK {}", service_id).expect("buffer should be large enough")
139            }
140        }
141
142        cursor.position() as usize
143    }
144}
145
146fn strip_ascii_prefix_case_insensitive<'a>(s: &'a [u8], prefix: &[u8]) -> Option<&'a [u8]> {
147    if s.len() >= prefix.len() && s[..prefix.len()].eq_ignore_ascii_case(prefix) {
148        Some(&s[prefix.len()..])
149    } else {
150        None
151    }
152}
153
154/// Error returned by [`VsockPortOrId::host_uds_path`].
155#[derive(Debug, thiserror::Error)]
156pub enum UdsPathError {
157    /// No hybrid vsock listener was found at the specified path.
158    #[error("no hybrid vsock listener at {}", _0.display())]
159    NoListener(PathBuf),
160    /// An I/O error occurred while checking for the listener.
161    #[error(transparent)]
162    Io(#[from] std::io::Error),
163}
164
165/// Error returned by [`VsockPortOrId::parse_connect_request`].
166#[derive(Debug, thiserror::Error)]
167pub enum ParseError {
168    /// The connect request did not contain a newline within the maximum expected length.
169    #[error("connect request did not fit")]
170    RequestTooLong,
171    /// The connect request did not start with the expected "CONNECT " prefix.
172    #[error("missing CONNECT prefix")]
173    MissingPrefix,
174    /// The connect request contained invalid UTF-8.
175    #[error("invalid UTF-8 in connect request")]
176    InvalidString(#[from] std::str::Utf8Error),
177    /// The connect request did not contain a valid port number or service ID.
178    #[error("invalid port or service ID: {0}")]
179    InvalidFormat(String),
180}
181
182#[cfg(test)]
183mod tests {
184    use super::*;
185    use guid::guid;
186
187    #[test]
188    fn test_read_hybrid_vsock_connect_uppercase() {
189        let connect = b"CONNECT 1234";
190        let request = VsockPortOrId::parse_connect_request(connect).unwrap();
191        assert_eq!(request, VsockPortOrId::Port(1234));
192        assert_eq!(
193            request.id(),
194            Guid {
195                data1: 1234,
196                ..VSOCK_TEMPLATE
197            }
198        );
199    }
200
201    #[test]
202    fn test_read_hybrid_vsock_connect_lowercase() {
203        let connect = b"connect 1234";
204        let request = VsockPortOrId::parse_connect_request(connect).unwrap();
205        assert_eq!(request, VsockPortOrId::Port(1234));
206        assert_eq!(
207            request.id(),
208            Guid {
209                data1: 1234,
210                ..VSOCK_TEMPLATE
211            }
212        );
213    }
214
215    #[test]
216    fn test_read_hybrid_vsock_connect_guid() {
217        let connect = b"CONNECT 00000123-facb-11e6-bd58-64006a7986d3";
218        let request = VsockPortOrId::parse_connect_request(connect).unwrap();
219        let expected = guid!("00000123-facb-11e6-bd58-64006a7986d3");
220        assert_eq!(request, VsockPortOrId::Id(expected));
221        assert_eq!(request.port(), Some(0x123));
222        assert_eq!(request.id(), expected);
223
224        let connect = b"CONNECT EE59B4BF-A573-48D0-9C51-BB0E72C2B139";
225        let request = VsockPortOrId::parse_connect_request(connect).unwrap();
226        let expected = guid!("ee59b4bf-a573-48d0-9c51-bb0e72c2b139");
227        assert_eq!(request, VsockPortOrId::Id(expected));
228        assert_eq!(request.port(), None);
229        assert_eq!(request.id(), expected);
230    }
231
232    #[test]
233    fn test_get_ok_response() {
234        let port_request = VsockPortOrId::Port(1234);
235        assert_eq!(port_request.get_ok_response(), "OK 1234\n");
236
237        let guid = guid!("00000123-facb-11e6-bd58-64006a7986d3");
238        let id_request = VsockPortOrId::Id(guid);
239        assert_eq!(
240            id_request.get_ok_response(),
241            "OK 00000123-facb-11e6-bd58-64006a7986d3\n"
242        );
243    }
244}