vnc_worker/
lib.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! A worker for running a VNC server.
5
6#![forbid(unsafe_code)]
7
8use anyhow::Context;
9use anyhow::anyhow;
10use futures::FutureExt;
11use input_core::InputData;
12use input_core::KeyboardData;
13use input_core::MouseData;
14use mesh::message::MeshField;
15use mesh_worker::Worker;
16use mesh_worker::WorkerId;
17use mesh_worker::WorkerRpc;
18use pal_async::local::LocalDriver;
19use pal_async::local::block_with_io;
20use pal_async::socket::Listener;
21use pal_async::socket::PolledSocket;
22use pal_async::timer::PolledTimer;
23use std::future::Future;
24use std::net::TcpListener;
25use std::pin::Pin;
26use std::time::Duration;
27use tracing_helpers::AnyhowValueExt;
28use vnc_worker_defs::VncParameters;
29
30/// A worker for running a VNC server.
31pub struct VncWorker<T: Listener> {
32    listener: T,
33    state: State<T>,
34}
35
36/// The current server state.
37enum State<T: Listener> {
38    Listening {
39        view: ViewWrapper,
40        input: VncInput,
41    },
42    Connected {
43        remote_addr: T::Address,
44        task: Pin<Box<dyn Future<Output = (ViewWrapper, VncInput)>>>,
45        abort: mesh::OneshotSender<()>,
46    },
47    Invalid,
48}
49
50impl Worker for VncWorker<TcpListener> {
51    type Parameters = VncParameters<TcpListener>;
52    type State = VncParameters<TcpListener>;
53    const ID: WorkerId<Self::Parameters> = vnc_worker_defs::VNC_WORKER_TCP;
54
55    fn new(params: Self::Parameters) -> anyhow::Result<Self> {
56        Self::new_inner(params)
57    }
58
59    fn restart(state: Self::State) -> anyhow::Result<Self> {
60        Self::new(state)
61    }
62
63    fn run(self, rpc_recv: mesh::Receiver<WorkerRpc<Self::State>>) -> anyhow::Result<()> {
64        self.run_inner(rpc_recv)
65    }
66}
67
68#[cfg(any(windows, target_os = "linux"))]
69impl Worker for VncWorker<vmsocket::VmListener> {
70    type Parameters = VncParameters<vmsocket::VmListener>;
71    type State = VncParameters<vmsocket::VmListener>;
72    const ID: WorkerId<Self::Parameters> = vnc_worker_defs::VNC_WORKER_VMSOCKET;
73
74    fn new(params: Self::Parameters) -> anyhow::Result<Self> {
75        Self::new_inner(params)
76    }
77
78    fn restart(state: Self::State) -> anyhow::Result<Self> {
79        Self::new(state)
80    }
81
82    fn run(self, rpc_recv: mesh::Receiver<WorkerRpc<Self::State>>) -> anyhow::Result<()> {
83        self.run_inner(rpc_recv)
84    }
85}
86
87impl<T: 'static + Listener + MeshField + Send> VncWorker<T> {
88    fn new_inner(params: VncParameters<T>) -> anyhow::Result<Self> {
89        Ok(Self {
90            listener: params.listener,
91            state: State::Listening {
92                view: ViewWrapper(
93                    params
94                        .framebuffer
95                        .view()
96                        .context("failed to map framebuffer")?,
97                ),
98                input: VncInput {
99                    send: params.input_send,
100                },
101            },
102        })
103    }
104
105    fn run_inner(
106        self,
107        mut rpc_recv: mesh::Receiver<WorkerRpc<VncParameters<T>>>,
108    ) -> anyhow::Result<()> {
109        block_with_io(async |driver| {
110            tracing::info!(
111                address = ?self.listener.local_addr().unwrap(),
112                "VNC server listening",
113            );
114
115            let listener = PolledSocket::new(&driver, self.listener)?;
116            let mut server = Server {
117                listener,
118                state: self.state,
119            };
120
121            let rpc = loop {
122                let r = futures::select! { // merge semantics
123                    r = rpc_recv.recv().fuse() => r,
124                    r = server.process(&driver).fuse() => break r.map(|_| None)?,
125                };
126                match r {
127                    Ok(message) => match message {
128                        WorkerRpc::Stop => break None,
129                        WorkerRpc::Inspect(deferred) => deferred.inspect(&server),
130                        WorkerRpc::Restart(response) => break Some(response),
131                    },
132                    Err(_) => break None,
133                }
134            };
135            if let Some(rpc) = rpc {
136                let (view, input) = match server.state {
137                    State::Listening { view, input } => (view, input),
138                    State::Connected { task, abort, .. } => {
139                        drop(abort);
140                        task.await
141                    }
142                    State::Invalid => unreachable!(),
143                };
144                let state = VncParameters {
145                    listener: server.listener.into_inner(),
146                    framebuffer: view.0.access(),
147                    input_send: input.send,
148                };
149                rpc.complete(Ok(state));
150            }
151            Ok(())
152        })
153    }
154}
155
156struct Server<T: Listener> {
157    listener: PolledSocket<T>,
158    state: State<T>,
159}
160
161impl<T: Listener> Server<T> {
162    /// Runs the state machine forward, either advancing the current connection
163    /// task or waiting for a new connection.
164    ///
165    /// This function's future can be dropped safely at any time without losing
166    /// any data or connections.
167    async fn process(&mut self, driver: &LocalDriver) -> anyhow::Result<()> {
168        loop {
169            match &mut self.state {
170                State::Listening { .. } => {
171                    // Accept the connection if one is really ready.
172                    let (socket, remote_addr) = self.listener.accept().await?;
173                    let socket = PolledSocket::new(driver, socket.into())?;
174
175                    tracing::info!(address = ?remote_addr, "VNC client connected");
176
177                    let (view, input) = if let State::Listening { view, input } =
178                        std::mem::replace(&mut self.state, State::Invalid)
179                    {
180                        (view, input)
181                    } else {
182                        unreachable!()
183                    };
184
185                    let mut vncserver = vnc::Server::new("HvLite VM".into(), socket, view, input);
186                    let mut timer = PolledTimer::new(driver);
187
188                    let (abort_send, abort_recv) = mesh::oneshot();
189                    let connection = Box::pin(async move {
190                        let updater = vncserver.updater();
191                        let update_task = async {
192                            // For now, just mark the framebuffer as updated
193                            // every 30ms (about 30 frames per second).
194                            loop {
195                                timer.sleep(Duration::from_millis(30)).await;
196                                updater.update();
197                            }
198                        };
199                        let r = futures::select! { // race semantics
200                            r = vncserver.run().fuse() => r.context("VNC error"),
201                            _ = abort_recv.fuse() => Err(anyhow!("VNC connection aborted")),
202                            _ = update_task.fuse() => unreachable!(),
203                        };
204                        match r {
205                            Ok(_) => {
206                                tracing::info!("VNC client disconnected");
207                            }
208                            Err(err) => {
209                                tracing::error!(error = err.as_error(), "VNC client error");
210                            }
211                        }
212                        vncserver.done()
213                    });
214                    self.state = State::Connected {
215                        remote_addr,
216                        task: connection,
217                        abort: abort_send,
218                    };
219                }
220                State::Connected { task, .. } => {
221                    let (view, input) = task.await;
222                    self.state = State::Listening { view, input };
223                }
224                State::Invalid => unreachable!(),
225            }
226        }
227    }
228}
229
230impl<T: Listener> inspect::Inspect for Server<T> {
231    fn inspect(&self, req: inspect::Request<'_>) {
232        let mut resp = req.respond();
233        resp.display_debug("local_addr", &self.listener.get().local_addr().unwrap());
234        let state = match &self.state {
235            State::Listening { .. } => "listening",
236            State::Connected { remote_addr, .. } => {
237                resp.display_debug("remote_addr", &remote_addr);
238                "connected"
239            }
240            State::Invalid => unreachable!(),
241        };
242        resp.field("state", state);
243    }
244}
245
246struct VncInput {
247    send: mesh::Sender<InputData>,
248}
249
250impl vnc::Input for VncInput {
251    fn key(&mut self, scancode: u16, is_down: bool) {
252        // TODO: need some kind of backpressure
253        self.send.send(InputData::Keyboard(KeyboardData {
254            code: scancode,
255            make: is_down,
256        }));
257    }
258
259    fn mouse(&mut self, button_mask: u8, x: u16, y: u16) {
260        self.send
261            .send(InputData::Mouse(MouseData { button_mask, x, y }));
262    }
263}
264
265struct ViewWrapper(framebuffer::View);
266
267impl vnc::Framebuffer for ViewWrapper {
268    fn read_line(&mut self, line: u16, data: &mut [u8]) {
269        self.0.read_line(line, data)
270    }
271
272    fn resolution(&mut self) -> (u16, u16) {
273        self.0.resolution()
274    }
275}