virtio_serial/
lib.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4#![expect(missing_docs)]
5#![forbid(unsafe_code)]
6
7use async_trait::async_trait;
8use guestmem::GuestMemory;
9use parking_lot::Condvar;
10use parking_lot::Mutex;
11use std::io;
12use std::io::ErrorKind;
13use std::ops::DerefMut;
14use std::sync::Arc;
15use std::thread::JoinHandle;
16use virtio::DeviceTraits;
17use virtio::LegacyVirtioDevice;
18use virtio::VirtioQueueCallbackWork;
19use virtio::VirtioQueueWorkerContext;
20use virtio::VirtioState;
21
22const VIRTIO_DEVICE_TYPE_CONSOLE: u16 = 3;
23
24// const VIRTIO_CONSOLE_F_SIZE: u64 = 1;
25const VIRTIO_CONSOLE_F_MULTIPORT: u64 = 2;
26// const VIRTIO_CONSOLE_F_EMERG_WRITE: u64 = 4;
27
28const VIRTIO_CONSOLE_DEVICE_READY: u16 = 0;
29const VIRTIO_CONSOLE_DEVICE_ADD: u16 = 1;
30// const VIRTIO_CONSOLE_DEVICE_REMOVE: u16 = 2;
31const VIRTIO_CONSOLE_PORT_READY: u16 = 3;
32// const VIRTIO_CONSOLE_CONSOLE_PORT: u16 = 4;
33// const VIRTIO_CONSOLE_RESIZE: u16 = 5;
34const VIRTIO_CONSOLE_PORT_OPEN: u16 = 6;
35const VIRTIO_CONSOLE_PORT_NAME: u16 = 7;
36
37enum VirtioSerialPortIoState {
38    Unavailable,
39    Ready(Option<VirtioQueueCallbackWork>),
40    Processing,
41    Disconnected,
42    Exiting,
43}
44
45pub struct VirtioSerialPort {
46    mem: GuestMemory,
47    read_state: (Mutex<VirtioSerialPortIoState>, Condvar),
48    write_state: (Mutex<VirtioSerialPortIoState>, Condvar),
49}
50
51impl VirtioSerialPort {
52    pub fn new(mem: &GuestMemory) -> Self {
53        Self {
54            mem: mem.clone(),
55            read_state: (
56                Mutex::new(VirtioSerialPortIoState::Disconnected),
57                Condvar::new(),
58            ),
59            write_state: (
60                Mutex::new(VirtioSerialPortIoState::Disconnected),
61                Condvar::new(),
62            ),
63        }
64    }
65
66    pub fn read_from_port(&self) -> Option<VirtioQueueCallbackWork> {
67        let mut read_work: Option<VirtioQueueCallbackWork> = None;
68        let (state, state_cvar) = &self.read_state;
69        let mut cur_state = state.lock();
70        while let VirtioSerialPortIoState::Unavailable = *cur_state {
71            state_cvar.wait(&mut cur_state);
72        }
73
74        if let VirtioSerialPortIoState::Ready(work) = cur_state.deref_mut() {
75            assert!(work.is_some());
76            read_work = work.take();
77            *cur_state = VirtioSerialPortIoState::Processing;
78        }
79        read_work
80    }
81
82    pub fn complete_read_from_port(&self, mut work: VirtioQueueCallbackWork) {
83        work.complete(0);
84        let (state, state_cvar) = &self.read_state;
85        let mut cur_state = state.lock();
86        if let VirtioSerialPortIoState::Processing = *cur_state {
87            *cur_state = VirtioSerialPortIoState::Unavailable;
88            state_cvar.notify_one();
89        }
90    }
91
92    pub fn write_to_port(&self, data: &[u8]) -> usize {
93        if data.is_empty() {
94            return 0;
95        }
96        let mut write_work: Option<VirtioQueueCallbackWork> = None;
97        let (state, state_cvar) = &self.write_state;
98        {
99            let mut cur_state = state.lock();
100            while let VirtioSerialPortIoState::Unavailable = *cur_state {
101                state_cvar.wait(&mut cur_state);
102            }
103
104            if let VirtioSerialPortIoState::Ready(work) = cur_state.deref_mut() {
105                assert!(work.is_some());
106                write_work = work.take();
107                *cur_state = VirtioSerialPortIoState::Processing;
108            }
109        }
110
111        let mut bytes_written = 0;
112        if let Some(mut work) = write_work {
113            let mut remaining = data;
114            if remaining.len() > u32::MAX as usize {
115                remaining = &remaining[..u32::MAX as usize];
116            }
117            for payload in work.payload.iter() {
118                if remaining.is_empty() {
119                    break;
120                }
121                if !payload.writeable {
122                    break;
123                }
124                if payload.length == 0 {
125                    continue;
126                }
127                let bytes_to_write = std::cmp::min(payload.length as usize, remaining.len());
128                if let Err(error) = self.mem.write_at(payload.address, remaining) {
129                    tracing::error!(
130                        error = &error as &dyn std::error::Error,
131                        "[virtio_serial] Failed to write to guest memory",
132                    );
133                    break;
134                }
135                remaining = &remaining[bytes_to_write..];
136                bytes_written += bytes_to_write;
137            }
138            work.complete(u32::try_from(bytes_written).unwrap());
139            let mut cur_state = state.lock();
140            if let VirtioSerialPortIoState::Processing = *cur_state {
141                *cur_state = VirtioSerialPortIoState::Unavailable;
142                state_cvar.notify_one();
143            }
144        }
145        bytes_written
146    }
147
148    pub fn write_all_to_port(&self, data: &[u8]) -> bool {
149        let mut remaining = data;
150        while !remaining.is_empty() {
151            let bytes_written = self.write_to_port(remaining);
152            if bytes_written == 0 {
153                break;
154            }
155            remaining = &remaining[bytes_written..];
156        }
157        remaining.is_empty()
158    }
159
160    // transfer data into the virtio port
161    pub async fn process_virtio_read(&self, work: VirtioQueueCallbackWork) -> bool {
162        let (state, state_cvar) = &self.write_state;
163        let mut cur_state = state.lock();
164        match *cur_state {
165            VirtioSerialPortIoState::Unavailable | VirtioSerialPortIoState::Disconnected => {
166                *cur_state = VirtioSerialPortIoState::Ready(Some(work));
167                state_cvar.notify_one();
168            }
169            VirtioSerialPortIoState::Exiting => {
170                return false;
171            }
172            _ => panic!("Unexpected serial port IO state"),
173        };
174        loop {
175            match *cur_state {
176                VirtioSerialPortIoState::Unavailable
177                | VirtioSerialPortIoState::Disconnected
178                | VirtioSerialPortIoState::Exiting => {
179                    break false;
180                }
181                _ => {
182                    state_cvar.wait(&mut cur_state);
183                }
184            };
185        }
186    }
187
188    // transfer data from the virtio port
189    pub async fn process_virtio_write(&self, mut work: VirtioQueueCallbackWork) -> bool {
190        let (state, state_cvar) = &self.read_state;
191        let mut cur_state = state.lock();
192        match *cur_state {
193            VirtioSerialPortIoState::Unavailable => {
194                *cur_state = VirtioSerialPortIoState::Ready(Some(work));
195                state_cvar.notify_one();
196            }
197            VirtioSerialPortIoState::Disconnected => {
198                // if the port is disconnected, drop any writes
199                work.complete(0);
200                return true;
201            }
202            VirtioSerialPortIoState::Exiting => {
203                return false;
204            }
205            _ => panic!("Unexpected serial port IO state"),
206        };
207        loop {
208            match *cur_state {
209                VirtioSerialPortIoState::Unavailable
210                | VirtioSerialPortIoState::Disconnected
211                | VirtioSerialPortIoState::Exiting => {
212                    break false;
213                }
214                _ => {
215                    state_cvar.wait(&mut cur_state);
216                }
217            };
218        }
219    }
220
221    pub fn open(&self) {
222        let (state, _) = &self.read_state;
223        let mut cur_state = state.lock();
224        if let VirtioSerialPortIoState::Disconnected = *cur_state {
225            *cur_state = VirtioSerialPortIoState::Unavailable;
226        } else {
227            panic!("Opening a port that has already been opened");
228        }
229
230        let (state, _) = &self.write_state;
231        let mut cur_state = state.lock();
232        if let VirtioSerialPortIoState::Disconnected = *cur_state {
233            *cur_state = VirtioSerialPortIoState::Unavailable;
234        }
235    }
236
237    pub fn close(&self) {
238        let (state, state_cvar) = &self.read_state;
239        let mut cur_state = state.lock();
240        if let VirtioSerialPortIoState::Ready(work) = cur_state.deref_mut() {
241            assert!(work.is_some());
242            let mut work = work.take().expect("[VIRTO SERIAL] empty work");
243            work.complete(0);
244        } else if let VirtioSerialPortIoState::Disconnected = *cur_state {
245            panic!("Closing a port that was not open");
246        }
247        *cur_state = VirtioSerialPortIoState::Disconnected;
248        state_cvar.notify_one();
249
250        let (state, state_cvar) = &self.write_state;
251        let mut cur_state = state.lock();
252        *cur_state = VirtioSerialPortIoState::Disconnected;
253        state_cvar.notify_one();
254    }
255
256    pub fn stop(&self) {
257        let (state, state_cvar) = &self.read_state;
258        *state.lock() = VirtioSerialPortIoState::Exiting;
259        state_cvar.notify_one();
260        let (state, state_cvar) = &self.write_state;
261        *state.lock() = VirtioSerialPortIoState::Exiting;
262        state_cvar.notify_one();
263    }
264}
265
266impl Drop for VirtioSerialPort {
267    fn drop(&mut self) {
268        self.stop();
269    }
270}
271
272struct VirtioSerialDeviceConfig {
273    columns: u16,
274    rows: u16,
275    max_ports: u32,
276    _emergency_write: u32,
277}
278
279struct VirtioSerialControl {
280    port_number: u32,
281    event: u16,
282    value: u16,
283}
284
285impl VirtioSerialControl {
286    pub fn to_bytes(&self) -> Vec<u8> {
287        let mut data = Vec::new();
288        data.extend_from_slice(&self.port_number.to_le_bytes());
289        data.extend_from_slice(&self.event.to_le_bytes());
290        data.extend_from_slice(&self.value.to_le_bytes());
291        data
292    }
293
294    pub fn from_bytes(data: &[u8]) -> Result<Self, io::Error> {
295        if data.len() < 8 {
296            return Err(io::Error::new(
297                ErrorKind::InvalidInput,
298                format!("Data is too small {} bytes", data.len()),
299            ));
300        }
301        let port_number = u32::from_le_bytes(data[0..4].try_into().unwrap());
302        let event = u16::from_le_bytes(data[4..6].try_into().unwrap());
303        let value = u16::from_le_bytes(data[6..8].try_into().unwrap());
304        Ok(Self {
305            port_number,
306            event,
307            value,
308        })
309    }
310}
311
312// struct VirtioSerialControlResize {
313//     columns: u16,
314//     rows: u16,
315// }
316
317struct VirtioSerialControlPort {
318    mem: GuestMemory,
319    thread: Option<JoinHandle<()>>,
320    port: Arc<VirtioSerialPort>,
321}
322
323type VirtioSerialControlReadyFn = Box<dyn Fn() + Send>;
324
325impl VirtioSerialControlPort {
326    pub fn new(mem: &GuestMemory) -> Self {
327        Self {
328            mem: mem.clone(),
329            thread: None,
330            port: Arc::new(VirtioSerialPort::new(mem)),
331        }
332    }
333
334    pub fn start(&mut self, ready_callback: VirtioSerialControlReadyFn) {
335        self.thread = Some(Self::start_control_thread(
336            &self.port,
337            &self.mem,
338            ready_callback,
339        ));
340    }
341
342    fn start_control_thread(
343        port: &Arc<VirtioSerialPort>,
344        mem: &GuestMemory,
345        ready_callback: VirtioSerialControlReadyFn,
346    ) -> JoinHandle<()> {
347        let read_fn = port_read_fn(port, mem);
348        std::thread::Builder::new()
349            .name("virtio control".into())
350            .spawn(move || {
351                loop {
352                    let data = (read_fn)();
353                    if data.is_empty() {
354                        break;
355                    }
356                    let message = VirtioSerialControl::from_bytes(&data);
357                    if message.is_err() {
358                        continue;
359                    }
360
361                    let message = message.unwrap();
362                    match message.event {
363                        VIRTIO_CONSOLE_DEVICE_READY => {
364                            if message.value != 0 {
365                                (ready_callback)()
366                            }
367                        }
368                        VIRTIO_CONSOLE_PORT_READY => (),
369                        _ => tracing::warn!(
370                            event = message.event,
371                            port_number = message.port_number,
372                            value = message.value,
373                            "[SERIAL] Unhandled control event",
374                        ),
375                    }
376                }
377            })
378            .unwrap()
379    }
380
381    pub fn register_port(&self, port_number: u16) {
382        let mut control_message = VirtioSerialControl {
383            port_number: port_number as u32,
384            event: VIRTIO_CONSOLE_DEVICE_ADD,
385            value: 0,
386        };
387        self.port
388            .write_all_to_port(control_message.to_bytes().as_slice());
389
390        control_message.event = VIRTIO_CONSOLE_PORT_NAME;
391        let mut name_message = control_message.to_bytes();
392        name_message.extend_from_slice(format!("port{}\0", port_number).as_bytes());
393        self.port.write_all_to_port(name_message.as_slice());
394
395        control_message.event = VIRTIO_CONSOLE_PORT_OPEN;
396        control_message.value = 1;
397        self.port
398            .write_all_to_port(control_message.to_bytes().as_slice());
399    }
400}
401
402pub struct VirtioSerialDevice {
403    mem: GuestMemory,
404    config: VirtioSerialDeviceConfig,
405    ports: Vec<Arc<VirtioSerialPort>>,
406    control_port: Arc<Mutex<VirtioSerialControlPort>>,
407}
408
409type VirtioSerialPortRead = Box<dyn Fn() -> Vec<u8> + Send>;
410type VirtioSerialPortWrite = Box<dyn Fn(&[u8]) + Send>;
411
412impl VirtioSerialDevice {
413    pub fn new(max_ports: u16, gm: &GuestMemory) -> Self {
414        let config = VirtioSerialDeviceConfig {
415            columns: 0,
416            rows: 0,
417            max_ports: max_ports as u32,
418            _emergency_write: 0,
419        };
420
421        let control_port = Arc::new(Mutex::new(VirtioSerialControlPort::new(gm)));
422        let mut ports = Vec::new();
423        ports.resize_with(config.max_ports as usize, || {
424            Arc::new(VirtioSerialPort::new(gm))
425        });
426        VirtioSerialDevice {
427            mem: gm.clone(),
428            config,
429            ports,
430            control_port,
431        }
432    }
433
434    pub fn io(&self) -> SerialIo {
435        SerialIo {
436            ports: self.ports.clone(),
437            mem: self.mem.clone(),
438        }
439    }
440}
441
442fn port_read_fn(port: &Arc<VirtioSerialPort>, mem: &GuestMemory) -> VirtioSerialPortRead {
443    let port = port.clone();
444    let mem = mem.clone();
445    Box::new(move || {
446        let work = port.read_from_port();
447        let mut data = Vec::new();
448        if let Some(work) = work {
449            for payload in work.payload.iter() {
450                if payload.writeable {
451                    break;
452                }
453                data.resize(data.len() + payload.length as usize, 0);
454                let dest_index = data.len() - payload.length as usize;
455                let next_chunk = data.as_mut_slice().split_at_mut(dest_index).1;
456                mem.read_at(payload.address, next_chunk).unwrap();
457            }
458            port.complete_read_from_port(work);
459        }
460        data
461    })
462}
463
464#[derive(Clone)]
465pub struct SerialIo {
466    ports: Vec<Arc<VirtioSerialPort>>,
467    mem: GuestMemory,
468}
469
470impl SerialIo {
471    pub fn get_port_read_fn(&self, port: u16) -> VirtioSerialPortRead {
472        assert!((port as usize) < self.ports.len());
473        port_read_fn(&self.ports[port as usize], &self.mem)
474    }
475
476    pub fn port_write_fn(port: &Arc<VirtioSerialPort>) -> VirtioSerialPortWrite {
477        let port = port.clone();
478        Box::new(move |data: &[u8]| {
479            port.write_all_to_port(data);
480        })
481    }
482
483    pub fn get_port_write_fn(&self, port: u16) -> VirtioSerialPortWrite {
484        assert!((port as usize) < self.ports.len());
485        Self::port_write_fn(&self.ports[port as usize])
486    }
487
488    pub fn open_port(&self, port: u16) {
489        assert!((port as usize) < self.ports.len());
490        self.ports[port as usize].open();
491    }
492
493    pub fn close_port(&self, port: u16) {
494        assert!((port as usize) < self.ports.len());
495        self.ports[port as usize].close();
496    }
497
498    pub fn queue_input_bytes(&mut self, c: &[u8]) -> io::Result<()> {
499        self.write_port(0, &c);
500        Ok(())
501    }
502
503    pub fn write_port<T: AsRef<[u8]>>(&self, port: u16, data: &T) {
504        self.write(port, data);
505    }
506
507    pub fn write<T>(&self, port: u16, data: &T)
508    where
509        T: AsRef<[u8]>,
510    {
511        assert!((port as usize) < self.ports.len());
512        let port = &self.ports[port as usize];
513        port.write_all_to_port(data.as_ref());
514    }
515}
516
517impl LegacyVirtioDevice for VirtioSerialDevice {
518    fn traits(&self) -> DeviceTraits {
519        let queue_size = 2 + 2 * self.config.max_ports;
520        let features = if self.config.max_ports > 1 {
521            VIRTIO_CONSOLE_F_MULTIPORT
522        } else {
523            0
524        };
525        DeviceTraits {
526            device_id: VIRTIO_DEVICE_TYPE_CONSOLE,
527            device_features: features,
528            max_queues: queue_size as u16,
529            device_register_length: 12,
530            ..Default::default()
531        }
532    }
533
534    fn read_registers_u32(&self, offset: u16) -> u32 {
535        match offset {
536            0 => {
537                (self.config.rows as u32 & 0xff) << 24
538                    | (self.config.rows as u32 >> 16) << 16
539                    | (self.config.columns as u32 & 0xff) << 8
540                    | (self.config.columns as u32 >> 16)
541            }
542            4 => self.config.max_ports,
543            _ => 0,
544        }
545    }
546
547    fn write_registers_u32(&mut self, offset: u16, val: u32) {
548        // TODO: implement emergency_write (offset 8)
549        tracing::warn!(offset, val, "[VIRTIO SERIAL] Unknown write",);
550    }
551
552    fn get_work_callback(&mut self, index: u16) -> Box<dyn VirtioQueueWorkerContext + Send> {
553        let port = match index {
554            0 | 1 => self.ports[0].clone(),
555            2 | 3 => self.control_port.lock().port.clone(),
556            _ => self.ports[index as usize / 2 - 1].clone(),
557        };
558        Box::new(VirtioSerialWorker {
559            index,
560            port,
561            reader: (index & 1 == 0),
562        })
563    }
564
565    fn state_change(&mut self, state: &VirtioState) {
566        match state {
567            // if multi-port is set, start the control port thread
568            VirtioState::Running(run_state) => {
569                if run_state.features & VIRTIO_CONSOLE_F_MULTIPORT != 0 {
570                    let enabled_queues = run_state.enabled_queues.clone();
571                    if run_state.enabled_queues[2] && run_state.enabled_queues[3] {
572                        // on ready callback, asynchronously register the available ports with the guest.
573                        let max_ports = self.config.max_ports as u16;
574                        let register_control_port = self.control_port.clone();
575                        self.control_port.lock().start(Box::new(move || {
576                            let enabled_queues = enabled_queues.clone();
577                            let register_control_port = register_control_port.clone();
578                            std::thread::Builder::new()
579                                .name("virtio serial register".into())
580                                .spawn(move || {
581                                    for port_index in 0..max_ports {
582                                        let read_index = if port_index == 0 {
583                                            0
584                                        } else {
585                                            4 + port_index as usize - 1
586                                        };
587                                        let write_index = if port_index == 0 {
588                                            1
589                                        } else {
590                                            4 + port_index as usize
591                                        };
592                                        if enabled_queues[read_index] && enabled_queues[write_index]
593                                        {
594                                            register_control_port.lock().register_port(port_index);
595                                        }
596                                    }
597                                })
598                                .unwrap();
599                        }));
600                    }
601                }
602            }
603            _ => {
604                for port in self.ports.iter() {
605                    port.stop();
606                }
607            }
608        }
609    }
610}
611
612struct VirtioSerialWorker {
613    index: u16,
614    port: Arc<VirtioSerialPort>,
615    reader: bool,
616}
617
618#[async_trait]
619impl VirtioQueueWorkerContext for VirtioSerialWorker {
620    async fn process_work(&mut self, work: anyhow::Result<VirtioQueueCallbackWork>) -> bool {
621        if let Err(err) = work {
622            tracing::error!(
623                index = self.index,
624                err = err.as_ref() as &dyn std::error::Error,
625                "queue error"
626            );
627            return false;
628        }
629        let work = work.unwrap();
630        if self.reader {
631            self.port.process_virtio_read(work).await
632        } else {
633            self.port.process_virtio_write(work).await
634        }
635    }
636}