uidevices/mouse/
mod.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! This module contains headers, constants, and structs pertinent to the Mouse device
5mod protocol;
6
7use async_trait::async_trait;
8use futures::StreamExt;
9use input_core::InputSource;
10use input_core::MouseData;
11use mesh::payload::Protobuf;
12use std::io::IoSlice;
13use std::pin::pin;
14use task_control::StopTask;
15use thiserror::Error;
16use vmbus_async::async_dgram::AsyncRecv;
17use vmbus_async::async_dgram::AsyncRecvExt;
18use vmbus_async::async_dgram::AsyncSend;
19use vmbus_async::async_dgram::AsyncSendExt;
20use vmbus_async::pipe::MessagePipe;
21use vmbus_channel::RawAsyncChannel;
22use vmbus_channel::bus::ChannelType;
23use vmbus_channel::bus::OfferParams;
24use vmbus_channel::channel::ChannelOpenError;
25use vmbus_channel::gpadl_ring::GpadlRingMem;
26use vmbus_channel::simple::SaveRestoreSimpleVmbusDevice;
27use vmbus_channel::simple::SimpleVmbusDevice;
28use vmbus_ring::RingMem;
29use vmcore::save_restore::SavedStateRoot;
30use zerocopy::FromBytes;
31use zerocopy::FromZeros;
32use zerocopy::Immutable;
33use zerocopy::IntoBytes;
34use zerocopy::KnownLayout;
35
36#[derive(Debug, Error)]
37enum Error {
38    #[error("channel i/o error")]
39    Io(#[source] std::io::Error),
40    #[error("received out of order packet")]
41    UnexpectedPacketOrder,
42    #[error("bad packet")]
43    BadPacket,
44    #[error("unknown message type")]
45    UnknownMessageType(u32),
46    #[error("accepting vmbus channel")]
47    Accept(#[from] vmbus_channel::offer::Error),
48}
49
50enum Request {
51    ProtocolRequest(u32),
52    DeviceInfoAck,
53}
54
55//HID consts- specific to setting up a HID mouse device
56const HID_DEVICE_ATTRIBUTES: protocol::HidAttributes = protocol::HidAttributes {
57    size: size_of::<protocol::HidAttributes>() as u32,
58    vendor_id: protocol::HID_VENDOR_ID,
59    product_id: protocol::HID_PRODUCT_ID,
60    version_id: protocol::HID_VERSION_ID,
61    padding: [0; 11],
62};
63
64const HID_DESCRIPTOR: protocol::HidDescriptor = protocol::HidDescriptor {
65    length: size_of::<protocol::HidDescriptor>() as u8,
66    descriptor_type: 0x21,
67    hid: 0x101,
68    country: 0x00,
69    num_descriptors: 1,
70    descriptor_list: protocol::HidDescriptorList {
71        report_type: 0x22,
72        report_length: 67,
73    },
74};
75
76const MSG_DEVICE_INFO_LENGTH: u32 = size_of::<protocol::HidAttributes>() as u32
77    + size_of::<protocol::HidDescriptor>() as u32
78    + HID_DESCRIPTOR.descriptor_list.report_length as u32;
79
80async fn recv_packet(reader: &mut (impl AsyncRecv + Unpin)) -> Result<Request, Error> {
81    let mut buf = [0; 64];
82    let n = match reader.recv(&mut buf).await {
83        Ok(n) => n,
84        Err(e) => return Err(Error::Io(e)),
85    };
86
87    let buf = &buf[..n];
88    let (header, buf) =
89        protocol::MessageHeader::read_from_prefix(buf).map_err(|_| Error::BadPacket)?; // TODO: zerocopy: map_err (https://github.com/microsoft/openvmm/issues/759)
90    let request = match header.message_type {
91        protocol::SYNTHHID_PROTOCOL_REQUEST => {
92            let message = protocol::MessageProtocolRequest::read_from_prefix(buf)
93                .map_err(|_| Error::BadPacket)?
94                .0; // TODO: zerocopy: map_err (https://github.com/microsoft/openvmm/issues/759)
95            Request::ProtocolRequest(message.version)
96        }
97        protocol::SYNTHHID_INIT_DEVICE_INFO_ACK => {
98            // We don't need the message contents, but we do still want to ensure it's valid.
99            let _message = protocol::MessageDeviceInfoAck::read_from_prefix(buf)
100                .map_err(|_| Error::BadPacket)?
101                .0; // TODO: zerocopy: map_err (https://github.com/microsoft/openvmm/issues/759)
102            Request::DeviceInfoAck
103        }
104        typ => return Err(Error::UnknownMessageType(typ)),
105    };
106    Ok(request)
107}
108
109async fn send_packet<T: IntoBytes + Immutable + KnownLayout>(
110    writer: &mut (impl AsyncSend + Unpin),
111    typ: u32,
112    size: u32,
113    packet: &T,
114) -> Result<(), Error> {
115    match writer
116        .send_vectored(&[
117            IoSlice::new(
118                protocol::MessageHeader {
119                    message_type: typ,
120                    message_size: size,
121                }
122                .as_bytes(),
123            ),
124            IoSlice::new(packet.as_bytes()),
125        ])
126        .await
127    {
128        Ok(_) => Ok(()),
129        Err(e) => Err(Error::Io(e)),
130    }
131}
132
133/// Vmbus synthetic mouse device.
134pub struct Mouse {
135    source: Box<dyn InputSource<MouseData>>,
136}
137
138impl Mouse {
139    /// Creates a new mouse device.
140    pub fn new(source: Box<dyn InputSource<MouseData>>) -> Self {
141        Self { source }
142    }
143
144    /// Extracts the mouse input receiver.
145    pub fn into_source(self) -> Box<dyn InputSource<MouseData>> {
146        self.source
147    }
148}
149
150#[derive(Debug, Clone, Protobuf, Default)]
151#[mesh(package = "ui.synthmouse")]
152enum ChannelState {
153    #[mesh(1)]
154    #[default]
155    ReadVersion,
156    #[mesh(2)]
157    WriteVersion {
158        #[mesh(1)]
159        version: u32,
160    },
161    #[mesh(3)]
162    SendDeviceInfo {
163        #[mesh(1)]
164        version: u32,
165    },
166    #[mesh(4)]
167    ReadDeviceInfoAck {
168        #[mesh(1)]
169        version: u32,
170    },
171    #[mesh(5)]
172    Active {
173        #[mesh(1)]
174        version: u32,
175    },
176}
177
178/// Mouse saved state.
179#[derive(Protobuf, SavedStateRoot)]
180#[mesh(package = "ui.synthmouse")]
181pub struct SavedState(#[mesh(1)] ChannelState);
182
183/// The mouse task.
184pub struct MouseChannel<T: RingMem = GpadlRingMem> {
185    channel: MessagePipe<T>,
186    state: ChannelState,
187}
188
189#[async_trait]
190impl SimpleVmbusDevice for Mouse {
191    type Runner = MouseChannel;
192    type SavedState = SavedState;
193
194    fn offer(&self) -> OfferParams {
195        OfferParams {
196            interface_name: "mouse".to_owned(),
197            interface_id: protocol::INTERFACE_GUID,
198            instance_id: protocol::INSTANCE_GUID,
199            channel_type: ChannelType::Device { pipe_packets: true },
200            ..Default::default()
201        }
202    }
203
204    fn inspect(&mut self, req: inspect::Request<'_>, channel: Option<&mut MouseChannel>) {
205        let mut resp = req.respond();
206        if let Some(channel) = channel {
207            let (version, state) = match &channel.state {
208                ChannelState::ReadVersion => (None, "read_version"),
209                ChannelState::WriteVersion { version } => (Some(*version), "write_version"),
210                ChannelState::SendDeviceInfo { version } => (Some(*version), "send_device_info"),
211                ChannelState::ReadDeviceInfoAck { version } => {
212                    (Some(*version), "read_device_info_ack")
213                }
214                ChannelState::Active { version } => (Some(*version), "active"),
215            };
216            resp.field("state", state).field("version", version);
217        }
218    }
219
220    fn open(
221        &mut self,
222        channel: RawAsyncChannel<GpadlRingMem>,
223        _guest_memory: guestmem::GuestMemory,
224    ) -> Result<Self::Runner, ChannelOpenError> {
225        let pipe = MessagePipe::new(channel)?;
226        Ok(MouseChannel::new(pipe, ChannelState::default()))
227    }
228
229    async fn run(
230        &mut self,
231        stop: &mut StopTask<'_>,
232        channel: &mut MouseChannel,
233    ) -> Result<(), task_control::Cancelled> {
234        stop.until_stopped(async {
235            match channel.process(self).await {
236                Ok(()) => {}
237                Err(err) => tracing::error!(error = &err as &dyn std::error::Error, "mouse error"),
238            }
239        })
240        .await
241    }
242
243    async fn close(&mut self) {
244        self.source.set_active(false).await;
245    }
246
247    fn supports_save_restore(
248        &mut self,
249    ) -> Option<
250        &mut dyn SaveRestoreSimpleVmbusDevice<SavedState = Self::SavedState, Runner = Self::Runner>,
251    > {
252        Some(self)
253    }
254}
255
256impl SaveRestoreSimpleVmbusDevice for Mouse {
257    fn save_open(&mut self, runner: &Self::Runner) -> Self::SavedState {
258        SavedState(runner.state.clone())
259    }
260
261    fn restore_open(
262        &mut self,
263        state: Self::SavedState,
264        channel: RawAsyncChannel<GpadlRingMem>,
265    ) -> Result<Self::Runner, ChannelOpenError> {
266        let pipe = MessagePipe::new(channel)?;
267        Ok(MouseChannel::new(pipe, state.0))
268    }
269}
270
271impl<T: RingMem + Unpin> MouseChannel<T> {
272    fn new(channel: MessagePipe<T>, state: ChannelState) -> Self {
273        Self { channel, state }
274    }
275
276    //responds to input from the VNC server and sends mouse information to the guest
277    async fn process(&mut self, mouse: &mut Mouse) -> Result<(), Error> {
278        let (mut recv, mut send) = MessagePipe::split(&mut self.channel);
279
280        loop {
281            match self.state {
282                ChannelState::ReadVersion => {
283                    if let Request::ProtocolRequest(version) = recv_packet(&mut recv).await? {
284                        self.state = ChannelState::WriteVersion { version };
285                    } else {
286                        return Err(Error::UnexpectedPacketOrder);
287                    }
288                }
289                ChannelState::WriteVersion { version } => {
290                    let accepted = version == protocol::SYNTHHID_INPUT_VERSION;
291                    send_packet(
292                        &mut send,
293                        protocol::SYNTHHID_PROTOCOL_RESPONSE,
294                        size_of::<protocol::MessageProtocolResponse>() as u32,
295                        &protocol::MessageProtocolResponse {
296                            version_requested: version,
297                            accepted: accepted.into(),
298                        },
299                    )
300                    .await?;
301                    if accepted {
302                        tracelimit::info_ratelimited!(version, "mouse negotiated");
303                        self.state = ChannelState::SendDeviceInfo { version };
304                    } else {
305                        tracelimit::warn_ratelimited!(version, "unknown mouse version");
306                        self.state = ChannelState::ReadVersion;
307                    }
308                }
309                ChannelState::SendDeviceInfo { version } => {
310                    let mut aligned_report_descriptor = [0u8; 128];
311                    aligned_report_descriptor[..67].copy_from_slice(&protocol::REPORT_DESCRIPTOR);
312                    let device_info_packet = protocol::MessageDeviceInfo {
313                        device_attributes: HID_DEVICE_ATTRIBUTES,
314                        descriptor_info: HID_DESCRIPTOR,
315                        report_descriptor: aligned_report_descriptor,
316                    };
317                    send_packet(
318                        &mut send,
319                        protocol::SYNTHHID_INIT_DEVICE_INFO,
320                        MSG_DEVICE_INFO_LENGTH,
321                        &device_info_packet,
322                    )
323                    .await?;
324                    self.state = ChannelState::ReadDeviceInfoAck { version };
325                }
326                ChannelState::ReadDeviceInfoAck { version } => {
327                    if !matches!(recv_packet(&mut recv).await?, Request::DeviceInfoAck) {
328                        return Err(Error::UnexpectedPacketOrder);
329                    }
330                    tracelimit::info_ratelimited!("mouse HID device info sent and acknowledged");
331                    self.state = ChannelState::Active { version };
332                }
333                ChannelState::Active { version: _ } => {
334                    mouse.source.set_active(true).await;
335                    let send_fut = pin!(async {
336                        while let Some(mouse_data) = mouse.source.next().await {
337                            post_mouse_packet(mouse_data, &mut send).await?;
338                        }
339                        Ok(())
340                    });
341                    let recv_fut = pin!(async {
342                        recv_packet(&mut recv).await?;
343                        Result::<(), _>::Err(Error::UnexpectedPacketOrder)
344                    });
345
346                    futures::future::try_join(send_fut, recv_fut).await?;
347                }
348            }
349        }
350    }
351}
352
353// Transforms MouseData from the vnc server to an HID input report (mouse packet) by scaling coordinates and marking button flags
354async fn post_mouse_packet(
355    mouse_data: MouseData,
356    channel: &mut (impl AsyncSend + Unpin),
357) -> Result<(), Error> {
358    let mut scrolled = protocol::ScrollType::NoChange;
359    let mut mouse_packet: protocol::MousePacket = FromZeros::new_zeroed();
360    mouse_packet.x = mouse_data.x;
361    mouse_packet.y = mouse_data.y;
362
363    let button_masks = [
364        protocol::HID_MOUSE_BUTTON_LEFT,
365        protocol::HID_MOUSE_BUTTON_MIDDLE,
366        protocol::HID_MOUSE_BUTTON_RIGHT,
367    ];
368
369    #[expect(clippy::needless_range_loop)] // rare case of a clippy misfire
370    for i in 0..protocol::MOUSE_NUMBER_BUTTONS {
371        if ((1u8 << i) & mouse_data.button_mask) == (1u8 << i) {
372            if i < 3 {
373                mouse_packet.button_data |= button_masks[i];
374            }
375            if i == 3 {
376                //button 4 is a mouse wheel up click
377                scrolled = protocol::ScrollType::Up;
378            }
379            if i == 4 {
380                //button 5 is a mouse wheel down click
381                scrolled = protocol::ScrollType::Down;
382            }
383        }
384    }
385
386    //b/c we want to use the ScrollType enum to move the z in a + or - direction, we cast it into an i16
387    if scrolled as i16 != 0 {
388        mouse_packet.z = scrolled as i16;
389    }
390    send_packet(
391        channel,
392        protocol::SYNTHHID_PROTOCOL_INPUT_REPORT,
393        size_of::<protocol::MessageInputReport>() as u32,
394        &mouse_packet,
395    )
396    .await
397}
398
399#[cfg(test)]
400mod tests {
401    use super::*;
402    use input_core::mesh_input::input_pair;
403    use pal_async::DefaultDriver;
404    use pal_async::async_test;
405    use pal_async::task::Spawn;
406    use pal_async::task::Task;
407    use std::io::ErrorKind;
408    use test_with_tracing::test;
409    use vmbus_async::pipe::connected_message_pipes;
410
411    #[derive(Debug)]
412    enum Packet {
413        ProtocolResponse(protocol::MessageProtocolResponse),
414        DeviceInfo(protocol::MessageDeviceInfo),
415    }
416
417    async fn recv_packet(read: &mut (dyn AsyncRecv + Unpin + Send)) -> Option<Packet> {
418        let mut packet = [0; 256];
419        let n = read.recv(&mut packet).await.unwrap();
420        if n == 0 {
421            return None;
422        }
423        let packet = &packet[..n];
424        let (header, rest) = protocol::MessageHeader::read_from_prefix(packet).unwrap(); // TODO: zerocopy: unwrap (https://github.com/microsoft/openvmm/issues/759)
425        Some(match header.message_type {
426            protocol::SYNTHHID_PROTOCOL_RESPONSE => {
427                Packet::ProtocolResponse(FromBytes::read_from_prefix(rest).unwrap().0)
428                // TODO: zerocopy: use-rest-of-range (https://github.com/microsoft/openvmm/issues/759)
429            }
430            protocol::SYNTHHID_INIT_DEVICE_INFO => {
431                Packet::DeviceInfo(FromBytes::read_from_prefix(rest).unwrap().0)
432                // TODO: zerocopy: use-rest-of-range (https://github.com/microsoft/openvmm/issues/759)
433            }
434            _ => panic!("unknown packet type {}", header.message_type),
435        })
436    }
437
438    fn start_worker<T: RingMem + 'static + Unpin + Send + Sync>(
439        driver: &DefaultDriver,
440        mut mouse: Mouse,
441        channel: MessagePipe<T>,
442    ) -> Task<Result<(), Error>> {
443        driver.spawn("mouse worker", async move {
444            MouseChannel::new(channel, Default::default())
445                .process(&mut mouse)
446                .await
447                .or_else(|e| match e {
448                    Error::Io(err) if err.kind() == ErrorKind::ConnectionReset => Ok(()),
449                    _ => Err(e),
450                })
451        })
452    }
453
454    #[async_test]
455    async fn test_channel_working(driver: DefaultDriver) {
456        let (host, mut guest) = connected_message_pipes(16384);
457        let (source, _sink) = input_pair();
458        let worker = start_worker(&driver, Mouse::new(Box::new(source)), host);
459
460        send_packet(
461            &mut guest,
462            protocol::SYNTHHID_PROTOCOL_REQUEST,
463            size_of::<protocol::MessageProtocolRequest>() as u32,
464            &protocol::MessageProtocolRequest {
465                version: protocol::SYNTHHID_INPUT_VERSION,
466            },
467        )
468        .await
469        .unwrap();
470
471        match recv_packet(&mut guest).await.unwrap() {
472            Packet::ProtocolResponse(protocol::MessageProtocolResponse {
473                version_requested: protocol::SYNTHHID_INPUT_VERSION,
474                accepted: 1,
475            }) => (),
476            p => panic!("unexpected {:?}", p),
477        }
478
479        match recv_packet(&mut guest).await.unwrap() {
480            Packet::DeviceInfo(protocol::MessageDeviceInfo {
481                device_attributes: _,
482                descriptor_info: _,
483                report_descriptor: _,
484            }) => (),
485            p => panic!("unexpected {:?}", p),
486        }
487
488        drop(guest);
489        worker.await.unwrap();
490    }
491
492    #[async_test]
493    async fn test_channel_negotiation_failed(driver: DefaultDriver) {
494        let (host, mut guest) = connected_message_pipes(16384);
495        let (source, _sink) = input_pair();
496        let worker = start_worker(&driver, Mouse::new(Box::new(source)), host);
497
498        send_packet(
499            &mut guest,
500            protocol::SYNTHHID_PROTOCOL_REQUEST,
501            size_of::<protocol::MessageProtocolRequest>() as u32,
502            &protocol::MessageProtocolRequest { version: 0xbadf00d },
503        )
504        .await
505        .unwrap();
506
507        let mut failed = false;
508        match recv_packet(&mut guest).await.unwrap() {
509            Packet::ProtocolResponse(protocol::MessageProtocolResponse {
510                version_requested: protocol::SYNTHHID_INPUT_VERSION,
511                accepted: 0,
512            }) => (),
513            _ => failed = true,
514        }
515
516        assert_eq!(failed, true);
517
518        drop(guest);
519        worker.await.unwrap();
520    }
521}