1mod protocol;
6
7use async_trait::async_trait;
8use futures::StreamExt;
9use input_core::InputSource;
10use input_core::MouseData;
11use mesh::payload::Protobuf;
12use std::io::IoSlice;
13use std::pin::pin;
14use task_control::StopTask;
15use thiserror::Error;
16use vmbus_async::async_dgram::AsyncRecv;
17use vmbus_async::async_dgram::AsyncRecvExt;
18use vmbus_async::async_dgram::AsyncSend;
19use vmbus_async::async_dgram::AsyncSendExt;
20use vmbus_async::pipe::MessagePipe;
21use vmbus_channel::RawAsyncChannel;
22use vmbus_channel::bus::ChannelType;
23use vmbus_channel::bus::OfferParams;
24use vmbus_channel::channel::ChannelOpenError;
25use vmbus_channel::gpadl_ring::GpadlRingMem;
26use vmbus_channel::simple::SaveRestoreSimpleVmbusDevice;
27use vmbus_channel::simple::SimpleVmbusDevice;
28use vmbus_ring::RingMem;
29use vmcore::save_restore::SavedStateRoot;
30use zerocopy::FromBytes;
31use zerocopy::FromZeros;
32use zerocopy::Immutable;
33use zerocopy::IntoBytes;
34use zerocopy::KnownLayout;
35
36#[derive(Debug, Error)]
37enum Error {
38 #[error("channel i/o error")]
39 Io(#[source] std::io::Error),
40 #[error("received out of order packet")]
41 UnexpectedPacketOrder,
42 #[error("bad packet")]
43 BadPacket,
44 #[error("unknown message type")]
45 UnknownMessageType(u32),
46 #[error("accepting vmbus channel")]
47 Accept(#[from] vmbus_channel::offer::Error),
48}
49
50enum Request {
51 ProtocolRequest(u32),
52 DeviceInfoAck,
53}
54
55const HID_DEVICE_ATTRIBUTES: protocol::HidAttributes = protocol::HidAttributes {
57 size: size_of::<protocol::HidAttributes>() as u32,
58 vendor_id: protocol::HID_VENDOR_ID,
59 product_id: protocol::HID_PRODUCT_ID,
60 version_id: protocol::HID_VERSION_ID,
61 padding: [0; 11],
62};
63
64const HID_DESCRIPTOR: protocol::HidDescriptor = protocol::HidDescriptor {
65 length: size_of::<protocol::HidDescriptor>() as u8,
66 descriptor_type: 0x21,
67 hid: 0x101,
68 country: 0x00,
69 num_descriptors: 1,
70 descriptor_list: protocol::HidDescriptorList {
71 report_type: 0x22,
72 report_length: 67,
73 },
74};
75
76const MSG_DEVICE_INFO_LENGTH: u32 = size_of::<protocol::HidAttributes>() as u32
77 + size_of::<protocol::HidDescriptor>() as u32
78 + HID_DESCRIPTOR.descriptor_list.report_length as u32;
79
80async fn recv_packet(reader: &mut (impl AsyncRecv + Unpin)) -> Result<Request, Error> {
81 let mut buf = [0; 64];
82 let n = match reader.recv(&mut buf).await {
83 Ok(n) => n,
84 Err(e) => return Err(Error::Io(e)),
85 };
86
87 let buf = &buf[..n];
88 let (header, buf) =
89 protocol::MessageHeader::read_from_prefix(buf).map_err(|_| Error::BadPacket)?; let request = match header.message_type {
91 protocol::SYNTHHID_PROTOCOL_REQUEST => {
92 let message = protocol::MessageProtocolRequest::read_from_prefix(buf)
93 .map_err(|_| Error::BadPacket)?
94 .0; Request::ProtocolRequest(message.version)
96 }
97 protocol::SYNTHHID_INIT_DEVICE_INFO_ACK => {
98 let _message = protocol::MessageDeviceInfoAck::read_from_prefix(buf)
100 .map_err(|_| Error::BadPacket)?
101 .0; Request::DeviceInfoAck
103 }
104 typ => return Err(Error::UnknownMessageType(typ)),
105 };
106 Ok(request)
107}
108
109async fn send_packet<T: IntoBytes + Immutable + KnownLayout>(
110 writer: &mut (impl AsyncSend + Unpin),
111 typ: u32,
112 size: u32,
113 packet: &T,
114) -> Result<(), Error> {
115 match writer
116 .send_vectored(&[
117 IoSlice::new(
118 protocol::MessageHeader {
119 message_type: typ,
120 message_size: size,
121 }
122 .as_bytes(),
123 ),
124 IoSlice::new(packet.as_bytes()),
125 ])
126 .await
127 {
128 Ok(_) => Ok(()),
129 Err(e) => Err(Error::Io(e)),
130 }
131}
132
133pub struct Mouse {
135 source: Box<dyn InputSource<MouseData>>,
136}
137
138impl Mouse {
139 pub fn new(source: Box<dyn InputSource<MouseData>>) -> Self {
141 Self { source }
142 }
143
144 pub fn into_source(self) -> Box<dyn InputSource<MouseData>> {
146 self.source
147 }
148}
149
150#[derive(Debug, Clone, Protobuf)]
151#[mesh(package = "ui.synthmouse")]
152enum ChannelState {
153 #[mesh(1)]
154 ReadVersion,
155 #[mesh(2)]
156 WriteVersion {
157 #[mesh(1)]
158 version: u32,
159 },
160 #[mesh(3)]
161 SendDeviceInfo {
162 #[mesh(1)]
163 version: u32,
164 },
165 #[mesh(4)]
166 ReadDeviceInfoAck {
167 #[mesh(1)]
168 version: u32,
169 },
170 #[mesh(5)]
171 Active {
172 #[mesh(1)]
173 version: u32,
174 },
175}
176
177impl Default for ChannelState {
178 fn default() -> Self {
179 Self::ReadVersion
180 }
181}
182
183#[derive(Protobuf, SavedStateRoot)]
185#[mesh(package = "ui.synthmouse")]
186pub struct SavedState(#[mesh(1)] ChannelState);
187
188pub struct MouseChannel<T: RingMem = GpadlRingMem> {
190 channel: MessagePipe<T>,
191 state: ChannelState,
192}
193
194#[async_trait]
195impl SimpleVmbusDevice for Mouse {
196 type Runner = MouseChannel;
197 type SavedState = SavedState;
198
199 fn offer(&self) -> OfferParams {
200 OfferParams {
201 interface_name: "mouse".to_owned(),
202 interface_id: protocol::INTERFACE_GUID,
203 instance_id: protocol::INSTANCE_GUID,
204 channel_type: ChannelType::Device { pipe_packets: true },
205 ..Default::default()
206 }
207 }
208
209 fn inspect(&mut self, req: inspect::Request<'_>, channel: Option<&mut MouseChannel>) {
210 let mut resp = req.respond();
211 if let Some(channel) = channel {
212 let (version, state) = match &channel.state {
213 ChannelState::ReadVersion => (None, "read_version"),
214 ChannelState::WriteVersion { version } => (Some(*version), "write_version"),
215 ChannelState::SendDeviceInfo { version } => (Some(*version), "send_device_info"),
216 ChannelState::ReadDeviceInfoAck { version } => {
217 (Some(*version), "read_device_info_ack")
218 }
219 ChannelState::Active { version } => (Some(*version), "active"),
220 };
221 resp.field("state", state).field("version", version);
222 }
223 }
224
225 fn open(
226 &mut self,
227 channel: RawAsyncChannel<GpadlRingMem>,
228 _guest_memory: guestmem::GuestMemory,
229 ) -> Result<Self::Runner, ChannelOpenError> {
230 let pipe = MessagePipe::new(channel)?;
231 Ok(MouseChannel::new(pipe, ChannelState::default()))
232 }
233
234 async fn run(
235 &mut self,
236 stop: &mut StopTask<'_>,
237 channel: &mut MouseChannel,
238 ) -> Result<(), task_control::Cancelled> {
239 stop.until_stopped(async {
240 match channel.process(self).await {
241 Ok(()) => {}
242 Err(err) => tracing::error!(error = &err as &dyn std::error::Error, "mouse error"),
243 }
244 })
245 .await
246 }
247
248 async fn close(&mut self) {
249 self.source.set_active(false).await;
250 }
251
252 fn supports_save_restore(
253 &mut self,
254 ) -> Option<
255 &mut dyn SaveRestoreSimpleVmbusDevice<SavedState = Self::SavedState, Runner = Self::Runner>,
256 > {
257 Some(self)
258 }
259}
260
261impl SaveRestoreSimpleVmbusDevice for Mouse {
262 fn save_open(&mut self, runner: &Self::Runner) -> Self::SavedState {
263 SavedState(runner.state.clone())
264 }
265
266 fn restore_open(
267 &mut self,
268 state: Self::SavedState,
269 channel: RawAsyncChannel<GpadlRingMem>,
270 ) -> Result<Self::Runner, ChannelOpenError> {
271 let pipe = MessagePipe::new(channel)?;
272 Ok(MouseChannel::new(pipe, state.0))
273 }
274}
275
276impl<T: RingMem + Unpin> MouseChannel<T> {
277 fn new(channel: MessagePipe<T>, state: ChannelState) -> Self {
278 Self { channel, state }
279 }
280
281 async fn process(&mut self, mouse: &mut Mouse) -> Result<(), Error> {
283 let (mut recv, mut send) = MessagePipe::split(&mut self.channel);
284
285 loop {
286 match self.state {
287 ChannelState::ReadVersion => {
288 if let Request::ProtocolRequest(version) = recv_packet(&mut recv).await? {
289 self.state = ChannelState::WriteVersion { version };
290 } else {
291 return Err(Error::UnexpectedPacketOrder);
292 }
293 }
294 ChannelState::WriteVersion { version } => {
295 let accepted = version == protocol::SYNTHHID_INPUT_VERSION;
296 send_packet(
297 &mut send,
298 protocol::SYNTHHID_PROTOCOL_RESPONSE,
299 size_of::<protocol::MessageProtocolResponse>() as u32,
300 &protocol::MessageProtocolResponse {
301 version_requested: version,
302 accepted: accepted.into(),
303 },
304 )
305 .await?;
306 if accepted {
307 tracelimit::info_ratelimited!(version, "mouse negotiated");
308 self.state = ChannelState::SendDeviceInfo { version };
309 } else {
310 tracelimit::warn_ratelimited!(version, "unknown mouse version");
311 self.state = ChannelState::ReadVersion;
312 }
313 }
314 ChannelState::SendDeviceInfo { version } => {
315 let mut aligned_report_descriptor = [0u8; 128];
316 aligned_report_descriptor[..67].copy_from_slice(&protocol::REPORT_DESCRIPTOR);
317 let device_info_packet = protocol::MessageDeviceInfo {
318 device_attributes: HID_DEVICE_ATTRIBUTES,
319 descriptor_info: HID_DESCRIPTOR,
320 report_descriptor: aligned_report_descriptor,
321 };
322 send_packet(
323 &mut send,
324 protocol::SYNTHHID_INIT_DEVICE_INFO,
325 MSG_DEVICE_INFO_LENGTH,
326 &device_info_packet,
327 )
328 .await?;
329 self.state = ChannelState::ReadDeviceInfoAck { version };
330 }
331 ChannelState::ReadDeviceInfoAck { version } => {
332 if !matches!(recv_packet(&mut recv).await?, Request::DeviceInfoAck) {
333 return Err(Error::UnexpectedPacketOrder);
334 }
335 tracelimit::info_ratelimited!("mouse HID device info sent and acknowledged");
336 self.state = ChannelState::Active { version };
337 }
338 ChannelState::Active { version: _ } => {
339 mouse.source.set_active(true).await;
340 let send_fut = pin!(async {
341 while let Some(mouse_data) = mouse.source.next().await {
342 post_mouse_packet(mouse_data, &mut send).await?;
343 }
344 Ok(())
345 });
346 let recv_fut = pin!(async {
347 recv_packet(&mut recv).await?;
348 Result::<(), _>::Err(Error::UnexpectedPacketOrder)
349 });
350
351 futures::future::try_join(send_fut, recv_fut).await?;
352 }
353 }
354 }
355 }
356}
357
358async fn post_mouse_packet(
360 mouse_data: MouseData,
361 channel: &mut (impl AsyncSend + Unpin),
362) -> Result<(), Error> {
363 let mut scrolled = protocol::ScrollType::NoChange;
364 let mut mouse_packet: protocol::MousePacket = FromZeros::new_zeroed();
365 mouse_packet.x = mouse_data.x;
366 mouse_packet.y = mouse_data.y;
367
368 let button_masks = [
369 protocol::HID_MOUSE_BUTTON_LEFT,
370 protocol::HID_MOUSE_BUTTON_MIDDLE,
371 protocol::HID_MOUSE_BUTTON_RIGHT,
372 ];
373
374 #[expect(clippy::needless_range_loop)] for i in 0..protocol::MOUSE_NUMBER_BUTTONS {
376 if ((1u8 << i) & mouse_data.button_mask) == (1u8 << i) {
377 if i < 3 {
378 mouse_packet.button_data |= button_masks[i];
379 }
380 if i == 3 {
381 scrolled = protocol::ScrollType::Up;
383 }
384 if i == 4 {
385 scrolled = protocol::ScrollType::Down;
387 }
388 }
389 }
390
391 if scrolled as i16 != 0 {
393 mouse_packet.z = scrolled as i16;
394 }
395 send_packet(
396 channel,
397 protocol::SYNTHHID_PROTOCOL_INPUT_REPORT,
398 size_of::<protocol::MessageInputReport>() as u32,
399 &mouse_packet,
400 )
401 .await
402}
403
404#[cfg(test)]
405mod tests {
406 use super::*;
407 use input_core::mesh_input::input_pair;
408 use pal_async::DefaultDriver;
409 use pal_async::async_test;
410 use pal_async::task::Spawn;
411 use pal_async::task::Task;
412 use std::io::ErrorKind;
413 use test_with_tracing::test;
414 use vmbus_async::pipe::connected_message_pipes;
415
416 #[derive(Debug)]
417 enum Packet {
418 ProtocolResponse(protocol::MessageProtocolResponse),
419 DeviceInfo(protocol::MessageDeviceInfo),
420 }
421
422 async fn recv_packet(read: &mut (dyn AsyncRecv + Unpin + Send)) -> Option<Packet> {
423 let mut packet = [0; 256];
424 let n = read.recv(&mut packet).await.unwrap();
425 if n == 0 {
426 return None;
427 }
428 let packet = &packet[..n];
429 let (header, rest) = protocol::MessageHeader::read_from_prefix(packet).unwrap(); Some(match header.message_type {
431 protocol::SYNTHHID_PROTOCOL_RESPONSE => {
432 Packet::ProtocolResponse(FromBytes::read_from_prefix(rest).unwrap().0)
433 }
435 protocol::SYNTHHID_INIT_DEVICE_INFO => {
436 Packet::DeviceInfo(FromBytes::read_from_prefix(rest).unwrap().0)
437 }
439 _ => panic!("unknown packet type {}", header.message_type),
440 })
441 }
442
443 fn start_worker<T: RingMem + 'static + Unpin + Send + Sync>(
444 driver: &DefaultDriver,
445 mut mouse: Mouse,
446 channel: MessagePipe<T>,
447 ) -> Task<Result<(), Error>> {
448 driver.spawn("mouse worker", async move {
449 MouseChannel::new(channel, Default::default())
450 .process(&mut mouse)
451 .await
452 .or_else(|e| match e {
453 Error::Io(err) if err.kind() == ErrorKind::ConnectionReset => Ok(()),
454 _ => Err(e),
455 })
456 })
457 }
458
459 #[async_test]
460 async fn test_channel_working(driver: DefaultDriver) {
461 let (host, mut guest) = connected_message_pipes(16384);
462 let (source, _sink) = input_pair();
463 let worker = start_worker(&driver, Mouse::new(Box::new(source)), host);
464
465 send_packet(
466 &mut guest,
467 protocol::SYNTHHID_PROTOCOL_REQUEST,
468 size_of::<protocol::MessageProtocolRequest>() as u32,
469 &protocol::MessageProtocolRequest {
470 version: protocol::SYNTHHID_INPUT_VERSION,
471 },
472 )
473 .await
474 .unwrap();
475
476 match recv_packet(&mut guest).await.unwrap() {
477 Packet::ProtocolResponse(protocol::MessageProtocolResponse {
478 version_requested: protocol::SYNTHHID_INPUT_VERSION,
479 accepted: 1,
480 }) => (),
481 p => panic!("unexpected {:?}", p),
482 }
483
484 match recv_packet(&mut guest).await.unwrap() {
485 Packet::DeviceInfo(protocol::MessageDeviceInfo {
486 device_attributes: _,
487 descriptor_info: _,
488 report_descriptor: _,
489 }) => (),
490 p => panic!("unexpected {:?}", p),
491 }
492
493 drop(guest);
494 worker.await.unwrap();
495 }
496
497 #[async_test]
498 async fn test_channel_negotiation_failed(driver: DefaultDriver) {
499 let (host, mut guest) = connected_message_pipes(16384);
500 let (source, _sink) = input_pair();
501 let worker = start_worker(&driver, Mouse::new(Box::new(source)), host);
502
503 send_packet(
504 &mut guest,
505 protocol::SYNTHHID_PROTOCOL_REQUEST,
506 size_of::<protocol::MessageProtocolRequest>() as u32,
507 &protocol::MessageProtocolRequest { version: 0xbadf00d },
508 )
509 .await
510 .unwrap();
511
512 let mut failed = false;
513 match recv_packet(&mut guest).await.unwrap() {
514 Packet::ProtocolResponse(protocol::MessageProtocolResponse {
515 version_requested: protocol::SYNTHHID_INPUT_VERSION,
516 accepted: 0,
517 }) => (),
518 _ => failed = true,
519 }
520
521 assert_eq!(failed, true);
522
523 drop(guest);
524 worker.await.unwrap();
525 }
526}