uidevices/keyboard/
mod.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! Vmbus synthetic keyboard device.
5
6mod protocol;
7
8use async_trait::async_trait;
9use futures::StreamExt;
10use input_core::InputSource;
11use input_core::KeyboardData;
12use mesh::payload::Protobuf;
13use std::io::IoSlice;
14use std::pin::pin;
15use task_control::StopTask;
16use thiserror::Error;
17use vmbus_async::async_dgram::AsyncRecv;
18use vmbus_async::async_dgram::AsyncRecvExt;
19use vmbus_async::async_dgram::AsyncSend;
20use vmbus_async::async_dgram::AsyncSendExt;
21use vmbus_async::pipe::MessagePipe;
22use vmbus_channel::RawAsyncChannel;
23use vmbus_channel::bus::OfferParams;
24use vmbus_channel::channel::ChannelOpenError;
25use vmbus_channel::gpadl_ring::GpadlRingMem;
26use vmbus_channel::simple::SaveRestoreSimpleVmbusDevice;
27use vmbus_channel::simple::SimpleVmbusDevice;
28use vmbus_ring::RingMem;
29use vmcore::save_restore::SavedStateRoot;
30use zerocopy::FromBytes;
31use zerocopy::Immutable;
32use zerocopy::IntoBytes;
33use zerocopy::KnownLayout;
34
35#[derive(Debug)]
36enum Request {
37    ProtocolRequest(u32),
38    SetLedIndicators,
39}
40
41#[derive(Debug, Error)]
42enum Error {
43    #[error("channel i/o error")]
44    Io(#[source] std::io::Error),
45    #[error("received out of order packet")]
46    UnexpectedPacketOrder,
47    #[error("bad packet")]
48    BadPacket,
49    #[error("unknown message type")]
50    UnknownMessageType(u32),
51    #[error("accepting vmbus channel")]
52    Accept(#[from] vmbus_channel::offer::Error),
53}
54
55async fn recv_packet(reader: &mut impl AsyncRecv) -> Result<Request, Error> {
56    let mut buf = [0; 64];
57    let n = reader.recv(&mut buf).await.map_err(Error::Io)?;
58    let buf = &buf[..n];
59    let (header, buf) =
60        protocol::MessageHeader::read_from_prefix(buf).map_err(|_| Error::BadPacket)?; // TODO: zerocopy: map_err (https://github.com/microsoft/openvmm/issues/759)
61    let request = match header.message_type {
62        protocol::MESSAGE_PROTOCOL_REQUEST => {
63            let message = protocol::MessageProtocolRequest::read_from_prefix(buf)
64                .map_err(|_| Error::BadPacket)?
65                .0; // TODO: zerocopy: map_err (https://github.com/microsoft/openvmm/issues/759)
66            Request::ProtocolRequest(message.version)
67        }
68        protocol::MESSAGE_SET_LED_INDICATORS => {
69            // We don't have any actual LEDs to set, so check the message for validity but ignore its contents.
70            let _message = protocol::MessageLedIndicatorsState::read_from_prefix(buf)
71                .map_err(|_| Error::BadPacket)?
72                .0; // TODO: zerocopy: map_err (https://github.com/microsoft/openvmm/issues/759)
73            Request::SetLedIndicators
74        }
75        typ => return Err(Error::UnknownMessageType(typ)),
76    };
77    Ok(request)
78}
79
80async fn send_packet<T: IntoBytes + Immutable + KnownLayout>(
81    writer: &mut impl AsyncSend,
82    typ: u32,
83    packet: &T,
84) -> Result<(), Error> {
85    writer
86        .send_vectored(&[
87            IoSlice::new(protocol::MessageHeader { message_type: typ }.as_bytes()),
88            IoSlice::new(packet.as_bytes()),
89        ])
90        .await
91        .map_err(Error::Io)?;
92    Ok(())
93}
94
95/// A vmbus synthetic keyboard.
96pub struct Keyboard {
97    source: Box<dyn InputSource<KeyboardData>>,
98}
99
100impl Keyboard {
101    /// Creates a new keyboard.
102    pub fn new(source: Box<dyn InputSource<KeyboardData>>) -> Self {
103        Self { source }
104    }
105
106    /// Extracts the keyboard input source.
107    pub fn into_source(self) -> Box<dyn InputSource<KeyboardData>> {
108        self.source
109    }
110}
111
112#[async_trait]
113impl SimpleVmbusDevice for Keyboard {
114    type Runner = KeyboardChannel<GpadlRingMem>;
115    type SavedState = SavedState;
116
117    fn offer(&self) -> OfferParams {
118        OfferParams {
119            interface_name: "keyboard".to_owned(),
120            interface_id: protocol::INTERFACE_GUID,
121            instance_id: protocol::INSTANCE_GUID,
122            ..Default::default()
123        }
124    }
125
126    fn open(
127        &mut self,
128        channel: RawAsyncChannel<GpadlRingMem>,
129        _guest_memory: guestmem::GuestMemory,
130    ) -> Result<Self::Runner, ChannelOpenError> {
131        let pipe = MessagePipe::new_raw(channel)?;
132        Ok(KeyboardChannel::new(pipe, ChannelState::default()))
133    }
134
135    async fn close(&mut self) {
136        self.source.set_active(false).await;
137    }
138
139    fn inspect(&mut self, req: inspect::Request<'_>, channel: Option<&mut Self::Runner>) {
140        let mut resp = req.respond();
141        if let Some(channel) = channel {
142            let (version, state) = match &channel.state {
143                ChannelState::ReadVersion => (None, "read_version"),
144                ChannelState::WriteVersion { version } => (Some(*version), "write_version"),
145                ChannelState::Active { version } => (Some(*version), "active"),
146            };
147            resp.field("state", state).field("version", version);
148        }
149    }
150
151    async fn run(
152        &mut self,
153        stop: &mut StopTask<'_>,
154        channel: &mut KeyboardChannel,
155    ) -> Result<(), task_control::Cancelled> {
156        stop.until_stopped(async {
157            match channel.process(self).await {
158                Ok(()) => {}
159                Err(err) => {
160                    tracing::error!(error = &err as &dyn std::error::Error, "keyboard error")
161                }
162            }
163        })
164        .await
165    }
166
167    fn supports_save_restore(
168        &mut self,
169    ) -> Option<
170        &mut dyn SaveRestoreSimpleVmbusDevice<SavedState = Self::SavedState, Runner = Self::Runner>,
171    > {
172        Some(self)
173    }
174}
175
176impl SaveRestoreSimpleVmbusDevice for Keyboard {
177    fn save_open(&mut self, runner: &Self::Runner) -> Self::SavedState {
178        SavedState(runner.state.clone())
179    }
180
181    fn restore_open(
182        &mut self,
183        state: Self::SavedState,
184        channel: RawAsyncChannel<GpadlRingMem>,
185    ) -> Result<Self::Runner, ChannelOpenError> {
186        let pipe = MessagePipe::new_raw(channel)?;
187        Ok(KeyboardChannel::new(pipe, state.0))
188    }
189}
190
191/// Keyboard saved state.
192#[derive(Protobuf, SavedStateRoot)]
193#[mesh(package = "ui.synthkbd")]
194pub struct SavedState(#[mesh(1)] ChannelState);
195
196/// The keyboard task.
197pub struct KeyboardChannel<T: RingMem = GpadlRingMem> {
198    channel: MessagePipe<T>,
199    state: ChannelState,
200}
201
202#[derive(Debug, Clone, Protobuf)]
203#[mesh(package = "ui.synthkbd")]
204enum ChannelState {
205    #[mesh(1)]
206    ReadVersion,
207    #[mesh(2)]
208    WriteVersion {
209        #[mesh(1)]
210        version: u32,
211    },
212    #[mesh(3)]
213    Active {
214        #[mesh(1)]
215        version: u32,
216    },
217}
218
219impl Default for ChannelState {
220    fn default() -> Self {
221        Self::ReadVersion
222    }
223}
224
225impl<T: RingMem + Unpin> KeyboardChannel<T> {
226    fn new(channel: MessagePipe<T>, state: ChannelState) -> Self {
227        Self { channel, state }
228    }
229
230    async fn process(&mut self, keyboard: &mut Keyboard) -> Result<(), Error> {
231        let (mut recv, mut send) = MessagePipe::split(&mut self.channel);
232        loop {
233            match self.state {
234                ChannelState::ReadVersion => {
235                    if let Request::ProtocolRequest(version) = recv_packet(&mut recv).await? {
236                        self.state = ChannelState::WriteVersion { version };
237                    } else {
238                        return Err(Error::UnexpectedPacketOrder);
239                    }
240                }
241                ChannelState::WriteVersion { version } => {
242                    let accepted = version == protocol::VERSION_WIN8;
243                    send_packet(
244                        &mut send,
245                        protocol::MESSAGE_PROTOCOL_RESPONSE,
246                        &protocol::MessageProtocolResponse {
247                            accepted: accepted.into(),
248                        },
249                    )
250                    .await?;
251                    if accepted {
252                        tracelimit::info_ratelimited!(version, "keyboard negotiated, version");
253                        self.state = ChannelState::Active { version };
254                    } else {
255                        tracelimit::warn_ratelimited!(version, "unknown keyboard version");
256                        self.state = ChannelState::ReadVersion;
257                    }
258                }
259                ChannelState::Active { version: _ } => loop {
260                    keyboard.source.set_active(true).await;
261                    let send_fut = pin!(async {
262                        while let Some(input) = keyboard.source.next().await {
263                            let mut flags = 0;
264                            match input.code >> 8 {
265                                0xe0 => {
266                                    flags |= protocol::KEYSTROKE_IS_E0;
267                                }
268                                0xe1 => {
269                                    flags |= protocol::KEYSTROKE_IS_E1;
270                                }
271                                _ => (),
272                            }
273                            if !input.make {
274                                flags |= protocol::KEYSTROKE_IS_BREAK;
275                            }
276                            send_packet(
277                                &mut send,
278                                protocol::MESSAGE_EVENT,
279                                &protocol::MessageKeystroke {
280                                    make_code: input.code & 0x7f,
281                                    padding: 0,
282                                    flags,
283                                },
284                            )
285                            .await?;
286                        }
287                        Ok(())
288                    });
289
290                    let recv_fut = pin!(async {
291                        loop {
292                            match recv_packet(&mut recv).await? {
293                                Request::SetLedIndicators => (),
294                                _ => return Err(Error::UnexpectedPacketOrder),
295                            }
296                        }
297                        #[expect(unreachable_code)]
298                        Ok(())
299                    });
300
301                    futures::future::try_join(send_fut, recv_fut).await?;
302                },
303            }
304        }
305    }
306}
307
308#[cfg(test)]
309mod tests {
310    use super::*;
311    use input_core::mesh_input::input_pair;
312    use pal_async::DefaultDriver;
313    use pal_async::async_test;
314    use pal_async::task::Spawn;
315    use pal_async::task::Task;
316    use std::io::ErrorKind;
317    use test_with_tracing::test;
318    use tracing_helpers::ErrorValueExt;
319    use vmbus_async::pipe::connected_raw_message_pipes;
320
321    #[derive(Debug)]
322    enum Packet {
323        ProtocolResponse(protocol::MessageProtocolResponse),
324        Event(protocol::MessageKeystroke),
325    }
326
327    async fn recv_packet(read: &mut (dyn AsyncRecv + Unpin + Send + Sync)) -> Option<Packet> {
328        let mut packet = [0; protocol::MAXIMUM_MESSAGE_SIZE];
329        let n = read.recv(&mut packet).await.unwrap();
330        if n == 0 {
331            return None;
332        }
333        let packet = &packet[..n];
334        let (header, rest) = protocol::MessageHeader::read_from_prefix(packet).unwrap(); // TODO: zerocopy: unwrap (https://github.com/microsoft/openvmm/issues/759)
335        Some(match header.message_type {
336            protocol::MESSAGE_PROTOCOL_RESPONSE => {
337                Packet::ProtocolResponse(FromBytes::read_from_prefix(rest).unwrap().0)
338                // TODO: zerocopy: use-rest-of-range (https://github.com/microsoft/openvmm/issues/759)
339            }
340            protocol::MESSAGE_EVENT => Packet::Event(FromBytes::read_from_prefix(rest).unwrap().0), // TODO: zerocopy: use-rest-of-range (https://github.com/microsoft/openvmm/issues/759)
341            _ => panic!("unknown packet type {}", header.message_type),
342        })
343    }
344
345    fn start_worker<T: RingMem + 'static + Unpin + Send + Sync>(
346        driver: &DefaultDriver,
347        mut keyboard: Keyboard,
348        channel: MessagePipe<T>,
349    ) -> Task<Result<(), Error>> {
350        driver.spawn("keyboard worker", async move {
351            let mut channel = KeyboardChannel::new(channel, ChannelState::ReadVersion);
352            channel.process(&mut keyboard).await.or_else(|e| match e {
353                Error::Io(err) if err.kind() == ErrorKind::ConnectionReset => {
354                    tracing::info!("closed");
355                    Ok(())
356                }
357                _ => {
358                    tracing::error!(error = e.as_error());
359                    Err(e)
360                }
361            })
362        })
363    }
364
365    #[async_test]
366    async fn test_channel_working(driver: DefaultDriver) {
367        let (host, mut guest) = connected_raw_message_pipes(16384);
368        let (source, mut sink) = input_pair();
369        let worker = start_worker(&driver, Keyboard::new(Box::new(source)), host);
370
371        send_packet(
372            &mut guest,
373            protocol::MESSAGE_PROTOCOL_REQUEST,
374            &protocol::MessageProtocolRequest {
375                version: protocol::VERSION_WIN8,
376            },
377        )
378        .await
379        .unwrap();
380
381        match recv_packet(&mut guest).await.unwrap() {
382            Packet::ProtocolResponse(protocol::MessageProtocolResponse { accepted: 1 }) => (),
383            p => panic!("unexpected {:?}", p),
384        }
385
386        let events = [(3, false), (5, true)];
387
388        for &(code, make) in &events {
389            sink.send(KeyboardData { code, make });
390        }
391
392        for event in &events {
393            match recv_packet(&mut guest).await.unwrap() {
394                Packet::Event(protocol::MessageKeystroke {
395                    make_code,
396                    padding: _padding,
397                    flags,
398                }) => {
399                    assert_eq!(make_code, event.0);
400                    assert_eq!(
401                        flags,
402                        if event.1 {
403                            0
404                        } else {
405                            protocol::KEYSTROKE_IS_BREAK
406                        }
407                    );
408                }
409                p => panic!("unexpected {:?}", p),
410            }
411        }
412        drop(guest);
413        worker.await.unwrap()
414    }
415
416    #[async_test]
417    async fn test_channel_negotiation_failed(driver: DefaultDriver) {
418        let (host, mut guest) = connected_raw_message_pipes(16384);
419        let (source, _sink) = input_pair();
420        let worker = start_worker(&driver, Keyboard::new(Box::new(source)), host);
421
422        send_packet(
423            &mut guest,
424            protocol::MESSAGE_PROTOCOL_REQUEST,
425            &protocol::MessageProtocolRequest { version: 0xbadf00d },
426        )
427        .await
428        .unwrap();
429
430        match recv_packet(&mut guest).await.unwrap() {
431            Packet::ProtocolResponse(protocol::MessageProtocolResponse { accepted: 0 }) => (),
432            p => panic!("unexpected {:?}", p),
433        }
434
435        drop(guest);
436        worker.await.unwrap();
437    }
438}