1#![forbid(unsafe_code)]
5
6use fs_err::PathExt;
10use guid::Guid;
11use std::io::Write;
12use std::path::Path;
13use std::path::PathBuf;
14use std::str::FromStr;
15
16pub const HYBRID_CONNECT_REQUEST_LEN: usize =
19 "CONNECT 00000000-facb-11e6-bd58-64006a7986d3\n".len();
20
21const VSOCK_TEMPLATE: Guid = guid::guid!("00000000-facb-11e6-bd58-64006a7986d3");
23
24#[derive(Clone, Copy, Debug, PartialEq, Eq)]
27pub enum VsockPortOrId {
28 Port(u32),
30 Id(Guid),
32}
33
34impl VsockPortOrId {
35 pub fn port(&self) -> Option<u32> {
38 match self {
39 VsockPortOrId::Port(port) => Some(*port),
40 VsockPortOrId::Id(service_id) => {
41 let stripped_id = Guid {
42 data1: 0,
43 ..*service_id
44 };
45 (VSOCK_TEMPLATE == stripped_id).then_some(service_id.data1)
46 }
47 }
48 }
49
50 pub fn id(&self) -> Guid {
53 match self {
54 VsockPortOrId::Port(port) => Self::port_to_id(*port),
55 VsockPortOrId::Id(service_id) => *service_id,
56 }
57 }
58
59 pub fn port_to_id(port: u32) -> Guid {
61 Guid {
62 data1: port,
63 ..VSOCK_TEMPLATE
64 }
65 }
66
67 pub fn host_uds_path(&self, base_path: impl AsRef<Path>) -> Result<PathBuf, UdsPathError> {
73 let base_path = base_path.as_ref();
74 let mut path = base_path.as_os_str().to_owned();
75 if let Some(port) = self.port() {
76 path.push(format!("_{port}"));
79 if Path::new(&path).fs_err_try_exists()? {
80 return Ok(path.into());
81 }
82
83 path.clear();
85 path.push(base_path);
86 }
87
88 path.push(format!("_{}", self.id()));
89 if !Path::new(&path).fs_err_try_exists()? {
90 return Err(UdsPathError::NoListener(path.into()));
91 }
92
93 Ok(path.into())
94 }
95
96 pub fn parse_connect_request(buf: &[u8]) -> Result<Self, ParseError> {
99 let rest = strip_ascii_prefix_case_insensitive(buf, b"CONNECT ")
100 .ok_or(ParseError::MissingPrefix)?;
101
102 let rest = std::str::from_utf8(rest).map_err(ParseError::InvalidString)?;
103 if let Ok(port) = u32::from_str(rest) {
104 Ok(VsockPortOrId::Port(port))
105 } else if let Ok(service_id) = Guid::from_str(rest) {
106 Ok(VsockPortOrId::Id(service_id))
107 } else {
108 Err(ParseError::InvalidFormat(rest.to_string()))
109 }
110 }
111
112 pub fn get_ok_response(&self) -> String {
119 match self {
120 VsockPortOrId::Port(port) => format!("OK {}\n", port),
121 VsockPortOrId::Id(service_id) => format!("OK {}\n", service_id),
122 }
123 }
124
125 pub fn write_ok_response(&self, buf: &mut [u8]) -> usize {
132 let mut cursor = std::io::Cursor::new(buf);
133 match self {
134 VsockPortOrId::Port(port) => {
135 writeln!(cursor, "OK {}", port).expect("buffer should be large enough")
136 }
137 VsockPortOrId::Id(service_id) => {
138 writeln!(cursor, "OK {}", service_id).expect("buffer should be large enough")
139 }
140 }
141
142 cursor.position() as usize
143 }
144}
145
146fn strip_ascii_prefix_case_insensitive<'a>(s: &'a [u8], prefix: &[u8]) -> Option<&'a [u8]> {
147 if s.len() >= prefix.len() && s[..prefix.len()].eq_ignore_ascii_case(prefix) {
148 Some(&s[prefix.len()..])
149 } else {
150 None
151 }
152}
153
154#[derive(Debug, thiserror::Error)]
156pub enum UdsPathError {
157 #[error("no hybrid vsock listener at {}", _0.display())]
159 NoListener(PathBuf),
160 #[error(transparent)]
162 Io(#[from] std::io::Error),
163}
164
165#[derive(Debug, thiserror::Error)]
167pub enum ParseError {
168 #[error("connect request did not fit")]
170 RequestTooLong,
171 #[error("missing CONNECT prefix")]
173 MissingPrefix,
174 #[error("invalid UTF-8 in connect request")]
176 InvalidString(#[from] std::str::Utf8Error),
177 #[error("invalid port or service ID: {0}")]
179 InvalidFormat(String),
180}
181
182#[cfg(test)]
183mod tests {
184 use super::*;
185 use guid::guid;
186
187 #[test]
188 fn test_read_hybrid_vsock_connect_uppercase() {
189 let connect = b"CONNECT 1234";
190 let request = VsockPortOrId::parse_connect_request(connect).unwrap();
191 assert_eq!(request, VsockPortOrId::Port(1234));
192 assert_eq!(
193 request.id(),
194 Guid {
195 data1: 1234,
196 ..VSOCK_TEMPLATE
197 }
198 );
199 }
200
201 #[test]
202 fn test_read_hybrid_vsock_connect_lowercase() {
203 let connect = b"connect 1234";
204 let request = VsockPortOrId::parse_connect_request(connect).unwrap();
205 assert_eq!(request, VsockPortOrId::Port(1234));
206 assert_eq!(
207 request.id(),
208 Guid {
209 data1: 1234,
210 ..VSOCK_TEMPLATE
211 }
212 );
213 }
214
215 #[test]
216 fn test_read_hybrid_vsock_connect_guid() {
217 let connect = b"CONNECT 00000123-facb-11e6-bd58-64006a7986d3";
218 let request = VsockPortOrId::parse_connect_request(connect).unwrap();
219 let expected = guid!("00000123-facb-11e6-bd58-64006a7986d3");
220 assert_eq!(request, VsockPortOrId::Id(expected));
221 assert_eq!(request.port(), Some(0x123));
222 assert_eq!(request.id(), expected);
223
224 let connect = b"CONNECT EE59B4BF-A573-48D0-9C51-BB0E72C2B139";
225 let request = VsockPortOrId::parse_connect_request(connect).unwrap();
226 let expected = guid!("ee59b4bf-a573-48d0-9c51-bb0e72c2b139");
227 assert_eq!(request, VsockPortOrId::Id(expected));
228 assert_eq!(request.port(), None);
229 assert_eq!(request.id(), expected);
230 }
231
232 #[test]
233 fn test_get_ok_response() {
234 let port_request = VsockPortOrId::Port(1234);
235 assert_eq!(port_request.get_ok_response(), "OK 1234\n");
236
237 let guid = guid!("00000123-facb-11e6-bd58-64006a7986d3");
238 let id_request = VsockPortOrId::Id(guid);
239 assert_eq!(
240 id_request.get_ok_response(),
241 "OK 00000123-facb-11e6-bd58-64006a7986d3\n"
242 );
243 }
244}