1mod protocol;
6
7use async_trait::async_trait;
8use futures::StreamExt;
9use input_core::InputSource;
10use input_core::MouseData;
11use mesh::payload::Protobuf;
12use std::io::IoSlice;
13use std::pin::pin;
14use task_control::StopTask;
15use thiserror::Error;
16use vmbus_async::async_dgram::AsyncRecv;
17use vmbus_async::async_dgram::AsyncRecvExt;
18use vmbus_async::async_dgram::AsyncSend;
19use vmbus_async::async_dgram::AsyncSendExt;
20use vmbus_async::pipe::MessagePipe;
21use vmbus_channel::RawAsyncChannel;
22use vmbus_channel::bus::ChannelType;
23use vmbus_channel::bus::OfferParams;
24use vmbus_channel::channel::ChannelOpenError;
25use vmbus_channel::gpadl_ring::GpadlRingMem;
26use vmbus_channel::simple::SaveRestoreSimpleVmbusDevice;
27use vmbus_channel::simple::SimpleVmbusDevice;
28use vmbus_ring::RingMem;
29use vmcore::save_restore::SavedStateRoot;
30use zerocopy::FromBytes;
31use zerocopy::FromZeros;
32use zerocopy::Immutable;
33use zerocopy::IntoBytes;
34use zerocopy::KnownLayout;
35
36#[derive(Debug, Error)]
37enum Error {
38 #[error("channel i/o error")]
39 Io(#[source] std::io::Error),
40 #[error("received out of order packet")]
41 UnexpectedPacketOrder,
42 #[error("bad packet")]
43 BadPacket,
44 #[error("unknown message type")]
45 UnknownMessageType(u32),
46 #[error("accepting vmbus channel")]
47 Accept(#[from] vmbus_channel::offer::Error),
48}
49
50enum Request {
51 ProtocolRequest(u32),
52 DeviceInfoAck,
53}
54
55const HID_DEVICE_ATTRIBUTES: protocol::HidAttributes = protocol::HidAttributes {
57 size: size_of::<protocol::HidAttributes>() as u32,
58 vendor_id: protocol::HID_VENDOR_ID,
59 product_id: protocol::HID_PRODUCT_ID,
60 version_id: protocol::HID_VERSION_ID,
61 padding: [0; 11],
62};
63
64const HID_DESCRIPTOR: protocol::HidDescriptor = protocol::HidDescriptor {
65 length: size_of::<protocol::HidDescriptor>() as u8,
66 descriptor_type: 0x21,
67 hid: 0x101,
68 country: 0x00,
69 num_descriptors: 1,
70 descriptor_list: protocol::HidDescriptorList {
71 report_type: 0x22,
72 report_length: 67,
73 },
74};
75
76const MSG_DEVICE_INFO_LENGTH: u32 = size_of::<protocol::HidAttributes>() as u32
77 + size_of::<protocol::HidDescriptor>() as u32
78 + HID_DESCRIPTOR.descriptor_list.report_length as u32;
79
80async fn recv_packet(reader: &mut (impl AsyncRecv + Unpin)) -> Result<Request, Error> {
81 let mut buf = [0; 64];
82 let n = match reader.recv(&mut buf).await {
83 Ok(n) => n,
84 Err(e) => return Err(Error::Io(e)),
85 };
86
87 let buf = &buf[..n];
88 let (header, buf) =
89 protocol::MessageHeader::read_from_prefix(buf).map_err(|_| Error::BadPacket)?; let request = match header.message_type {
91 protocol::SYNTHHID_PROTOCOL_REQUEST => {
92 let message = protocol::MessageProtocolRequest::read_from_prefix(buf)
93 .map_err(|_| Error::BadPacket)?
94 .0; Request::ProtocolRequest(message.version)
96 }
97 protocol::SYNTHHID_INIT_DEVICE_INFO_ACK => {
98 let _message = protocol::MessageDeviceInfoAck::read_from_prefix(buf)
100 .map_err(|_| Error::BadPacket)?
101 .0; Request::DeviceInfoAck
103 }
104 typ => return Err(Error::UnknownMessageType(typ)),
105 };
106 Ok(request)
107}
108
109async fn send_packet<T: IntoBytes + Immutable + KnownLayout>(
110 writer: &mut (impl AsyncSend + Unpin),
111 typ: u32,
112 size: u32,
113 packet: &T,
114) -> Result<(), Error> {
115 match writer
116 .send_vectored(&[
117 IoSlice::new(
118 protocol::MessageHeader {
119 message_type: typ,
120 message_size: size,
121 }
122 .as_bytes(),
123 ),
124 IoSlice::new(packet.as_bytes()),
125 ])
126 .await
127 {
128 Ok(_) => Ok(()),
129 Err(e) => Err(Error::Io(e)),
130 }
131}
132
133pub struct Mouse {
135 source: Box<dyn InputSource<MouseData>>,
136}
137
138impl Mouse {
139 pub fn new(source: Box<dyn InputSource<MouseData>>) -> Self {
141 Self { source }
142 }
143
144 pub fn into_source(self) -> Box<dyn InputSource<MouseData>> {
146 self.source
147 }
148}
149
150#[derive(Debug, Clone, Protobuf, Default)]
151#[mesh(package = "ui.synthmouse")]
152enum ChannelState {
153 #[mesh(1)]
154 #[default]
155 ReadVersion,
156 #[mesh(2)]
157 WriteVersion {
158 #[mesh(1)]
159 version: u32,
160 },
161 #[mesh(3)]
162 SendDeviceInfo {
163 #[mesh(1)]
164 version: u32,
165 },
166 #[mesh(4)]
167 ReadDeviceInfoAck {
168 #[mesh(1)]
169 version: u32,
170 },
171 #[mesh(5)]
172 Active {
173 #[mesh(1)]
174 version: u32,
175 },
176}
177
178#[derive(Protobuf, SavedStateRoot)]
180#[mesh(package = "ui.synthmouse")]
181pub struct SavedState(#[mesh(1)] ChannelState);
182
183pub struct MouseChannel<T: RingMem = GpadlRingMem> {
185 channel: MessagePipe<T>,
186 state: ChannelState,
187}
188
189#[async_trait]
190impl SimpleVmbusDevice for Mouse {
191 type Runner = MouseChannel;
192 type SavedState = SavedState;
193
194 fn offer(&self) -> OfferParams {
195 OfferParams {
196 interface_name: "mouse".to_owned(),
197 interface_id: protocol::INTERFACE_GUID,
198 instance_id: protocol::INSTANCE_GUID,
199 channel_type: ChannelType::Device { pipe_packets: true },
200 ..Default::default()
201 }
202 }
203
204 fn inspect(&mut self, req: inspect::Request<'_>, channel: Option<&mut MouseChannel>) {
205 let mut resp = req.respond();
206 if let Some(channel) = channel {
207 let (version, state) = match &channel.state {
208 ChannelState::ReadVersion => (None, "read_version"),
209 ChannelState::WriteVersion { version } => (Some(*version), "write_version"),
210 ChannelState::SendDeviceInfo { version } => (Some(*version), "send_device_info"),
211 ChannelState::ReadDeviceInfoAck { version } => {
212 (Some(*version), "read_device_info_ack")
213 }
214 ChannelState::Active { version } => (Some(*version), "active"),
215 };
216 resp.field("state", state).field("version", version);
217 }
218 }
219
220 fn open(
221 &mut self,
222 channel: RawAsyncChannel<GpadlRingMem>,
223 _guest_memory: guestmem::GuestMemory,
224 ) -> Result<Self::Runner, ChannelOpenError> {
225 let pipe = MessagePipe::new(channel)?;
226 Ok(MouseChannel::new(pipe, ChannelState::default()))
227 }
228
229 async fn run(
230 &mut self,
231 stop: &mut StopTask<'_>,
232 channel: &mut MouseChannel,
233 ) -> Result<(), task_control::Cancelled> {
234 stop.until_stopped(async {
235 match channel.process(self).await {
236 Ok(()) => {}
237 Err(err) => tracing::error!(error = &err as &dyn std::error::Error, "mouse error"),
238 }
239 })
240 .await
241 }
242
243 async fn close(&mut self) {
244 self.source.set_active(false).await;
245 }
246
247 fn supports_save_restore(
248 &mut self,
249 ) -> Option<
250 &mut dyn SaveRestoreSimpleVmbusDevice<SavedState = Self::SavedState, Runner = Self::Runner>,
251 > {
252 Some(self)
253 }
254}
255
256impl SaveRestoreSimpleVmbusDevice for Mouse {
257 fn save_open(&mut self, runner: &Self::Runner) -> Self::SavedState {
258 SavedState(runner.state.clone())
259 }
260
261 fn restore_open(
262 &mut self,
263 state: Self::SavedState,
264 channel: RawAsyncChannel<GpadlRingMem>,
265 ) -> Result<Self::Runner, ChannelOpenError> {
266 let pipe = MessagePipe::new(channel)?;
267 Ok(MouseChannel::new(pipe, state.0))
268 }
269}
270
271impl<T: RingMem + Unpin> MouseChannel<T> {
272 fn new(channel: MessagePipe<T>, state: ChannelState) -> Self {
273 Self { channel, state }
274 }
275
276 async fn process(&mut self, mouse: &mut Mouse) -> Result<(), Error> {
278 let (mut recv, mut send) = MessagePipe::split(&mut self.channel);
279
280 loop {
281 match self.state {
282 ChannelState::ReadVersion => {
283 if let Request::ProtocolRequest(version) = recv_packet(&mut recv).await? {
284 self.state = ChannelState::WriteVersion { version };
285 } else {
286 return Err(Error::UnexpectedPacketOrder);
287 }
288 }
289 ChannelState::WriteVersion { version } => {
290 let accepted = version == protocol::SYNTHHID_INPUT_VERSION;
291 send_packet(
292 &mut send,
293 protocol::SYNTHHID_PROTOCOL_RESPONSE,
294 size_of::<protocol::MessageProtocolResponse>() as u32,
295 &protocol::MessageProtocolResponse {
296 version_requested: version,
297 accepted: accepted.into(),
298 },
299 )
300 .await?;
301 if accepted {
302 tracelimit::info_ratelimited!(version, "mouse negotiated");
303 self.state = ChannelState::SendDeviceInfo { version };
304 } else {
305 tracelimit::warn_ratelimited!(version, "unknown mouse version");
306 self.state = ChannelState::ReadVersion;
307 }
308 }
309 ChannelState::SendDeviceInfo { version } => {
310 let mut aligned_report_descriptor = [0u8; 128];
311 aligned_report_descriptor[..67].copy_from_slice(&protocol::REPORT_DESCRIPTOR);
312 let device_info_packet = protocol::MessageDeviceInfo {
313 device_attributes: HID_DEVICE_ATTRIBUTES,
314 descriptor_info: HID_DESCRIPTOR,
315 report_descriptor: aligned_report_descriptor,
316 };
317 send_packet(
318 &mut send,
319 protocol::SYNTHHID_INIT_DEVICE_INFO,
320 MSG_DEVICE_INFO_LENGTH,
321 &device_info_packet,
322 )
323 .await?;
324 self.state = ChannelState::ReadDeviceInfoAck { version };
325 }
326 ChannelState::ReadDeviceInfoAck { version } => {
327 if !matches!(recv_packet(&mut recv).await?, Request::DeviceInfoAck) {
328 return Err(Error::UnexpectedPacketOrder);
329 }
330 tracelimit::info_ratelimited!("mouse HID device info sent and acknowledged");
331 self.state = ChannelState::Active { version };
332 }
333 ChannelState::Active { version: _ } => {
334 mouse.source.set_active(true).await;
335 let send_fut = pin!(async {
336 while let Some(mouse_data) = mouse.source.next().await {
337 post_mouse_packet(mouse_data, &mut send).await?;
338 }
339 Ok(())
340 });
341 let recv_fut = pin!(async {
342 recv_packet(&mut recv).await?;
343 Result::<(), _>::Err(Error::UnexpectedPacketOrder)
344 });
345
346 futures::future::try_join(send_fut, recv_fut).await?;
347 }
348 }
349 }
350 }
351}
352
353async fn post_mouse_packet(
355 mouse_data: MouseData,
356 channel: &mut (impl AsyncSend + Unpin),
357) -> Result<(), Error> {
358 let mut scrolled = protocol::ScrollType::NoChange;
359 let mut mouse_packet: protocol::MousePacket = FromZeros::new_zeroed();
360 mouse_packet.x = mouse_data.x;
361 mouse_packet.y = mouse_data.y;
362
363 let button_masks = [
364 protocol::HID_MOUSE_BUTTON_LEFT,
365 protocol::HID_MOUSE_BUTTON_MIDDLE,
366 protocol::HID_MOUSE_BUTTON_RIGHT,
367 ];
368
369 #[expect(clippy::needless_range_loop)] for i in 0..protocol::MOUSE_NUMBER_BUTTONS {
371 if ((1u8 << i) & mouse_data.button_mask) == (1u8 << i) {
372 if i < 3 {
373 mouse_packet.button_data |= button_masks[i];
374 }
375 if i == 3 {
376 scrolled = protocol::ScrollType::Up;
378 }
379 if i == 4 {
380 scrolled = protocol::ScrollType::Down;
382 }
383 }
384 }
385
386 if scrolled as i16 != 0 {
388 mouse_packet.z = scrolled as i16;
389 }
390 send_packet(
391 channel,
392 protocol::SYNTHHID_PROTOCOL_INPUT_REPORT,
393 size_of::<protocol::MessageInputReport>() as u32,
394 &mouse_packet,
395 )
396 .await
397}
398
399#[cfg(test)]
400mod tests {
401 use super::*;
402 use input_core::mesh_input::input_pair;
403 use pal_async::DefaultDriver;
404 use pal_async::async_test;
405 use pal_async::task::Spawn;
406 use pal_async::task::Task;
407 use std::io::ErrorKind;
408 use test_with_tracing::test;
409 use vmbus_async::pipe::connected_message_pipes;
410
411 #[derive(Debug)]
412 enum Packet {
413 ProtocolResponse(protocol::MessageProtocolResponse),
414 DeviceInfo(protocol::MessageDeviceInfo),
415 }
416
417 async fn recv_packet(read: &mut (dyn AsyncRecv + Unpin + Send)) -> Option<Packet> {
418 let mut packet = [0; 256];
419 let n = read.recv(&mut packet).await.unwrap();
420 if n == 0 {
421 return None;
422 }
423 let packet = &packet[..n];
424 let (header, rest) = protocol::MessageHeader::read_from_prefix(packet).unwrap(); Some(match header.message_type {
426 protocol::SYNTHHID_PROTOCOL_RESPONSE => {
427 Packet::ProtocolResponse(FromBytes::read_from_prefix(rest).unwrap().0)
428 }
430 protocol::SYNTHHID_INIT_DEVICE_INFO => {
431 Packet::DeviceInfo(FromBytes::read_from_prefix(rest).unwrap().0)
432 }
434 _ => panic!("unknown packet type {}", header.message_type),
435 })
436 }
437
438 fn start_worker<T: RingMem + 'static + Unpin + Send + Sync>(
439 driver: &DefaultDriver,
440 mut mouse: Mouse,
441 channel: MessagePipe<T>,
442 ) -> Task<Result<(), Error>> {
443 driver.spawn("mouse worker", async move {
444 MouseChannel::new(channel, Default::default())
445 .process(&mut mouse)
446 .await
447 .or_else(|e| match e {
448 Error::Io(err) if err.kind() == ErrorKind::ConnectionReset => Ok(()),
449 _ => Err(e),
450 })
451 })
452 }
453
454 #[async_test]
455 async fn test_channel_working(driver: DefaultDriver) {
456 let (host, mut guest) = connected_message_pipes(16384);
457 let (source, _sink) = input_pair();
458 let worker = start_worker(&driver, Mouse::new(Box::new(source)), host);
459
460 send_packet(
461 &mut guest,
462 protocol::SYNTHHID_PROTOCOL_REQUEST,
463 size_of::<protocol::MessageProtocolRequest>() as u32,
464 &protocol::MessageProtocolRequest {
465 version: protocol::SYNTHHID_INPUT_VERSION,
466 },
467 )
468 .await
469 .unwrap();
470
471 match recv_packet(&mut guest).await.unwrap() {
472 Packet::ProtocolResponse(protocol::MessageProtocolResponse {
473 version_requested: protocol::SYNTHHID_INPUT_VERSION,
474 accepted: 1,
475 }) => (),
476 p => panic!("unexpected {:?}", p),
477 }
478
479 match recv_packet(&mut guest).await.unwrap() {
480 Packet::DeviceInfo(protocol::MessageDeviceInfo {
481 device_attributes: _,
482 descriptor_info: _,
483 report_descriptor: _,
484 }) => (),
485 p => panic!("unexpected {:?}", p),
486 }
487
488 drop(guest);
489 worker.await.unwrap();
490 }
491
492 #[async_test]
493 async fn test_channel_negotiation_failed(driver: DefaultDriver) {
494 let (host, mut guest) = connected_message_pipes(16384);
495 let (source, _sink) = input_pair();
496 let worker = start_worker(&driver, Mouse::new(Box::new(source)), host);
497
498 send_packet(
499 &mut guest,
500 protocol::SYNTHHID_PROTOCOL_REQUEST,
501 size_of::<protocol::MessageProtocolRequest>() as u32,
502 &protocol::MessageProtocolRequest { version: 0xbadf00d },
503 )
504 .await
505 .unwrap();
506
507 let mut failed = false;
508 match recv_packet(&mut guest).await.unwrap() {
509 Packet::ProtocolResponse(protocol::MessageProtocolResponse {
510 version_requested: protocol::SYNTHHID_INPUT_VERSION,
511 accepted: 0,
512 }) => (),
513 _ => failed = true,
514 }
515
516 assert_eq!(failed, true);
517
518 drop(guest);
519 worker.await.unwrap();
520 }
521}