uidevices/keyboard/
mod.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! Vmbus synthetic keyboard device.
5
6mod protocol;
7
8use async_trait::async_trait;
9use futures::StreamExt;
10use input_core::InputSource;
11use input_core::KeyboardData;
12use mesh::payload::Protobuf;
13use std::io::IoSlice;
14use std::pin::pin;
15use task_control::StopTask;
16use thiserror::Error;
17use vmbus_async::async_dgram::AsyncRecv;
18use vmbus_async::async_dgram::AsyncRecvExt;
19use vmbus_async::async_dgram::AsyncSend;
20use vmbus_async::async_dgram::AsyncSendExt;
21use vmbus_async::pipe::MessagePipe;
22use vmbus_channel::RawAsyncChannel;
23use vmbus_channel::bus::OfferParams;
24use vmbus_channel::channel::ChannelOpenError;
25use vmbus_channel::gpadl_ring::GpadlRingMem;
26use vmbus_channel::simple::SaveRestoreSimpleVmbusDevice;
27use vmbus_channel::simple::SimpleVmbusDevice;
28use vmbus_ring::RingMem;
29use vmcore::save_restore::SavedStateRoot;
30use zerocopy::FromBytes;
31use zerocopy::Immutable;
32use zerocopy::IntoBytes;
33use zerocopy::KnownLayout;
34
35#[derive(Debug)]
36enum Request {
37    ProtocolRequest(u32),
38    SetLedIndicators,
39}
40
41#[derive(Debug, Error)]
42enum Error {
43    #[error("channel i/o error")]
44    Io(#[source] std::io::Error),
45    #[error("received out of order packet")]
46    UnexpectedPacketOrder,
47    #[error("bad packet")]
48    BadPacket,
49    #[error("unknown message type")]
50    UnknownMessageType(u32),
51    #[error("accepting vmbus channel")]
52    Accept(#[from] vmbus_channel::offer::Error),
53}
54
55async fn recv_packet(reader: &mut impl AsyncRecv) -> Result<Request, Error> {
56    let mut buf = [0; 64];
57    let n = reader.recv(&mut buf).await.map_err(Error::Io)?;
58    let buf = &buf[..n];
59    let (header, buf) =
60        protocol::MessageHeader::read_from_prefix(buf).map_err(|_| Error::BadPacket)?; // TODO: zerocopy: map_err (https://github.com/microsoft/openvmm/issues/759)
61    let request = match header.message_type {
62        protocol::MESSAGE_PROTOCOL_REQUEST => {
63            let message = protocol::MessageProtocolRequest::read_from_prefix(buf)
64                .map_err(|_| Error::BadPacket)?
65                .0; // TODO: zerocopy: map_err (https://github.com/microsoft/openvmm/issues/759)
66            Request::ProtocolRequest(message.version)
67        }
68        protocol::MESSAGE_SET_LED_INDICATORS => {
69            // We don't have any actual LEDs to set, so check the message for validity but ignore its contents.
70            let _message = protocol::MessageLedIndicatorsState::read_from_prefix(buf)
71                .map_err(|_| Error::BadPacket)?
72                .0; // TODO: zerocopy: map_err (https://github.com/microsoft/openvmm/issues/759)
73            Request::SetLedIndicators
74        }
75        typ => return Err(Error::UnknownMessageType(typ)),
76    };
77    Ok(request)
78}
79
80async fn send_packet<T: IntoBytes + Immutable + KnownLayout>(
81    writer: &mut impl AsyncSend,
82    typ: u32,
83    packet: &T,
84) -> Result<(), Error> {
85    writer
86        .send_vectored(&[
87            IoSlice::new(protocol::MessageHeader { message_type: typ }.as_bytes()),
88            IoSlice::new(packet.as_bytes()),
89        ])
90        .await
91        .map_err(Error::Io)?;
92    Ok(())
93}
94
95/// A vmbus synthetic keyboard.
96pub struct Keyboard {
97    source: Box<dyn InputSource<KeyboardData>>,
98}
99
100impl Keyboard {
101    /// Creates a new keyboard.
102    pub fn new(source: Box<dyn InputSource<KeyboardData>>) -> Self {
103        Self { source }
104    }
105
106    /// Extracts the keyboard input source.
107    pub fn into_source(self) -> Box<dyn InputSource<KeyboardData>> {
108        self.source
109    }
110}
111
112#[async_trait]
113impl SimpleVmbusDevice for Keyboard {
114    type Runner = KeyboardChannel<GpadlRingMem>;
115    type SavedState = SavedState;
116
117    fn offer(&self) -> OfferParams {
118        OfferParams {
119            interface_name: "keyboard".to_owned(),
120            interface_id: protocol::INTERFACE_GUID,
121            instance_id: protocol::INSTANCE_GUID,
122            ..Default::default()
123        }
124    }
125
126    fn open(
127        &mut self,
128        channel: RawAsyncChannel<GpadlRingMem>,
129        _guest_memory: guestmem::GuestMemory,
130    ) -> Result<Self::Runner, ChannelOpenError> {
131        let pipe = MessagePipe::new_raw(channel)?;
132        Ok(KeyboardChannel::new(pipe, ChannelState::default()))
133    }
134
135    async fn close(&mut self) {
136        self.source.set_active(false).await;
137    }
138
139    fn inspect(&mut self, req: inspect::Request<'_>, channel: Option<&mut Self::Runner>) {
140        let mut resp = req.respond();
141        if let Some(channel) = channel {
142            let (version, state) = match &channel.state {
143                ChannelState::ReadVersion => (None, "read_version"),
144                ChannelState::WriteVersion { version } => (Some(*version), "write_version"),
145                ChannelState::Active { version } => (Some(*version), "active"),
146            };
147            resp.field("state", state).field("version", version);
148        }
149    }
150
151    async fn run(
152        &mut self,
153        stop: &mut StopTask<'_>,
154        channel: &mut KeyboardChannel,
155    ) -> Result<(), task_control::Cancelled> {
156        stop.until_stopped(async {
157            match channel.process(self).await {
158                Ok(()) => {}
159                Err(err) => {
160                    tracing::error!(error = &err as &dyn std::error::Error, "keyboard error")
161                }
162            }
163        })
164        .await
165    }
166
167    fn supports_save_restore(
168        &mut self,
169    ) -> Option<
170        &mut dyn SaveRestoreSimpleVmbusDevice<SavedState = Self::SavedState, Runner = Self::Runner>,
171    > {
172        Some(self)
173    }
174}
175
176impl SaveRestoreSimpleVmbusDevice for Keyboard {
177    fn save_open(&mut self, runner: &Self::Runner) -> Self::SavedState {
178        SavedState(runner.state.clone())
179    }
180
181    fn restore_open(
182        &mut self,
183        state: Self::SavedState,
184        channel: RawAsyncChannel<GpadlRingMem>,
185    ) -> Result<Self::Runner, ChannelOpenError> {
186        let pipe = MessagePipe::new_raw(channel)?;
187        Ok(KeyboardChannel::new(pipe, state.0))
188    }
189}
190
191/// Keyboard saved state.
192#[derive(Protobuf, SavedStateRoot)]
193#[mesh(package = "ui.synthkbd")]
194pub struct SavedState(#[mesh(1)] ChannelState);
195
196/// The keyboard task.
197pub struct KeyboardChannel<T: RingMem = GpadlRingMem> {
198    channel: MessagePipe<T>,
199    state: ChannelState,
200}
201
202#[derive(Debug, Clone, Protobuf, Default)]
203#[mesh(package = "ui.synthkbd")]
204enum ChannelState {
205    #[mesh(1)]
206    #[default]
207    ReadVersion,
208    #[mesh(2)]
209    WriteVersion {
210        #[mesh(1)]
211        version: u32,
212    },
213    #[mesh(3)]
214    Active {
215        #[mesh(1)]
216        version: u32,
217    },
218}
219
220impl<T: RingMem + Unpin> KeyboardChannel<T> {
221    fn new(channel: MessagePipe<T>, state: ChannelState) -> Self {
222        Self { channel, state }
223    }
224
225    async fn process(&mut self, keyboard: &mut Keyboard) -> Result<(), Error> {
226        let (mut recv, mut send) = MessagePipe::split(&mut self.channel);
227        loop {
228            match self.state {
229                ChannelState::ReadVersion => {
230                    if let Request::ProtocolRequest(version) = recv_packet(&mut recv).await? {
231                        self.state = ChannelState::WriteVersion { version };
232                    } else {
233                        return Err(Error::UnexpectedPacketOrder);
234                    }
235                }
236                ChannelState::WriteVersion { version } => {
237                    let accepted = version == protocol::VERSION_WIN8;
238                    send_packet(
239                        &mut send,
240                        protocol::MESSAGE_PROTOCOL_RESPONSE,
241                        &protocol::MessageProtocolResponse {
242                            accepted: accepted.into(),
243                        },
244                    )
245                    .await?;
246                    if accepted {
247                        tracelimit::info_ratelimited!(version, "keyboard negotiated, version");
248                        self.state = ChannelState::Active { version };
249                    } else {
250                        tracelimit::warn_ratelimited!(version, "unknown keyboard version");
251                        self.state = ChannelState::ReadVersion;
252                    }
253                }
254                ChannelState::Active { version: _ } => loop {
255                    keyboard.source.set_active(true).await;
256                    let send_fut = pin!(async {
257                        while let Some(input) = keyboard.source.next().await {
258                            let mut flags = 0;
259                            match input.code >> 8 {
260                                0xe0 => {
261                                    flags |= protocol::KEYSTROKE_IS_E0;
262                                }
263                                0xe1 => {
264                                    flags |= protocol::KEYSTROKE_IS_E1;
265                                }
266                                _ => (),
267                            }
268                            if !input.make {
269                                flags |= protocol::KEYSTROKE_IS_BREAK;
270                            }
271                            send_packet(
272                                &mut send,
273                                protocol::MESSAGE_EVENT,
274                                &protocol::MessageKeystroke {
275                                    make_code: input.code & 0x7f,
276                                    padding: 0,
277                                    flags,
278                                },
279                            )
280                            .await?;
281                        }
282                        Ok(())
283                    });
284
285                    let recv_fut = pin!(async {
286                        loop {
287                            match recv_packet(&mut recv).await? {
288                                Request::SetLedIndicators => (),
289                                _ => return Err(Error::UnexpectedPacketOrder),
290                            }
291                        }
292                        #[expect(unreachable_code)]
293                        Ok(())
294                    });
295
296                    futures::future::try_join(send_fut, recv_fut).await?;
297                },
298            }
299        }
300    }
301}
302
303#[cfg(test)]
304mod tests {
305    use super::*;
306    use input_core::mesh_input::input_pair;
307    use pal_async::DefaultDriver;
308    use pal_async::async_test;
309    use pal_async::task::Spawn;
310    use pal_async::task::Task;
311    use std::io::ErrorKind;
312    use test_with_tracing::test;
313    use tracing_helpers::ErrorValueExt;
314    use vmbus_async::pipe::connected_raw_message_pipes;
315
316    #[derive(Debug)]
317    enum Packet {
318        ProtocolResponse(protocol::MessageProtocolResponse),
319        Event(protocol::MessageKeystroke),
320    }
321
322    async fn recv_packet(read: &mut (dyn AsyncRecv + Unpin + Send + Sync)) -> Option<Packet> {
323        let mut packet = [0; protocol::MAXIMUM_MESSAGE_SIZE];
324        let n = read.recv(&mut packet).await.unwrap();
325        if n == 0 {
326            return None;
327        }
328        let packet = &packet[..n];
329        let (header, rest) = protocol::MessageHeader::read_from_prefix(packet).unwrap(); // TODO: zerocopy: unwrap (https://github.com/microsoft/openvmm/issues/759)
330        Some(match header.message_type {
331            protocol::MESSAGE_PROTOCOL_RESPONSE => {
332                Packet::ProtocolResponse(FromBytes::read_from_prefix(rest).unwrap().0)
333                // TODO: zerocopy: use-rest-of-range (https://github.com/microsoft/openvmm/issues/759)
334            }
335            protocol::MESSAGE_EVENT => Packet::Event(FromBytes::read_from_prefix(rest).unwrap().0), // TODO: zerocopy: use-rest-of-range (https://github.com/microsoft/openvmm/issues/759)
336            _ => panic!("unknown packet type {}", header.message_type),
337        })
338    }
339
340    fn start_worker<T: RingMem + 'static + Unpin + Send + Sync>(
341        driver: &DefaultDriver,
342        mut keyboard: Keyboard,
343        channel: MessagePipe<T>,
344    ) -> Task<Result<(), Error>> {
345        driver.spawn("keyboard worker", async move {
346            let mut channel = KeyboardChannel::new(channel, ChannelState::ReadVersion);
347            channel.process(&mut keyboard).await.or_else(|e| match e {
348                Error::Io(err) if err.kind() == ErrorKind::ConnectionReset => {
349                    tracing::info!("closed");
350                    Ok(())
351                }
352                _ => {
353                    tracing::error!(error = e.as_error());
354                    Err(e)
355                }
356            })
357        })
358    }
359
360    #[async_test]
361    async fn test_channel_working(driver: DefaultDriver) {
362        let (host, mut guest) = connected_raw_message_pipes(16384);
363        let (source, mut sink) = input_pair();
364        let worker = start_worker(&driver, Keyboard::new(Box::new(source)), host);
365
366        send_packet(
367            &mut guest,
368            protocol::MESSAGE_PROTOCOL_REQUEST,
369            &protocol::MessageProtocolRequest {
370                version: protocol::VERSION_WIN8,
371            },
372        )
373        .await
374        .unwrap();
375
376        match recv_packet(&mut guest).await.unwrap() {
377            Packet::ProtocolResponse(protocol::MessageProtocolResponse { accepted: 1 }) => (),
378            p => panic!("unexpected {:?}", p),
379        }
380
381        let events = [(3, false), (5, true)];
382
383        for &(code, make) in &events {
384            sink.send(KeyboardData { code, make });
385        }
386
387        for event in &events {
388            match recv_packet(&mut guest).await.unwrap() {
389                Packet::Event(protocol::MessageKeystroke {
390                    make_code,
391                    padding: _padding,
392                    flags,
393                }) => {
394                    assert_eq!(make_code, event.0);
395                    assert_eq!(
396                        flags,
397                        if event.1 {
398                            0
399                        } else {
400                            protocol::KEYSTROKE_IS_BREAK
401                        }
402                    );
403                }
404                p => panic!("unexpected {:?}", p),
405            }
406        }
407        drop(guest);
408        worker.await.unwrap()
409    }
410
411    #[async_test]
412    async fn test_channel_negotiation_failed(driver: DefaultDriver) {
413        let (host, mut guest) = connected_raw_message_pipes(16384);
414        let (source, _sink) = input_pair();
415        let worker = start_worker(&driver, Keyboard::new(Box::new(source)), host);
416
417        send_packet(
418            &mut guest,
419            protocol::MESSAGE_PROTOCOL_REQUEST,
420            &protocol::MessageProtocolRequest { version: 0xbadf00d },
421        )
422        .await
423        .unwrap();
424
425        match recv_packet(&mut guest).await.unwrap() {
426            Packet::ProtocolResponse(protocol::MessageProtocolResponse { accepted: 0 }) => (),
427            p => panic!("unexpected {:?}", p),
428        }
429
430        drop(guest);
431        worker.await.unwrap();
432    }
433}