vmsocket/
af_vsock.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! AF_VSOCK support.
5
6// UNSAFETY: Calling libc functions on a raw socket fd.
7#![expect(unsafe_code)]
8
9use crate::VmListener;
10use crate::VmSocket;
11use crate::VmStream;
12use mesh::payload::os_resource;
13use socket2::Domain;
14use socket2::SockAddr;
15use socket2::Socket;
16use socket2::Type;
17use std::io;
18use std::os::unix::prelude::*;
19use std::time::Duration;
20
21#[derive(Debug)]
22pub struct Address {
23    pub(crate) cid: u32,
24    pub(crate) port: u32,
25}
26
27impl Address {
28    pub fn new(cid: u32, port: u32) -> Self {
29        Self { cid, port }
30    }
31
32    pub fn vsock_any(port: u32) -> Self {
33        Self::new(!0, port)
34    }
35
36    pub fn vsock_host(port: u32) -> Self {
37        Self::new(2, port)
38    }
39
40    pub fn into_sock_addr(self) -> SockAddr {
41        SockAddr::vsock(self.cid, self.port)
42    }
43
44    pub fn try_from_sock_addr(addr: &SockAddr) -> Option<Self> {
45        let (cid, port) = addr.as_vsock_address()?;
46        Some(Self::new(cid, port))
47    }
48}
49
50impl VmSocket {
51    pub(crate) fn new_inner() -> io::Result<Self> {
52        Ok(Self(Socket::new(Domain::VSOCK, Type::STREAM, None)?))
53    }
54
55    /// Sets the connection timeout for this socket.
56    pub fn set_connect_timeout(&self, duration: Duration) -> io::Result<()> {
57        let timeout = libc::timeval {
58            tv_sec: duration
59                .as_secs()
60                .try_into()
61                .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?,
62            tv_usec: duration.subsec_micros().into(),
63        };
64
65        // SAFETY: Calling a VSOCK-specific option on a VSOCK socket,
66        // and passing a valid pointer to a timeval struct.
67        unsafe {
68            if libc::setsockopt(
69                self.as_fd().as_raw_fd(),
70                libc::AF_VSOCK,
71                6, // SO_VM_SOCKETS_CONNECT_TIMEOUT
72                std::ptr::from_ref(&timeout).cast(),
73                size_of_val(&timeout) as u32,
74            ) != 0
75            {
76                return Err(io::Error::last_os_error());
77            }
78        }
79        Ok(())
80    }
81}
82
83impl AsFd for VmSocket {
84    fn as_fd(&self) -> BorrowedFd<'_> {
85        self.0.as_fd()
86    }
87}
88
89impl From<VmSocket> for OwnedFd {
90    fn from(fd: VmSocket) -> Self {
91        fd.0.into()
92    }
93}
94
95impl From<OwnedFd> for VmSocket {
96    fn from(fd: OwnedFd) -> Self {
97        Self(fd.into())
98    }
99}
100
101impl AsFd for VmListener {
102    fn as_fd(&self) -> BorrowedFd<'_> {
103        self.0.as_fd()
104    }
105}
106
107impl From<VmListener> for OwnedFd {
108    fn from(fd: VmListener) -> Self {
109        fd.0.into()
110    }
111}
112
113impl From<OwnedFd> for VmListener {
114    fn from(fd: OwnedFd) -> Self {
115        Self(fd.into())
116    }
117}
118
119impl AsFd for VmStream {
120    fn as_fd(&self) -> BorrowedFd<'_> {
121        self.0.as_fd()
122    }
123}
124
125impl From<VmStream> for OwnedFd {
126    fn from(fd: VmStream) -> Self {
127        fd.0.into()
128    }
129}
130
131impl From<OwnedFd> for VmStream {
132    fn from(fd: OwnedFd) -> Self {
133        Self(fd.into())
134    }
135}
136
137os_resource!(VmSocket, OwnedFd);
138os_resource!(VmStream, OwnedFd);
139os_resource!(VmListener, OwnedFd);