console_relay/
unix.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4#![cfg(unix)]
5
6//! Console relay support using Unix domain sockets.
7
8use futures::AsyncRead;
9use futures::AsyncWrite;
10use pal_async::driver::Driver;
11use pal_async::socket::PolledSocket;
12use std::path::Path;
13use std::pin::Pin;
14use std::task::Context;
15use std::task::ready;
16use unix_socket::UnixListener;
17use unix_socket::UnixStream;
18
19pub struct UnixSocketConsole {
20    driver: Box<dyn Driver>,
21    state: UnixSocketConsoleState,
22}
23
24enum UnixSocketConsoleState {
25    Listening(PolledSocket<UnixListener>),
26    Connected(PolledSocket<UnixStream>),
27}
28
29impl UnixSocketConsole {
30    pub fn new(driver: Box<dyn Driver>, path: &Path) -> std::io::Result<Self> {
31        let listener = UnixListener::bind(path)?;
32        Ok(Self {
33            state: UnixSocketConsoleState::Listening(PolledSocket::new(&driver, listener)?),
34            driver,
35        })
36    }
37
38    fn poll_connect(
39        &mut self,
40        cx: &mut Context<'_>,
41    ) -> std::task::Poll<std::io::Result<&mut PolledSocket<UnixStream>>> {
42        match &mut self.state {
43            UnixSocketConsoleState::Listening(l) => {
44                let (c, _) = ready!(l.poll_accept(cx))?;
45                let c = PolledSocket::new(&self.driver, c)?;
46                self.state = UnixSocketConsoleState::Connected(c);
47            }
48            UnixSocketConsoleState::Connected(_) => {}
49        }
50        let UnixSocketConsoleState::Connected(c) = &mut self.state else {
51            unreachable!()
52        };
53        Ok(c).into()
54    }
55}
56
57impl AsyncRead for UnixSocketConsole {
58    fn poll_read(
59        mut self: Pin<&mut Self>,
60        cx: &mut Context<'_>,
61        buf: &mut [u8],
62    ) -> std::task::Poll<std::io::Result<usize>> {
63        let c = ready!(self.poll_connect(cx))?;
64        Pin::new(c).poll_read(cx, buf)
65    }
66}
67
68impl AsyncWrite for UnixSocketConsole {
69    fn poll_write(
70        mut self: Pin<&mut Self>,
71        cx: &mut Context<'_>,
72        buf: &[u8],
73    ) -> std::task::Poll<std::io::Result<usize>> {
74        let c = ready!(self.poll_connect(cx))?;
75        Pin::new(c).poll_write(cx, buf)
76    }
77
78    fn poll_flush(
79        mut self: Pin<&mut Self>,
80        cx: &mut Context<'_>,
81    ) -> std::task::Poll<std::io::Result<()>> {
82        match &mut self.state {
83            UnixSocketConsoleState::Listening(_) => Ok(()).into(),
84            UnixSocketConsoleState::Connected(c) => Pin::new(c).poll_flush(cx),
85        }
86    }
87
88    fn poll_close(
89        mut self: Pin<&mut Self>,
90        cx: &mut Context<'_>,
91    ) -> std::task::Poll<std::io::Result<()>> {
92        match &mut self.state {
93            UnixSocketConsoleState::Listening(_) => Ok(()).into(),
94            UnixSocketConsoleState::Connected(c) => Pin::new(c).poll_close(cx),
95        }
96    }
97}