1use futures::AsyncRead;
8use futures::AsyncWrite;
9use inspect::InspectMut;
10use mesh::MeshPayload;
11use pal_async::driver::Driver;
12use pal_async::interest::PollEvents;
13use pal_async::socket::PollReady;
14use pal_async::socket::PolledSocket;
15use serial_core::SerialIo;
16use serial_core::resources::ResolveSerialBackendParams;
17use serial_core::resources::ResolvedSerialBackend;
18use socket2::Socket;
19use std::io;
20use std::net::TcpListener;
21use std::net::TcpStream;
22use std::pin::Pin;
23use std::task::Context;
24use std::task::Poll;
25use std::task::ready;
26use unix_socket::UnixListener;
27use unix_socket::UnixStream;
28use vm_resource::ResolveResource;
29use vm_resource::Resource;
30use vm_resource::ResourceId;
31use vm_resource::declare_static_resolver;
32use vm_resource::kind::SerialBackendHandle;
33
34#[derive(Debug, MeshPayload)]
35pub struct OpenSocketSerialConfig {
36 pub current: Option<Socket>,
37 pub listener: Option<Socket>,
38}
39
40impl ResourceId<SerialBackendHandle> for OpenSocketSerialConfig {
41 const ID: &'static str = "socket";
42}
43
44pub struct SocketSerialResolver;
45declare_static_resolver!(
46 SocketSerialResolver,
47 (SerialBackendHandle, OpenSocketSerialConfig)
48);
49
50impl ResolveResource<SerialBackendHandle, OpenSocketSerialConfig> for SocketSerialResolver {
51 type Output = ResolvedSerialBackend;
52 type Error = io::Error;
53
54 fn resolve(
55 &self,
56 rsrc: OpenSocketSerialConfig,
57 input: ResolveSerialBackendParams<'_>,
58 ) -> Result<Self::Output, Self::Error> {
59 Ok(SocketSerialBackend::new(input.driver, rsrc)?.into())
60 }
61}
62
63impl From<UnixStream> for OpenSocketSerialConfig {
64 fn from(stream: UnixStream) -> Self {
65 Self {
66 current: Some(stream.into()),
67 listener: None,
68 }
69 }
70}
71
72impl From<UnixListener> for OpenSocketSerialConfig {
73 fn from(listener: UnixListener) -> Self {
74 Self {
75 current: None,
76 listener: Some(listener.into()),
77 }
78 }
79}
80
81impl From<TcpStream> for OpenSocketSerialConfig {
82 fn from(stream: TcpStream) -> Self {
83 Self {
84 current: Some(stream.into()),
85 listener: None,
86 }
87 }
88}
89
90impl From<TcpListener> for OpenSocketSerialConfig {
91 fn from(listener: TcpListener) -> Self {
92 Self {
93 current: None,
94 listener: Some(listener.into()),
95 }
96 }
97}
98
99pub struct SocketSerialBackend {
100 driver: Box<dyn Driver>,
101 current: Option<PolledSocket<Socket>>,
102 listener: Option<PolledSocket<Socket>>,
103}
104
105impl InspectMut for SocketSerialBackend {
106 fn inspect_mut(&mut self, req: inspect::Request<'_>) {
107 req.respond().field_with("state", || {
108 if self.current.is_some() {
109 "connected"
110 } else if self.listener.is_some() {
111 "listening"
112 } else {
113 "done"
114 }
115 });
116 }
117}
118
119impl SocketSerialBackend {
120 pub fn new(driver: Box<dyn Driver>, config: OpenSocketSerialConfig) -> io::Result<Self> {
121 let current = config
122 .current
123 .map(|s| PolledSocket::new(&driver, s))
124 .transpose()?;
125 let listener = config
126 .listener
127 .map(|s| PolledSocket::new(&driver, s))
128 .transpose()?;
129 Ok(Self {
130 driver: Box::new(driver),
131 current,
132 listener,
133 })
134 }
135
136 pub fn into_config(self) -> OpenSocketSerialConfig {
137 OpenSocketSerialConfig {
138 current: self.current.map(PolledSocket::into_inner),
139 listener: self.listener.map(PolledSocket::into_inner),
140 }
141 }
142}
143
144impl From<SocketSerialBackend> for Resource<SerialBackendHandle> {
145 fn from(value: SocketSerialBackend) -> Self {
146 Resource::new(value.into_config())
147 }
148}
149
150impl SerialIo for SocketSerialBackend {
151 fn is_connected(&self) -> bool {
152 self.current.is_some()
153 }
154
155 fn poll_connect(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
156 if self.current.is_some() {
157 Poll::Ready(Ok(()))
158 } else if let Some(listener) = &mut self.listener {
159 let (socket, _) = ready!(listener.poll_accept(cx))?;
160 self.current = Some(PolledSocket::new(&self.driver, socket)?);
161 Poll::Ready(Ok(()))
162 } else {
163 Poll::Pending
165 }
166 }
167
168 fn poll_disconnect(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
169 if let Some(current) = &mut self.current {
170 ready!(current.poll_ready(cx, PollEvents::RDHUP));
171 self.current = None;
172 }
173 Poll::Ready(Ok(()))
174 }
175}
176
177impl AsyncRead for SocketSerialBackend {
178 fn poll_read(
179 mut self: Pin<&mut Self>,
180 cx: &mut Context<'_>,
181 buf: &mut [u8],
182 ) -> Poll<io::Result<usize>> {
183 let Some(current) = &mut self.current else {
184 return Poll::Ready(Ok(0));
185 };
186 let r = ready!(Pin::new(current).poll_read(cx, buf));
187 if matches!(r, Ok(0)) {
188 self.current = None;
189 }
190 Poll::Ready(r)
191 }
192}
193
194impl AsyncWrite for SocketSerialBackend {
195 fn poll_write(
196 mut self: Pin<&mut Self>,
197 cx: &mut Context<'_>,
198 buf: &[u8],
199 ) -> Poll<io::Result<usize>> {
200 let Some(current) = &mut self.current else {
201 return Poll::Ready(Ok(buf.len()));
202 };
203 let r = ready!(Pin::new(current).poll_write(cx, buf));
204 if matches!(&r, Err(err) if err.kind() == io::ErrorKind::BrokenPipe) {
205 return Poll::Ready(Ok(buf.len()));
206 }
207 Poll::Ready(r)
208 }
209
210 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
211 let Some(current) = &mut self.current else {
212 return Poll::Ready(Ok(()));
213 };
214 let r = ready!(Pin::new(current).poll_flush(cx));
215 if matches!(&r, Err(err) if err.kind() == io::ErrorKind::BrokenPipe) {
216 return Poll::Ready(Ok(()));
217 }
218 Poll::Ready(r)
219 }
220
221 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
222 let Some(current) = &mut self.current else {
223 return Poll::Ready(Ok(()));
224 };
225 let r = ready!(Pin::new(current).poll_close(cx));
226 if matches!(&r, Err(err) if err.kind() == io::ErrorKind::BrokenPipe) {
227 return Poll::Ready(Ok(()));
228 }
229 Poll::Ready(r)
230 }
231}