serial_socket/
net.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! Socket serial backend, usable for both TCP and Unix sockets (even on
5//! Windows).
6
7use futures::AsyncRead;
8use futures::AsyncWrite;
9use inspect::InspectMut;
10use mesh::MeshPayload;
11use pal_async::driver::Driver;
12use pal_async::interest::PollEvents;
13use pal_async::socket::PollReady;
14use pal_async::socket::PolledSocket;
15use serial_core::SerialIo;
16use serial_core::resources::ResolveSerialBackendParams;
17use serial_core::resources::ResolvedSerialBackend;
18use socket2::Socket;
19use std::io;
20use std::net::TcpListener;
21use std::net::TcpStream;
22use std::pin::Pin;
23use std::task::Context;
24use std::task::Poll;
25use std::task::ready;
26use unix_socket::UnixListener;
27use unix_socket::UnixStream;
28use vm_resource::ResolveResource;
29use vm_resource::Resource;
30use vm_resource::ResourceId;
31use vm_resource::declare_static_resolver;
32use vm_resource::kind::SerialBackendHandle;
33
34#[derive(Debug, MeshPayload)]
35pub struct OpenSocketSerialConfig {
36    pub current: Option<Socket>,
37    pub listener: Option<Socket>,
38}
39
40impl ResourceId<SerialBackendHandle> for OpenSocketSerialConfig {
41    const ID: &'static str = "socket";
42}
43
44pub struct SocketSerialResolver;
45declare_static_resolver!(
46    SocketSerialResolver,
47    (SerialBackendHandle, OpenSocketSerialConfig)
48);
49
50impl ResolveResource<SerialBackendHandle, OpenSocketSerialConfig> for SocketSerialResolver {
51    type Output = ResolvedSerialBackend;
52    type Error = io::Error;
53
54    fn resolve(
55        &self,
56        rsrc: OpenSocketSerialConfig,
57        input: ResolveSerialBackendParams<'_>,
58    ) -> Result<Self::Output, Self::Error> {
59        Ok(SocketSerialBackend::new(input.driver, rsrc)?.into())
60    }
61}
62
63impl From<UnixStream> for OpenSocketSerialConfig {
64    fn from(stream: UnixStream) -> Self {
65        Self {
66            current: Some(stream.into()),
67            listener: None,
68        }
69    }
70}
71
72impl From<UnixListener> for OpenSocketSerialConfig {
73    fn from(listener: UnixListener) -> Self {
74        Self {
75            current: None,
76            listener: Some(listener.into()),
77        }
78    }
79}
80
81impl From<TcpStream> for OpenSocketSerialConfig {
82    fn from(stream: TcpStream) -> Self {
83        Self {
84            current: Some(stream.into()),
85            listener: None,
86        }
87    }
88}
89
90impl From<TcpListener> for OpenSocketSerialConfig {
91    fn from(listener: TcpListener) -> Self {
92        Self {
93            current: None,
94            listener: Some(listener.into()),
95        }
96    }
97}
98
99pub struct SocketSerialBackend {
100    driver: Box<dyn Driver>,
101    current: Option<PolledSocket<Socket>>,
102    listener: Option<PolledSocket<Socket>>,
103}
104
105impl InspectMut for SocketSerialBackend {
106    fn inspect_mut(&mut self, req: inspect::Request<'_>) {
107        req.respond().field_with("state", || {
108            if self.current.is_some() {
109                "connected"
110            } else if self.listener.is_some() {
111                "listening"
112            } else {
113                "done"
114            }
115        });
116    }
117}
118
119impl SocketSerialBackend {
120    pub fn new(driver: Box<dyn Driver>, config: OpenSocketSerialConfig) -> io::Result<Self> {
121        let current = config
122            .current
123            .map(|s| PolledSocket::new(&driver, s))
124            .transpose()?;
125        let listener = config
126            .listener
127            .map(|s| PolledSocket::new(&driver, s))
128            .transpose()?;
129        Ok(Self {
130            driver: Box::new(driver),
131            current,
132            listener,
133        })
134    }
135
136    pub fn into_config(self) -> OpenSocketSerialConfig {
137        OpenSocketSerialConfig {
138            current: self.current.map(PolledSocket::into_inner),
139            listener: self.listener.map(PolledSocket::into_inner),
140        }
141    }
142}
143
144impl From<SocketSerialBackend> for Resource<SerialBackendHandle> {
145    fn from(value: SocketSerialBackend) -> Self {
146        Resource::new(value.into_config())
147    }
148}
149
150impl SerialIo for SocketSerialBackend {
151    fn is_connected(&self) -> bool {
152        self.current.is_some()
153    }
154
155    fn poll_connect(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
156        if self.current.is_some() {
157            Poll::Ready(Ok(()))
158        } else if let Some(listener) = &mut self.listener {
159            let (socket, _) = ready!(listener.poll_accept(cx))?;
160            self.current = Some(PolledSocket::new(&self.driver, socket)?);
161            Poll::Ready(Ok(()))
162        } else {
163            // This will never complete.
164            Poll::Pending
165        }
166    }
167
168    fn poll_disconnect(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
169        if let Some(current) = &mut self.current {
170            ready!(current.poll_ready(cx, PollEvents::RDHUP));
171            self.current = None;
172        }
173        Poll::Ready(Ok(()))
174    }
175}
176
177impl AsyncRead for SocketSerialBackend {
178    fn poll_read(
179        mut self: Pin<&mut Self>,
180        cx: &mut Context<'_>,
181        buf: &mut [u8],
182    ) -> Poll<io::Result<usize>> {
183        let Some(current) = &mut self.current else {
184            return Poll::Ready(Ok(0));
185        };
186        let r = ready!(Pin::new(current).poll_read(cx, buf));
187        if matches!(r, Ok(0)) {
188            self.current = None;
189        }
190        Poll::Ready(r)
191    }
192}
193
194impl AsyncWrite for SocketSerialBackend {
195    fn poll_write(
196        mut self: Pin<&mut Self>,
197        cx: &mut Context<'_>,
198        buf: &[u8],
199    ) -> Poll<io::Result<usize>> {
200        let Some(current) = &mut self.current else {
201            return Poll::Ready(Ok(buf.len()));
202        };
203        let r = ready!(Pin::new(current).poll_write(cx, buf));
204        if matches!(&r, Err(err) if err.kind() == io::ErrorKind::BrokenPipe) {
205            return Poll::Ready(Ok(buf.len()));
206        }
207        Poll::Ready(r)
208    }
209
210    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
211        let Some(current) = &mut self.current else {
212            return Poll::Ready(Ok(()));
213        };
214        let r = ready!(Pin::new(current).poll_flush(cx));
215        if matches!(&r, Err(err) if err.kind() == io::ErrorKind::BrokenPipe) {
216            return Poll::Ready(Ok(()));
217        }
218        Poll::Ready(r)
219    }
220
221    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
222        let Some(current) = &mut self.current else {
223            return Poll::Ready(Ok(()));
224        };
225        let r = ready!(Pin::new(current).poll_close(cx));
226        if matches!(&r, Err(err) if err.kind() == io::ErrorKind::BrokenPipe) {
227            return Poll::Ready(Ok(()));
228        }
229        Poll::Ready(r)
230    }
231}