1mod protocol;
7
8use async_trait::async_trait;
9use futures::StreamExt;
10use input_core::InputSource;
11use input_core::KeyboardData;
12use mesh::payload::Protobuf;
13use std::io::IoSlice;
14use std::pin::pin;
15use task_control::StopTask;
16use thiserror::Error;
17use vmbus_async::async_dgram::AsyncRecv;
18use vmbus_async::async_dgram::AsyncRecvExt;
19use vmbus_async::async_dgram::AsyncSend;
20use vmbus_async::async_dgram::AsyncSendExt;
21use vmbus_async::pipe::MessagePipe;
22use vmbus_channel::RawAsyncChannel;
23use vmbus_channel::bus::OfferParams;
24use vmbus_channel::channel::ChannelOpenError;
25use vmbus_channel::gpadl_ring::GpadlRingMem;
26use vmbus_channel::simple::SaveRestoreSimpleVmbusDevice;
27use vmbus_channel::simple::SimpleVmbusDevice;
28use vmbus_ring::RingMem;
29use vmcore::save_restore::SavedStateRoot;
30use zerocopy::FromBytes;
31use zerocopy::Immutable;
32use zerocopy::IntoBytes;
33use zerocopy::KnownLayout;
34
35#[derive(Debug)]
36enum Request {
37 ProtocolRequest(u32),
38 SetLedIndicators,
39}
40
41#[derive(Debug, Error)]
42enum Error {
43 #[error("channel i/o error")]
44 Io(#[source] std::io::Error),
45 #[error("received out of order packet")]
46 UnexpectedPacketOrder,
47 #[error("bad packet")]
48 BadPacket,
49 #[error("unknown message type")]
50 UnknownMessageType(u32),
51 #[error("accepting vmbus channel")]
52 Accept(#[from] vmbus_channel::offer::Error),
53}
54
55async fn recv_packet(reader: &mut impl AsyncRecv) -> Result<Request, Error> {
56 let mut buf = [0; 64];
57 let n = reader.recv(&mut buf).await.map_err(Error::Io)?;
58 let buf = &buf[..n];
59 let (header, buf) =
60 protocol::MessageHeader::read_from_prefix(buf).map_err(|_| Error::BadPacket)?; let request = match header.message_type {
62 protocol::MESSAGE_PROTOCOL_REQUEST => {
63 let message = protocol::MessageProtocolRequest::read_from_prefix(buf)
64 .map_err(|_| Error::BadPacket)?
65 .0; Request::ProtocolRequest(message.version)
67 }
68 protocol::MESSAGE_SET_LED_INDICATORS => {
69 let _message = protocol::MessageLedIndicatorsState::read_from_prefix(buf)
71 .map_err(|_| Error::BadPacket)?
72 .0; Request::SetLedIndicators
74 }
75 typ => return Err(Error::UnknownMessageType(typ)),
76 };
77 Ok(request)
78}
79
80async fn send_packet<T: IntoBytes + Immutable + KnownLayout>(
81 writer: &mut impl AsyncSend,
82 typ: u32,
83 packet: &T,
84) -> Result<(), Error> {
85 writer
86 .send_vectored(&[
87 IoSlice::new(protocol::MessageHeader { message_type: typ }.as_bytes()),
88 IoSlice::new(packet.as_bytes()),
89 ])
90 .await
91 .map_err(Error::Io)?;
92 Ok(())
93}
94
95pub struct Keyboard {
97 source: Box<dyn InputSource<KeyboardData>>,
98}
99
100impl Keyboard {
101 pub fn new(source: Box<dyn InputSource<KeyboardData>>) -> Self {
103 Self { source }
104 }
105
106 pub fn into_source(self) -> Box<dyn InputSource<KeyboardData>> {
108 self.source
109 }
110}
111
112#[async_trait]
113impl SimpleVmbusDevice for Keyboard {
114 type Runner = KeyboardChannel<GpadlRingMem>;
115 type SavedState = SavedState;
116
117 fn offer(&self) -> OfferParams {
118 OfferParams {
119 interface_name: "keyboard".to_owned(),
120 interface_id: protocol::INTERFACE_GUID,
121 instance_id: protocol::INSTANCE_GUID,
122 ..Default::default()
123 }
124 }
125
126 fn open(
127 &mut self,
128 channel: RawAsyncChannel<GpadlRingMem>,
129 _guest_memory: guestmem::GuestMemory,
130 ) -> Result<Self::Runner, ChannelOpenError> {
131 let pipe = MessagePipe::new_raw(channel)?;
132 Ok(KeyboardChannel::new(pipe, ChannelState::default()))
133 }
134
135 async fn close(&mut self) {
136 self.source.set_active(false).await;
137 }
138
139 fn inspect(&mut self, req: inspect::Request<'_>, channel: Option<&mut Self::Runner>) {
140 let mut resp = req.respond();
141 if let Some(channel) = channel {
142 let (version, state) = match &channel.state {
143 ChannelState::ReadVersion => (None, "read_version"),
144 ChannelState::WriteVersion { version } => (Some(*version), "write_version"),
145 ChannelState::Active { version } => (Some(*version), "active"),
146 };
147 resp.field("state", state).field("version", version);
148 }
149 }
150
151 async fn run(
152 &mut self,
153 stop: &mut StopTask<'_>,
154 channel: &mut KeyboardChannel,
155 ) -> Result<(), task_control::Cancelled> {
156 stop.until_stopped(async {
157 match channel.process(self).await {
158 Ok(()) => {}
159 Err(err) => {
160 tracing::error!(error = &err as &dyn std::error::Error, "keyboard error")
161 }
162 }
163 })
164 .await
165 }
166
167 fn supports_save_restore(
168 &mut self,
169 ) -> Option<
170 &mut dyn SaveRestoreSimpleVmbusDevice<SavedState = Self::SavedState, Runner = Self::Runner>,
171 > {
172 Some(self)
173 }
174}
175
176impl SaveRestoreSimpleVmbusDevice for Keyboard {
177 fn save_open(&mut self, runner: &Self::Runner) -> Self::SavedState {
178 SavedState(runner.state.clone())
179 }
180
181 fn restore_open(
182 &mut self,
183 state: Self::SavedState,
184 channel: RawAsyncChannel<GpadlRingMem>,
185 ) -> Result<Self::Runner, ChannelOpenError> {
186 let pipe = MessagePipe::new_raw(channel)?;
187 Ok(KeyboardChannel::new(pipe, state.0))
188 }
189}
190
191#[derive(Protobuf, SavedStateRoot)]
193#[mesh(package = "ui.synthkbd")]
194pub struct SavedState(#[mesh(1)] ChannelState);
195
196pub struct KeyboardChannel<T: RingMem = GpadlRingMem> {
198 channel: MessagePipe<T>,
199 state: ChannelState,
200}
201
202#[derive(Debug, Clone, Protobuf)]
203#[mesh(package = "ui.synthkbd")]
204enum ChannelState {
205 #[mesh(1)]
206 ReadVersion,
207 #[mesh(2)]
208 WriteVersion {
209 #[mesh(1)]
210 version: u32,
211 },
212 #[mesh(3)]
213 Active {
214 #[mesh(1)]
215 version: u32,
216 },
217}
218
219impl Default for ChannelState {
220 fn default() -> Self {
221 Self::ReadVersion
222 }
223}
224
225impl<T: RingMem + Unpin> KeyboardChannel<T> {
226 fn new(channel: MessagePipe<T>, state: ChannelState) -> Self {
227 Self { channel, state }
228 }
229
230 async fn process(&mut self, keyboard: &mut Keyboard) -> Result<(), Error> {
231 let (mut recv, mut send) = MessagePipe::split(&mut self.channel);
232 loop {
233 match self.state {
234 ChannelState::ReadVersion => {
235 if let Request::ProtocolRequest(version) = recv_packet(&mut recv).await? {
236 self.state = ChannelState::WriteVersion { version };
237 } else {
238 return Err(Error::UnexpectedPacketOrder);
239 }
240 }
241 ChannelState::WriteVersion { version } => {
242 let accepted = version == protocol::VERSION_WIN8;
243 send_packet(
244 &mut send,
245 protocol::MESSAGE_PROTOCOL_RESPONSE,
246 &protocol::MessageProtocolResponse {
247 accepted: accepted.into(),
248 },
249 )
250 .await?;
251 if accepted {
252 tracelimit::info_ratelimited!(version, "keyboard negotiated, version");
253 self.state = ChannelState::Active { version };
254 } else {
255 tracelimit::warn_ratelimited!(version, "unknown keyboard version");
256 self.state = ChannelState::ReadVersion;
257 }
258 }
259 ChannelState::Active { version: _ } => loop {
260 keyboard.source.set_active(true).await;
261 let send_fut = pin!(async {
262 while let Some(input) = keyboard.source.next().await {
263 let mut flags = 0;
264 match input.code >> 8 {
265 0xe0 => {
266 flags |= protocol::KEYSTROKE_IS_E0;
267 }
268 0xe1 => {
269 flags |= protocol::KEYSTROKE_IS_E1;
270 }
271 _ => (),
272 }
273 if !input.make {
274 flags |= protocol::KEYSTROKE_IS_BREAK;
275 }
276 send_packet(
277 &mut send,
278 protocol::MESSAGE_EVENT,
279 &protocol::MessageKeystroke {
280 make_code: input.code & 0x7f,
281 padding: 0,
282 flags,
283 },
284 )
285 .await?;
286 }
287 Ok(())
288 });
289
290 let recv_fut = pin!(async {
291 loop {
292 match recv_packet(&mut recv).await? {
293 Request::SetLedIndicators => (),
294 _ => return Err(Error::UnexpectedPacketOrder),
295 }
296 }
297 #[expect(unreachable_code)]
298 Ok(())
299 });
300
301 futures::future::try_join(send_fut, recv_fut).await?;
302 },
303 }
304 }
305 }
306}
307
308#[cfg(test)]
309mod tests {
310 use super::*;
311 use input_core::mesh_input::input_pair;
312 use pal_async::DefaultDriver;
313 use pal_async::async_test;
314 use pal_async::task::Spawn;
315 use pal_async::task::Task;
316 use std::io::ErrorKind;
317 use test_with_tracing::test;
318 use tracing_helpers::ErrorValueExt;
319 use vmbus_async::pipe::connected_raw_message_pipes;
320
321 #[derive(Debug)]
322 enum Packet {
323 ProtocolResponse(protocol::MessageProtocolResponse),
324 Event(protocol::MessageKeystroke),
325 }
326
327 async fn recv_packet(read: &mut (dyn AsyncRecv + Unpin + Send + Sync)) -> Option<Packet> {
328 let mut packet = [0; protocol::MAXIMUM_MESSAGE_SIZE];
329 let n = read.recv(&mut packet).await.unwrap();
330 if n == 0 {
331 return None;
332 }
333 let packet = &packet[..n];
334 let (header, rest) = protocol::MessageHeader::read_from_prefix(packet).unwrap(); Some(match header.message_type {
336 protocol::MESSAGE_PROTOCOL_RESPONSE => {
337 Packet::ProtocolResponse(FromBytes::read_from_prefix(rest).unwrap().0)
338 }
340 protocol::MESSAGE_EVENT => Packet::Event(FromBytes::read_from_prefix(rest).unwrap().0), _ => panic!("unknown packet type {}", header.message_type),
342 })
343 }
344
345 fn start_worker<T: RingMem + 'static + Unpin + Send + Sync>(
346 driver: &DefaultDriver,
347 mut keyboard: Keyboard,
348 channel: MessagePipe<T>,
349 ) -> Task<Result<(), Error>> {
350 driver.spawn("keyboard worker", async move {
351 let mut channel = KeyboardChannel::new(channel, ChannelState::ReadVersion);
352 channel.process(&mut keyboard).await.or_else(|e| match e {
353 Error::Io(err) if err.kind() == ErrorKind::ConnectionReset => {
354 tracing::info!("closed");
355 Ok(())
356 }
357 _ => {
358 tracing::error!(error = e.as_error());
359 Err(e)
360 }
361 })
362 })
363 }
364
365 #[async_test]
366 async fn test_channel_working(driver: DefaultDriver) {
367 let (host, mut guest) = connected_raw_message_pipes(16384);
368 let (source, mut sink) = input_pair();
369 let worker = start_worker(&driver, Keyboard::new(Box::new(source)), host);
370
371 send_packet(
372 &mut guest,
373 protocol::MESSAGE_PROTOCOL_REQUEST,
374 &protocol::MessageProtocolRequest {
375 version: protocol::VERSION_WIN8,
376 },
377 )
378 .await
379 .unwrap();
380
381 match recv_packet(&mut guest).await.unwrap() {
382 Packet::ProtocolResponse(protocol::MessageProtocolResponse { accepted: 1 }) => (),
383 p => panic!("unexpected {:?}", p),
384 }
385
386 let events = [(3, false), (5, true)];
387
388 for &(code, make) in &events {
389 sink.send(KeyboardData { code, make });
390 }
391
392 for event in &events {
393 match recv_packet(&mut guest).await.unwrap() {
394 Packet::Event(protocol::MessageKeystroke {
395 make_code,
396 padding: _padding,
397 flags,
398 }) => {
399 assert_eq!(make_code, event.0);
400 assert_eq!(
401 flags,
402 if event.1 {
403 0
404 } else {
405 protocol::KEYSTROKE_IS_BREAK
406 }
407 );
408 }
409 p => panic!("unexpected {:?}", p),
410 }
411 }
412 drop(guest);
413 worker.await.unwrap()
414 }
415
416 #[async_test]
417 async fn test_channel_negotiation_failed(driver: DefaultDriver) {
418 let (host, mut guest) = connected_raw_message_pipes(16384);
419 let (source, _sink) = input_pair();
420 let worker = start_worker(&driver, Keyboard::new(Box::new(source)), host);
421
422 send_packet(
423 &mut guest,
424 protocol::MESSAGE_PROTOCOL_REQUEST,
425 &protocol::MessageProtocolRequest { version: 0xbadf00d },
426 )
427 .await
428 .unwrap();
429
430 match recv_packet(&mut guest).await.unwrap() {
431 Packet::ProtocolResponse(protocol::MessageProtocolResponse { accepted: 0 }) => (),
432 p => panic!("unexpected {:?}", p),
433 }
434
435 drop(guest);
436 worker.await.unwrap();
437 }
438}