vmm_core/
input_distributor.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! Contains a state unit for distributing keyboard and mouse input to the
5//! appropriate devices.
6
7use async_trait::async_trait;
8use futures::StreamExt;
9use futures_concurrency::stream::Merge;
10use input_core::InputData;
11use input_core::KeyboardData;
12use input_core::MouseData;
13use input_core::MultiplexedInputHandle;
14use input_core::ResolvedInputSource;
15use input_core::mesh_input::MeshInputSink;
16use input_core::mesh_input::MeshInputSource;
17use input_core::mesh_input::input_pair;
18use inspect::Inspect;
19use inspect::InspectMut;
20use mesh::rpc::Rpc;
21use mesh::rpc::RpcSend;
22use state_unit::StateRequest;
23use state_unit::StateUnit;
24use thiserror::Error;
25use vm_resource::AsyncResolveResource;
26use vm_resource::ResourceResolver;
27use vm_resource::kind::KeyboardInputHandleKind;
28use vm_resource::kind::MouseInputHandleKind;
29use vmcore::save_restore::RestoreError;
30use vmcore::save_restore::SaveError;
31use vmcore::save_restore::SavedStateBlob;
32
33/// Distributes keyboard and mouse input to the appropriate devices.
34pub struct InputDistributor {
35    recv: mesh::Receiver<InputData>,
36    client_recv: mesh::Receiver<DistributorRequest>,
37    client: InputDistributorClient,
38    inner: Inner,
39}
40
41#[derive(Clone)]
42pub struct InputDistributorClient {
43    send: mesh::Sender<DistributorRequest>,
44}
45
46enum DistributorRequest {
47    AddKeyboard(Rpc<Sink<KeyboardData>, Result<(), AddSinkError>>),
48    AddMouse(Rpc<Sink<MouseData>, Result<(), AddSinkError>>),
49}
50
51impl InputDistributor {
52    /// Returns a new distributor for the provided input channel.
53    pub fn new(input: mesh::Receiver<InputData>) -> Self {
54        let (client_send, client_recv) = mesh::channel();
55        Self {
56            inner: Inner {
57                running: false,
58                keyboard: Forwarder::new(),
59                mouse: Forwarder::new(),
60            },
61            recv: input,
62            client: InputDistributorClient { send: client_send },
63            client_recv,
64        }
65    }
66
67    pub fn client(&self) -> &InputDistributorClient {
68        &self.client
69    }
70
71    /// Returns the input channel.
72    pub fn into_inner(self) -> mesh::Receiver<InputData> {
73        self.recv
74    }
75
76    /// Runs the distributor.
77    pub async fn run(&mut self, recv: &mut mesh::Receiver<StateRequest>) {
78        enum Event {
79            State(StateRequest),
80            Request(DistributorRequest),
81            Done,
82            Input(InputData),
83        }
84
85        let mut stream = (
86            recv.map(Event::State)
87                .chain(futures::stream::iter([Event::Done])),
88            (&mut self.recv).map(Event::Input),
89            (&mut self.client_recv).map(Event::Request),
90        )
91            .merge();
92
93        while let Some(event) = stream.next().await {
94            match event {
95                Event::State(req) => {
96                    req.apply(&mut self.inner).await;
97                }
98                Event::Request(req) => match req {
99                    DistributorRequest::AddKeyboard(rpc) => {
100                        rpc.handle_sync(|sink| self.inner.keyboard.add_sink(sink))
101                    }
102                    DistributorRequest::AddMouse(rpc) => {
103                        rpc.handle_sync(|sink| self.inner.mouse.add_sink(sink))
104                    }
105                },
106                Event::Done => break,
107                Event::Input(data) => {
108                    // Drop input while the VM is paused.
109                    if !self.inner.running {
110                        continue;
111                    }
112                    match data {
113                        InputData::Keyboard(input) => {
114                            tracing::trace!(
115                                code = input.code,
116                                make = input.make,
117                                "forwarding keyboard input"
118                            );
119                            self.inner.keyboard.forward(input)
120                        }
121                        InputData::Mouse(input) => {
122                            tracing::trace!(
123                                button_mask = input.button_mask,
124                                x = input.x,
125                                y = input.y,
126                                "forwarding mouse input"
127                            );
128                            self.inner.mouse.forward(input)
129                        }
130                    }
131                }
132            }
133        }
134    }
135}
136
137impl InputDistributorClient {
138    /// Adds a keyboard with the given name.
139    ///
140    /// The device with the highest elevation that is active will receive input.
141    pub async fn add_keyboard(
142        &self,
143        name: impl Into<String>,
144        elevation: usize,
145    ) -> Result<MeshInputSource<KeyboardData>, AddSinkError> {
146        let (source, sink) = input_pair();
147        // Treat a missing distributor as success.
148        self.send
149            .call(
150                DistributorRequest::AddKeyboard,
151                Sink {
152                    name: name.into(),
153                    elevation,
154                    sink,
155                },
156            )
157            .await
158            .unwrap_or(Ok(()))?;
159
160        Ok(source)
161    }
162
163    /// Adds a mouse with the given name. Returns an input channel and a cell
164    /// that can be set to make the device active or not.
165    ///
166    /// The device with the highest elevation that is active will receive input.
167    pub async fn add_mouse(
168        &self,
169        name: impl Into<String>,
170        elevation: usize,
171    ) -> Result<MeshInputSource<MouseData>, AddSinkError> {
172        let (source, sink) = input_pair();
173        // Treat a missing distributor as success.
174        self.send
175            .call(
176                DistributorRequest::AddMouse,
177                Sink {
178                    name: name.into(),
179                    elevation,
180                    sink,
181                },
182            )
183            .await
184            .unwrap_or(Ok(()))?;
185
186        Ok(source)
187    }
188}
189
190#[derive(InspectMut)]
191struct Inner {
192    running: bool,
193    keyboard: Forwarder<KeyboardData>,
194    mouse: Forwarder<MouseData>,
195}
196
197impl StateUnit for Inner {
198    async fn start(&mut self) {
199        self.running = true;
200    }
201
202    async fn stop(&mut self) {
203        self.running = false;
204    }
205
206    async fn reset(&mut self) -> anyhow::Result<()> {
207        Ok(())
208    }
209
210    async fn save(&mut self) -> Result<Option<SavedStateBlob>, SaveError> {
211        Ok(None)
212    }
213
214    async fn restore(&mut self, _buffer: SavedStateBlob) -> Result<(), RestoreError> {
215        Err(RestoreError::SavedStateNotSupported)
216    }
217}
218
219struct Forwarder<T> {
220    /// Sorted by elevation.
221    sinks: Vec<Sink<T>>,
222}
223
224impl<T: 'static + Send> Inspect for Forwarder<T> {
225    fn inspect(&self, req: inspect::Request<'_>) {
226        let mut resp = req.respond();
227        for sink in &self.sinks {
228            resp.field(&sink.elevation.to_string(), sink);
229        }
230    }
231}
232
233struct Sink<T> {
234    elevation: usize,
235    name: String,
236    sink: MeshInputSink<T>,
237}
238
239impl<T: 'static + Send> Inspect for Sink<T> {
240    fn inspect(&self, req: inspect::Request<'_>) {
241        req.respond()
242            .field("name", &self.name)
243            .field("active", self.sink.is_active());
244    }
245}
246
247#[derive(Debug, Error)]
248#[error("new input sink '{name}' at elevation {elevation} conflicts with '{other}'")]
249pub struct AddSinkError {
250    name: String,
251    elevation: usize,
252    other: String,
253}
254
255impl<T: 'static + Send> Forwarder<T> {
256    fn new() -> Self {
257        Self { sinks: Vec::new() }
258    }
259
260    fn add_sink(&mut self, sink: Sink<T>) -> Result<(), AddSinkError> {
261        // Insert the sink to keep the list ordered.
262        let i = match self
263            .sinks
264            .binary_search_by(|other| other.elevation.cmp(&sink.elevation))
265        {
266            Err(i) => i,
267            Ok(i) => {
268                let other = &self.sinks[i];
269                return Err(AddSinkError {
270                    name: sink.name,
271                    elevation: sink.elevation,
272                    other: other.name.clone(),
273                });
274            }
275        };
276        self.sinks.insert(i, sink);
277        Ok(())
278    }
279
280    fn forward(&mut self, t: T) {
281        for sink in self.sinks.iter_mut().rev() {
282            if sink.sink.is_active() {
283                sink.sink.send(t);
284                break;
285            }
286        }
287    }
288}
289
290#[async_trait]
291impl AsyncResolveResource<KeyboardInputHandleKind, MultiplexedInputHandle>
292    for InputDistributorClient
293{
294    type Output = ResolvedInputSource<KeyboardData>;
295    type Error = AddSinkError;
296
297    async fn resolve(
298        &self,
299        _resolver: &ResourceResolver,
300        resource: MultiplexedInputHandle,
301        input: &str,
302    ) -> Result<Self::Output, Self::Error> {
303        Ok(self.add_keyboard(input, resource.elevation).await?.into())
304    }
305}
306
307#[async_trait]
308impl AsyncResolveResource<MouseInputHandleKind, MultiplexedInputHandle> for InputDistributorClient {
309    type Output = ResolvedInputSource<MouseData>;
310    type Error = AddSinkError;
311
312    async fn resolve(
313        &self,
314        _resolver: &ResourceResolver,
315        resource: MultiplexedInputHandle,
316        input: &str,
317    ) -> Result<Self::Output, Self::Error> {
318        Ok(self.add_mouse(input, resource.elevation).await?.into())
319    }
320}