1mod protocol;
7
8use async_trait::async_trait;
9use futures::StreamExt;
10use input_core::InputSource;
11use input_core::KeyboardData;
12use mesh::payload::Protobuf;
13use std::io::IoSlice;
14use std::pin::pin;
15use task_control::StopTask;
16use thiserror::Error;
17use vmbus_async::async_dgram::AsyncRecv;
18use vmbus_async::async_dgram::AsyncRecvExt;
19use vmbus_async::async_dgram::AsyncSend;
20use vmbus_async::async_dgram::AsyncSendExt;
21use vmbus_async::pipe::MessagePipe;
22use vmbus_channel::RawAsyncChannel;
23use vmbus_channel::bus::OfferParams;
24use vmbus_channel::channel::ChannelOpenError;
25use vmbus_channel::gpadl_ring::GpadlRingMem;
26use vmbus_channel::simple::SaveRestoreSimpleVmbusDevice;
27use vmbus_channel::simple::SimpleVmbusDevice;
28use vmbus_ring::RingMem;
29use vmcore::save_restore::SavedStateRoot;
30use zerocopy::FromBytes;
31use zerocopy::Immutable;
32use zerocopy::IntoBytes;
33use zerocopy::KnownLayout;
34
35#[derive(Debug)]
36enum Request {
37 ProtocolRequest(u32),
38 SetLedIndicators,
39}
40
41#[derive(Debug, Error)]
42enum Error {
43 #[error("channel i/o error")]
44 Io(#[source] std::io::Error),
45 #[error("received out of order packet")]
46 UnexpectedPacketOrder,
47 #[error("bad packet")]
48 BadPacket,
49 #[error("unknown message type")]
50 UnknownMessageType(u32),
51 #[error("accepting vmbus channel")]
52 Accept(#[from] vmbus_channel::offer::Error),
53}
54
55async fn recv_packet(reader: &mut impl AsyncRecv) -> Result<Request, Error> {
56 let mut buf = [0; 64];
57 let n = reader.recv(&mut buf).await.map_err(Error::Io)?;
58 let buf = &buf[..n];
59 let (header, buf) =
60 protocol::MessageHeader::read_from_prefix(buf).map_err(|_| Error::BadPacket)?; let request = match header.message_type {
62 protocol::MESSAGE_PROTOCOL_REQUEST => {
63 let message = protocol::MessageProtocolRequest::read_from_prefix(buf)
64 .map_err(|_| Error::BadPacket)?
65 .0; Request::ProtocolRequest(message.version)
67 }
68 protocol::MESSAGE_SET_LED_INDICATORS => {
69 let _message = protocol::MessageLedIndicatorsState::read_from_prefix(buf)
71 .map_err(|_| Error::BadPacket)?
72 .0; Request::SetLedIndicators
74 }
75 typ => return Err(Error::UnknownMessageType(typ)),
76 };
77 Ok(request)
78}
79
80async fn send_packet<T: IntoBytes + Immutable + KnownLayout>(
81 writer: &mut impl AsyncSend,
82 typ: u32,
83 packet: &T,
84) -> Result<(), Error> {
85 writer
86 .send_vectored(&[
87 IoSlice::new(protocol::MessageHeader { message_type: typ }.as_bytes()),
88 IoSlice::new(packet.as_bytes()),
89 ])
90 .await
91 .map_err(Error::Io)?;
92 Ok(())
93}
94
95pub struct Keyboard {
97 source: Box<dyn InputSource<KeyboardData>>,
98}
99
100impl Keyboard {
101 pub fn new(source: Box<dyn InputSource<KeyboardData>>) -> Self {
103 Self { source }
104 }
105
106 pub fn into_source(self) -> Box<dyn InputSource<KeyboardData>> {
108 self.source
109 }
110}
111
112#[async_trait]
113impl SimpleVmbusDevice for Keyboard {
114 type Runner = KeyboardChannel<GpadlRingMem>;
115 type SavedState = SavedState;
116
117 fn offer(&self) -> OfferParams {
118 OfferParams {
119 interface_name: "keyboard".to_owned(),
120 interface_id: protocol::INTERFACE_GUID,
121 instance_id: protocol::INSTANCE_GUID,
122 ..Default::default()
123 }
124 }
125
126 fn open(
127 &mut self,
128 channel: RawAsyncChannel<GpadlRingMem>,
129 _guest_memory: guestmem::GuestMemory,
130 ) -> Result<Self::Runner, ChannelOpenError> {
131 let pipe = MessagePipe::new_raw(channel)?;
132 Ok(KeyboardChannel::new(pipe, ChannelState::default()))
133 }
134
135 async fn close(&mut self) {
136 self.source.set_active(false).await;
137 }
138
139 fn inspect(&mut self, req: inspect::Request<'_>, channel: Option<&mut Self::Runner>) {
140 let mut resp = req.respond();
141 if let Some(channel) = channel {
142 let (version, state) = match &channel.state {
143 ChannelState::ReadVersion => (None, "read_version"),
144 ChannelState::WriteVersion { version } => (Some(*version), "write_version"),
145 ChannelState::Active { version } => (Some(*version), "active"),
146 };
147 resp.field("state", state).field("version", version);
148 }
149 }
150
151 async fn run(
152 &mut self,
153 stop: &mut StopTask<'_>,
154 channel: &mut KeyboardChannel,
155 ) -> Result<(), task_control::Cancelled> {
156 stop.until_stopped(async {
157 match channel.process(self).await {
158 Ok(()) => {}
159 Err(err) => {
160 tracing::error!(error = &err as &dyn std::error::Error, "keyboard error")
161 }
162 }
163 })
164 .await
165 }
166
167 fn supports_save_restore(
168 &mut self,
169 ) -> Option<
170 &mut dyn SaveRestoreSimpleVmbusDevice<SavedState = Self::SavedState, Runner = Self::Runner>,
171 > {
172 Some(self)
173 }
174}
175
176impl SaveRestoreSimpleVmbusDevice for Keyboard {
177 fn save_open(&mut self, runner: &Self::Runner) -> Self::SavedState {
178 SavedState(runner.state.clone())
179 }
180
181 fn restore_open(
182 &mut self,
183 state: Self::SavedState,
184 channel: RawAsyncChannel<GpadlRingMem>,
185 ) -> Result<Self::Runner, ChannelOpenError> {
186 let pipe = MessagePipe::new_raw(channel)?;
187 Ok(KeyboardChannel::new(pipe, state.0))
188 }
189}
190
191#[derive(Protobuf, SavedStateRoot)]
193#[mesh(package = "ui.synthkbd")]
194pub struct SavedState(#[mesh(1)] ChannelState);
195
196pub struct KeyboardChannel<T: RingMem = GpadlRingMem> {
198 channel: MessagePipe<T>,
199 state: ChannelState,
200}
201
202#[derive(Debug, Clone, Protobuf, Default)]
203#[mesh(package = "ui.synthkbd")]
204enum ChannelState {
205 #[mesh(1)]
206 #[default]
207 ReadVersion,
208 #[mesh(2)]
209 WriteVersion {
210 #[mesh(1)]
211 version: u32,
212 },
213 #[mesh(3)]
214 Active {
215 #[mesh(1)]
216 version: u32,
217 },
218}
219
220impl<T: RingMem + Unpin> KeyboardChannel<T> {
221 fn new(channel: MessagePipe<T>, state: ChannelState) -> Self {
222 Self { channel, state }
223 }
224
225 async fn process(&mut self, keyboard: &mut Keyboard) -> Result<(), Error> {
226 let (mut recv, mut send) = MessagePipe::split(&mut self.channel);
227 loop {
228 match self.state {
229 ChannelState::ReadVersion => {
230 if let Request::ProtocolRequest(version) = recv_packet(&mut recv).await? {
231 self.state = ChannelState::WriteVersion { version };
232 } else {
233 return Err(Error::UnexpectedPacketOrder);
234 }
235 }
236 ChannelState::WriteVersion { version } => {
237 let accepted = version == protocol::VERSION_WIN8;
238 send_packet(
239 &mut send,
240 protocol::MESSAGE_PROTOCOL_RESPONSE,
241 &protocol::MessageProtocolResponse {
242 accepted: accepted.into(),
243 },
244 )
245 .await?;
246 if accepted {
247 tracelimit::info_ratelimited!(version, "keyboard negotiated, version");
248 self.state = ChannelState::Active { version };
249 } else {
250 tracelimit::warn_ratelimited!(version, "unknown keyboard version");
251 self.state = ChannelState::ReadVersion;
252 }
253 }
254 ChannelState::Active { version: _ } => loop {
255 keyboard.source.set_active(true).await;
256 let send_fut = pin!(async {
257 while let Some(input) = keyboard.source.next().await {
258 let mut flags = 0;
259 match input.code >> 8 {
260 0xe0 => {
261 flags |= protocol::KEYSTROKE_IS_E0;
262 }
263 0xe1 => {
264 flags |= protocol::KEYSTROKE_IS_E1;
265 }
266 _ => (),
267 }
268 if !input.make {
269 flags |= protocol::KEYSTROKE_IS_BREAK;
270 }
271 send_packet(
272 &mut send,
273 protocol::MESSAGE_EVENT,
274 &protocol::MessageKeystroke {
275 make_code: input.code & 0x7f,
276 padding: 0,
277 flags,
278 },
279 )
280 .await?;
281 }
282 Ok(())
283 });
284
285 let recv_fut = pin!(async {
286 loop {
287 match recv_packet(&mut recv).await? {
288 Request::SetLedIndicators => (),
289 _ => return Err(Error::UnexpectedPacketOrder),
290 }
291 }
292 #[expect(unreachable_code)]
293 Ok(())
294 });
295
296 futures::future::try_join(send_fut, recv_fut).await?;
297 },
298 }
299 }
300 }
301}
302
303#[cfg(test)]
304mod tests {
305 use super::*;
306 use input_core::mesh_input::input_pair;
307 use pal_async::DefaultDriver;
308 use pal_async::async_test;
309 use pal_async::task::Spawn;
310 use pal_async::task::Task;
311 use std::io::ErrorKind;
312 use test_with_tracing::test;
313 use tracing_helpers::ErrorValueExt;
314 use vmbus_async::pipe::connected_raw_message_pipes;
315
316 #[derive(Debug)]
317 enum Packet {
318 ProtocolResponse(protocol::MessageProtocolResponse),
319 Event(protocol::MessageKeystroke),
320 }
321
322 async fn recv_packet(read: &mut (dyn AsyncRecv + Unpin + Send + Sync)) -> Option<Packet> {
323 let mut packet = [0; protocol::MAXIMUM_MESSAGE_SIZE];
324 let n = read.recv(&mut packet).await.unwrap();
325 if n == 0 {
326 return None;
327 }
328 let packet = &packet[..n];
329 let (header, rest) = protocol::MessageHeader::read_from_prefix(packet).unwrap(); Some(match header.message_type {
331 protocol::MESSAGE_PROTOCOL_RESPONSE => {
332 Packet::ProtocolResponse(FromBytes::read_from_prefix(rest).unwrap().0)
333 }
335 protocol::MESSAGE_EVENT => Packet::Event(FromBytes::read_from_prefix(rest).unwrap().0), _ => panic!("unknown packet type {}", header.message_type),
337 })
338 }
339
340 fn start_worker<T: RingMem + 'static + Unpin + Send + Sync>(
341 driver: &DefaultDriver,
342 mut keyboard: Keyboard,
343 channel: MessagePipe<T>,
344 ) -> Task<Result<(), Error>> {
345 driver.spawn("keyboard worker", async move {
346 let mut channel = KeyboardChannel::new(channel, ChannelState::ReadVersion);
347 channel.process(&mut keyboard).await.or_else(|e| match e {
348 Error::Io(err) if err.kind() == ErrorKind::ConnectionReset => {
349 tracing::info!("closed");
350 Ok(())
351 }
352 _ => {
353 tracing::error!(error = e.as_error());
354 Err(e)
355 }
356 })
357 })
358 }
359
360 #[async_test]
361 async fn test_channel_working(driver: DefaultDriver) {
362 let (host, mut guest) = connected_raw_message_pipes(16384);
363 let (source, mut sink) = input_pair();
364 let worker = start_worker(&driver, Keyboard::new(Box::new(source)), host);
365
366 send_packet(
367 &mut guest,
368 protocol::MESSAGE_PROTOCOL_REQUEST,
369 &protocol::MessageProtocolRequest {
370 version: protocol::VERSION_WIN8,
371 },
372 )
373 .await
374 .unwrap();
375
376 match recv_packet(&mut guest).await.unwrap() {
377 Packet::ProtocolResponse(protocol::MessageProtocolResponse { accepted: 1 }) => (),
378 p => panic!("unexpected {:?}", p),
379 }
380
381 let events = [(3, false), (5, true)];
382
383 for &(code, make) in &events {
384 sink.send(KeyboardData { code, make });
385 }
386
387 for event in &events {
388 match recv_packet(&mut guest).await.unwrap() {
389 Packet::Event(protocol::MessageKeystroke {
390 make_code,
391 padding: _padding,
392 flags,
393 }) => {
394 assert_eq!(make_code, event.0);
395 assert_eq!(
396 flags,
397 if event.1 {
398 0
399 } else {
400 protocol::KEYSTROKE_IS_BREAK
401 }
402 );
403 }
404 p => panic!("unexpected {:?}", p),
405 }
406 }
407 drop(guest);
408 worker.await.unwrap()
409 }
410
411 #[async_test]
412 async fn test_channel_negotiation_failed(driver: DefaultDriver) {
413 let (host, mut guest) = connected_raw_message_pipes(16384);
414 let (source, _sink) = input_pair();
415 let worker = start_worker(&driver, Keyboard::new(Box::new(source)), host);
416
417 send_packet(
418 &mut guest,
419 protocol::MESSAGE_PROTOCOL_REQUEST,
420 &protocol::MessageProtocolRequest { version: 0xbadf00d },
421 )
422 .await
423 .unwrap();
424
425 match recv_packet(&mut guest).await.unwrap() {
426 Packet::ProtocolResponse(protocol::MessageProtocolResponse { accepted: 0 }) => (),
427 p => panic!("unexpected {:?}", p),
428 }
429
430 drop(guest);
431 worker.await.unwrap();
432 }
433}