uidevices/mouse/
mod.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! This module contains headers, constants, and structs pertinent to the Mouse device
5mod protocol;
6
7use async_trait::async_trait;
8use futures::StreamExt;
9use input_core::InputSource;
10use input_core::MouseData;
11use mesh::payload::Protobuf;
12use std::io::IoSlice;
13use std::pin::pin;
14use task_control::StopTask;
15use thiserror::Error;
16use vmbus_async::async_dgram::AsyncRecv;
17use vmbus_async::async_dgram::AsyncRecvExt;
18use vmbus_async::async_dgram::AsyncSend;
19use vmbus_async::async_dgram::AsyncSendExt;
20use vmbus_async::pipe::MessagePipe;
21use vmbus_channel::RawAsyncChannel;
22use vmbus_channel::bus::ChannelType;
23use vmbus_channel::bus::OfferParams;
24use vmbus_channel::channel::ChannelOpenError;
25use vmbus_channel::gpadl_ring::GpadlRingMem;
26use vmbus_channel::simple::SaveRestoreSimpleVmbusDevice;
27use vmbus_channel::simple::SimpleVmbusDevice;
28use vmbus_ring::RingMem;
29use vmcore::save_restore::SavedStateRoot;
30use zerocopy::FromBytes;
31use zerocopy::FromZeros;
32use zerocopy::Immutable;
33use zerocopy::IntoBytes;
34use zerocopy::KnownLayout;
35
36#[derive(Debug, Error)]
37enum Error {
38    #[error("channel i/o error")]
39    Io(#[source] std::io::Error),
40    #[error("received out of order packet")]
41    UnexpectedPacketOrder,
42    #[error("bad packet")]
43    BadPacket,
44    #[error("unknown message type")]
45    UnknownMessageType(u32),
46    #[error("accepting vmbus channel")]
47    Accept(#[from] vmbus_channel::offer::Error),
48}
49
50enum Request {
51    ProtocolRequest(u32),
52    DeviceInfoAck,
53}
54
55//HID consts- specific to setting up a HID mouse device
56const HID_DEVICE_ATTRIBUTES: protocol::HidAttributes = protocol::HidAttributes {
57    size: size_of::<protocol::HidAttributes>() as u32,
58    vendor_id: protocol::HID_VENDOR_ID,
59    product_id: protocol::HID_PRODUCT_ID,
60    version_id: protocol::HID_VERSION_ID,
61    padding: [0; 11],
62};
63
64const HID_DESCRIPTOR: protocol::HidDescriptor = protocol::HidDescriptor {
65    length: size_of::<protocol::HidDescriptor>() as u8,
66    descriptor_type: 0x21,
67    hid: 0x101,
68    country: 0x00,
69    num_descriptors: 1,
70    descriptor_list: protocol::HidDescriptorList {
71        report_type: 0x22,
72        report_length: 67,
73    },
74};
75
76const MSG_DEVICE_INFO_LENGTH: u32 = size_of::<protocol::HidAttributes>() as u32
77    + size_of::<protocol::HidDescriptor>() as u32
78    + HID_DESCRIPTOR.descriptor_list.report_length as u32;
79
80async fn recv_packet(reader: &mut (impl AsyncRecv + Unpin)) -> Result<Request, Error> {
81    let mut buf = [0; 64];
82    let n = match reader.recv(&mut buf).await {
83        Ok(n) => n,
84        Err(e) => return Err(Error::Io(e)),
85    };
86
87    let buf = &buf[..n];
88    let (header, buf) =
89        protocol::MessageHeader::read_from_prefix(buf).map_err(|_| Error::BadPacket)?; // TODO: zerocopy: map_err (https://github.com/microsoft/openvmm/issues/759)
90    let request = match header.message_type {
91        protocol::SYNTHHID_PROTOCOL_REQUEST => {
92            let message = protocol::MessageProtocolRequest::read_from_prefix(buf)
93                .map_err(|_| Error::BadPacket)?
94                .0; // TODO: zerocopy: map_err (https://github.com/microsoft/openvmm/issues/759)
95            Request::ProtocolRequest(message.version)
96        }
97        protocol::SYNTHHID_INIT_DEVICE_INFO_ACK => {
98            // We don't need the message contents, but we do still want to ensure it's valid.
99            let _message = protocol::MessageDeviceInfoAck::read_from_prefix(buf)
100                .map_err(|_| Error::BadPacket)?
101                .0; // TODO: zerocopy: map_err (https://github.com/microsoft/openvmm/issues/759)
102            Request::DeviceInfoAck
103        }
104        typ => return Err(Error::UnknownMessageType(typ)),
105    };
106    Ok(request)
107}
108
109async fn send_packet<T: IntoBytes + Immutable + KnownLayout>(
110    writer: &mut (impl AsyncSend + Unpin),
111    typ: u32,
112    size: u32,
113    packet: &T,
114) -> Result<(), Error> {
115    match writer
116        .send_vectored(&[
117            IoSlice::new(
118                protocol::MessageHeader {
119                    message_type: typ,
120                    message_size: size,
121                }
122                .as_bytes(),
123            ),
124            IoSlice::new(packet.as_bytes()),
125        ])
126        .await
127    {
128        Ok(_) => Ok(()),
129        Err(e) => Err(Error::Io(e)),
130    }
131}
132
133/// Vmbus synthetic mouse device.
134pub struct Mouse {
135    source: Box<dyn InputSource<MouseData>>,
136}
137
138impl Mouse {
139    /// Creates a new mouse device.
140    pub fn new(source: Box<dyn InputSource<MouseData>>) -> Self {
141        Self { source }
142    }
143
144    /// Extracts the mouse input receiver.
145    pub fn into_source(self) -> Box<dyn InputSource<MouseData>> {
146        self.source
147    }
148}
149
150#[derive(Debug, Clone, Protobuf)]
151#[mesh(package = "ui.synthmouse")]
152enum ChannelState {
153    #[mesh(1)]
154    ReadVersion,
155    #[mesh(2)]
156    WriteVersion {
157        #[mesh(1)]
158        version: u32,
159    },
160    #[mesh(3)]
161    SendDeviceInfo {
162        #[mesh(1)]
163        version: u32,
164    },
165    #[mesh(4)]
166    ReadDeviceInfoAck {
167        #[mesh(1)]
168        version: u32,
169    },
170    #[mesh(5)]
171    Active {
172        #[mesh(1)]
173        version: u32,
174    },
175}
176
177impl Default for ChannelState {
178    fn default() -> Self {
179        Self::ReadVersion
180    }
181}
182
183/// Mouse saved state.
184#[derive(Protobuf, SavedStateRoot)]
185#[mesh(package = "ui.synthmouse")]
186pub struct SavedState(#[mesh(1)] ChannelState);
187
188/// The mouse task.
189pub struct MouseChannel<T: RingMem = GpadlRingMem> {
190    channel: MessagePipe<T>,
191    state: ChannelState,
192}
193
194#[async_trait]
195impl SimpleVmbusDevice for Mouse {
196    type Runner = MouseChannel;
197    type SavedState = SavedState;
198
199    fn offer(&self) -> OfferParams {
200        OfferParams {
201            interface_name: "mouse".to_owned(),
202            interface_id: protocol::INTERFACE_GUID,
203            instance_id: protocol::INSTANCE_GUID,
204            channel_type: ChannelType::Device { pipe_packets: true },
205            ..Default::default()
206        }
207    }
208
209    fn inspect(&mut self, req: inspect::Request<'_>, channel: Option<&mut MouseChannel>) {
210        let mut resp = req.respond();
211        if let Some(channel) = channel {
212            let (version, state) = match &channel.state {
213                ChannelState::ReadVersion => (None, "read_version"),
214                ChannelState::WriteVersion { version } => (Some(*version), "write_version"),
215                ChannelState::SendDeviceInfo { version } => (Some(*version), "send_device_info"),
216                ChannelState::ReadDeviceInfoAck { version } => {
217                    (Some(*version), "read_device_info_ack")
218                }
219                ChannelState::Active { version } => (Some(*version), "active"),
220            };
221            resp.field("state", state).field("version", version);
222        }
223    }
224
225    fn open(
226        &mut self,
227        channel: RawAsyncChannel<GpadlRingMem>,
228        _guest_memory: guestmem::GuestMemory,
229    ) -> Result<Self::Runner, ChannelOpenError> {
230        let pipe = MessagePipe::new(channel)?;
231        Ok(MouseChannel::new(pipe, ChannelState::default()))
232    }
233
234    async fn run(
235        &mut self,
236        stop: &mut StopTask<'_>,
237        channel: &mut MouseChannel,
238    ) -> Result<(), task_control::Cancelled> {
239        stop.until_stopped(async {
240            match channel.process(self).await {
241                Ok(()) => {}
242                Err(err) => tracing::error!(error = &err as &dyn std::error::Error, "mouse error"),
243            }
244        })
245        .await
246    }
247
248    async fn close(&mut self) {
249        self.source.set_active(false).await;
250    }
251
252    fn supports_save_restore(
253        &mut self,
254    ) -> Option<
255        &mut dyn SaveRestoreSimpleVmbusDevice<SavedState = Self::SavedState, Runner = Self::Runner>,
256    > {
257        Some(self)
258    }
259}
260
261impl SaveRestoreSimpleVmbusDevice for Mouse {
262    fn save_open(&mut self, runner: &Self::Runner) -> Self::SavedState {
263        SavedState(runner.state.clone())
264    }
265
266    fn restore_open(
267        &mut self,
268        state: Self::SavedState,
269        channel: RawAsyncChannel<GpadlRingMem>,
270    ) -> Result<Self::Runner, ChannelOpenError> {
271        let pipe = MessagePipe::new(channel)?;
272        Ok(MouseChannel::new(pipe, state.0))
273    }
274}
275
276impl<T: RingMem + Unpin> MouseChannel<T> {
277    fn new(channel: MessagePipe<T>, state: ChannelState) -> Self {
278        Self { channel, state }
279    }
280
281    //responds to input from the VNC server and sends mouse information to the guest
282    async fn process(&mut self, mouse: &mut Mouse) -> Result<(), Error> {
283        let (mut recv, mut send) = MessagePipe::split(&mut self.channel);
284
285        loop {
286            match self.state {
287                ChannelState::ReadVersion => {
288                    if let Request::ProtocolRequest(version) = recv_packet(&mut recv).await? {
289                        self.state = ChannelState::WriteVersion { version };
290                    } else {
291                        return Err(Error::UnexpectedPacketOrder);
292                    }
293                }
294                ChannelState::WriteVersion { version } => {
295                    let accepted = version == protocol::SYNTHHID_INPUT_VERSION;
296                    send_packet(
297                        &mut send,
298                        protocol::SYNTHHID_PROTOCOL_RESPONSE,
299                        size_of::<protocol::MessageProtocolResponse>() as u32,
300                        &protocol::MessageProtocolResponse {
301                            version_requested: version,
302                            accepted: accepted.into(),
303                        },
304                    )
305                    .await?;
306                    if accepted {
307                        tracelimit::info_ratelimited!(version, "mouse negotiated");
308                        self.state = ChannelState::SendDeviceInfo { version };
309                    } else {
310                        tracelimit::warn_ratelimited!(version, "unknown mouse version");
311                        self.state = ChannelState::ReadVersion;
312                    }
313                }
314                ChannelState::SendDeviceInfo { version } => {
315                    let mut aligned_report_descriptor = [0u8; 128];
316                    aligned_report_descriptor[..67].copy_from_slice(&protocol::REPORT_DESCRIPTOR);
317                    let device_info_packet = protocol::MessageDeviceInfo {
318                        device_attributes: HID_DEVICE_ATTRIBUTES,
319                        descriptor_info: HID_DESCRIPTOR,
320                        report_descriptor: aligned_report_descriptor,
321                    };
322                    send_packet(
323                        &mut send,
324                        protocol::SYNTHHID_INIT_DEVICE_INFO,
325                        MSG_DEVICE_INFO_LENGTH,
326                        &device_info_packet,
327                    )
328                    .await?;
329                    self.state = ChannelState::ReadDeviceInfoAck { version };
330                }
331                ChannelState::ReadDeviceInfoAck { version } => {
332                    if !matches!(recv_packet(&mut recv).await?, Request::DeviceInfoAck) {
333                        return Err(Error::UnexpectedPacketOrder);
334                    }
335                    tracelimit::info_ratelimited!("mouse HID device info sent and acknowledged");
336                    self.state = ChannelState::Active { version };
337                }
338                ChannelState::Active { version: _ } => {
339                    mouse.source.set_active(true).await;
340                    let send_fut = pin!(async {
341                        while let Some(mouse_data) = mouse.source.next().await {
342                            post_mouse_packet(mouse_data, &mut send).await?;
343                        }
344                        Ok(())
345                    });
346                    let recv_fut = pin!(async {
347                        recv_packet(&mut recv).await?;
348                        Result::<(), _>::Err(Error::UnexpectedPacketOrder)
349                    });
350
351                    futures::future::try_join(send_fut, recv_fut).await?;
352                }
353            }
354        }
355    }
356}
357
358// Transforms MouseData from the vnc server to an HID input report (mouse packet) by scaling coordinates and marking button flags
359async fn post_mouse_packet(
360    mouse_data: MouseData,
361    channel: &mut (impl AsyncSend + Unpin),
362) -> Result<(), Error> {
363    let mut scrolled = protocol::ScrollType::NoChange;
364    let mut mouse_packet: protocol::MousePacket = FromZeros::new_zeroed();
365    mouse_packet.x = mouse_data.x;
366    mouse_packet.y = mouse_data.y;
367
368    let button_masks = [
369        protocol::HID_MOUSE_BUTTON_LEFT,
370        protocol::HID_MOUSE_BUTTON_MIDDLE,
371        protocol::HID_MOUSE_BUTTON_RIGHT,
372    ];
373
374    #[expect(clippy::needless_range_loop)] // rare case of a clippy misfire
375    for i in 0..protocol::MOUSE_NUMBER_BUTTONS {
376        if ((1u8 << i) & mouse_data.button_mask) == (1u8 << i) {
377            if i < 3 {
378                mouse_packet.button_data |= button_masks[i];
379            }
380            if i == 3 {
381                //button 4 is a mouse wheel up click
382                scrolled = protocol::ScrollType::Up;
383            }
384            if i == 4 {
385                //button 5 is a mouse wheel down click
386                scrolled = protocol::ScrollType::Down;
387            }
388        }
389    }
390
391    //b/c we want to use the ScrollType enum to move the z in a + or - direction, we cast it into an i16
392    if scrolled as i16 != 0 {
393        mouse_packet.z = scrolled as i16;
394    }
395    send_packet(
396        channel,
397        protocol::SYNTHHID_PROTOCOL_INPUT_REPORT,
398        size_of::<protocol::MessageInputReport>() as u32,
399        &mouse_packet,
400    )
401    .await
402}
403
404#[cfg(test)]
405mod tests {
406    use super::*;
407    use input_core::mesh_input::input_pair;
408    use pal_async::DefaultDriver;
409    use pal_async::async_test;
410    use pal_async::task::Spawn;
411    use pal_async::task::Task;
412    use std::io::ErrorKind;
413    use test_with_tracing::test;
414    use vmbus_async::pipe::connected_message_pipes;
415
416    #[derive(Debug)]
417    enum Packet {
418        ProtocolResponse(protocol::MessageProtocolResponse),
419        DeviceInfo(protocol::MessageDeviceInfo),
420    }
421
422    async fn recv_packet(read: &mut (dyn AsyncRecv + Unpin + Send)) -> Option<Packet> {
423        let mut packet = [0; 256];
424        let n = read.recv(&mut packet).await.unwrap();
425        if n == 0 {
426            return None;
427        }
428        let packet = &packet[..n];
429        let (header, rest) = protocol::MessageHeader::read_from_prefix(packet).unwrap(); // TODO: zerocopy: unwrap (https://github.com/microsoft/openvmm/issues/759)
430        Some(match header.message_type {
431            protocol::SYNTHHID_PROTOCOL_RESPONSE => {
432                Packet::ProtocolResponse(FromBytes::read_from_prefix(rest).unwrap().0)
433                // TODO: zerocopy: use-rest-of-range (https://github.com/microsoft/openvmm/issues/759)
434            }
435            protocol::SYNTHHID_INIT_DEVICE_INFO => {
436                Packet::DeviceInfo(FromBytes::read_from_prefix(rest).unwrap().0)
437                // TODO: zerocopy: use-rest-of-range (https://github.com/microsoft/openvmm/issues/759)
438            }
439            _ => panic!("unknown packet type {}", header.message_type),
440        })
441    }
442
443    fn start_worker<T: RingMem + 'static + Unpin + Send + Sync>(
444        driver: &DefaultDriver,
445        mut mouse: Mouse,
446        channel: MessagePipe<T>,
447    ) -> Task<Result<(), Error>> {
448        driver.spawn("mouse worker", async move {
449            MouseChannel::new(channel, Default::default())
450                .process(&mut mouse)
451                .await
452                .or_else(|e| match e {
453                    Error::Io(err) if err.kind() == ErrorKind::ConnectionReset => Ok(()),
454                    _ => Err(e),
455                })
456        })
457    }
458
459    #[async_test]
460    async fn test_channel_working(driver: DefaultDriver) {
461        let (host, mut guest) = connected_message_pipes(16384);
462        let (source, _sink) = input_pair();
463        let worker = start_worker(&driver, Mouse::new(Box::new(source)), host);
464
465        send_packet(
466            &mut guest,
467            protocol::SYNTHHID_PROTOCOL_REQUEST,
468            size_of::<protocol::MessageProtocolRequest>() as u32,
469            &protocol::MessageProtocolRequest {
470                version: protocol::SYNTHHID_INPUT_VERSION,
471            },
472        )
473        .await
474        .unwrap();
475
476        match recv_packet(&mut guest).await.unwrap() {
477            Packet::ProtocolResponse(protocol::MessageProtocolResponse {
478                version_requested: protocol::SYNTHHID_INPUT_VERSION,
479                accepted: 1,
480            }) => (),
481            p => panic!("unexpected {:?}", p),
482        }
483
484        match recv_packet(&mut guest).await.unwrap() {
485            Packet::DeviceInfo(protocol::MessageDeviceInfo {
486                device_attributes: _,
487                descriptor_info: _,
488                report_descriptor: _,
489            }) => (),
490            p => panic!("unexpected {:?}", p),
491        }
492
493        drop(guest);
494        worker.await.unwrap();
495    }
496
497    #[async_test]
498    async fn test_channel_negotiation_failed(driver: DefaultDriver) {
499        let (host, mut guest) = connected_message_pipes(16384);
500        let (source, _sink) = input_pair();
501        let worker = start_worker(&driver, Mouse::new(Box::new(source)), host);
502
503        send_packet(
504            &mut guest,
505            protocol::SYNTHHID_PROTOCOL_REQUEST,
506            size_of::<protocol::MessageProtocolRequest>() as u32,
507            &protocol::MessageProtocolRequest { version: 0xbadf00d },
508        )
509        .await
510        .unwrap();
511
512        let mut failed = false;
513        match recv_packet(&mut guest).await.unwrap() {
514            Packet::ProtocolResponse(protocol::MessageProtocolResponse {
515                version_requested: protocol::SYNTHHID_INPUT_VERSION,
516                accepted: 0,
517            }) => (),
518            _ => failed = true,
519        }
520
521        assert_eq!(failed, true);
522
523        drop(guest);
524        worker.await.unwrap();
525    }
526}