1#![expect(missing_docs)]
5#![forbid(unsafe_code)]
6
7use async_trait::async_trait;
8use guestmem::GuestMemory;
9use parking_lot::Condvar;
10use parking_lot::Mutex;
11use std::io;
12use std::io::ErrorKind;
13use std::ops::DerefMut;
14use std::sync::Arc;
15use std::thread::JoinHandle;
16use virtio::DeviceTraits;
17use virtio::LegacyVirtioDevice;
18use virtio::VirtioQueueCallbackWork;
19use virtio::VirtioQueueWorkerContext;
20use virtio::VirtioState;
21
22const VIRTIO_DEVICE_TYPE_CONSOLE: u16 = 3;
23
24const VIRTIO_CONSOLE_F_MULTIPORT: u64 = 2;
26const VIRTIO_CONSOLE_DEVICE_READY: u16 = 0;
29const VIRTIO_CONSOLE_DEVICE_ADD: u16 = 1;
30const VIRTIO_CONSOLE_PORT_READY: u16 = 3;
32const VIRTIO_CONSOLE_PORT_OPEN: u16 = 6;
35const VIRTIO_CONSOLE_PORT_NAME: u16 = 7;
36
37enum VirtioSerialPortIoState {
38 Unavailable,
39 Ready(Option<VirtioQueueCallbackWork>),
40 Processing,
41 Disconnected,
42 Exiting,
43}
44
45pub struct VirtioSerialPort {
46 mem: GuestMemory,
47 read_state: (Mutex<VirtioSerialPortIoState>, Condvar),
48 write_state: (Mutex<VirtioSerialPortIoState>, Condvar),
49}
50
51impl VirtioSerialPort {
52 pub fn new(mem: &GuestMemory) -> Self {
53 Self {
54 mem: mem.clone(),
55 read_state: (
56 Mutex::new(VirtioSerialPortIoState::Disconnected),
57 Condvar::new(),
58 ),
59 write_state: (
60 Mutex::new(VirtioSerialPortIoState::Disconnected),
61 Condvar::new(),
62 ),
63 }
64 }
65
66 pub fn read_from_port(&self) -> Option<VirtioQueueCallbackWork> {
67 let mut read_work: Option<VirtioQueueCallbackWork> = None;
68 let (state, state_cvar) = &self.read_state;
69 let mut cur_state = state.lock();
70 while let VirtioSerialPortIoState::Unavailable = *cur_state {
71 state_cvar.wait(&mut cur_state);
72 }
73
74 if let VirtioSerialPortIoState::Ready(work) = cur_state.deref_mut() {
75 assert!(work.is_some());
76 read_work = work.take();
77 *cur_state = VirtioSerialPortIoState::Processing;
78 }
79 read_work
80 }
81
82 pub fn complete_read_from_port(&self, mut work: VirtioQueueCallbackWork) {
83 work.complete(0);
84 let (state, state_cvar) = &self.read_state;
85 let mut cur_state = state.lock();
86 if let VirtioSerialPortIoState::Processing = *cur_state {
87 *cur_state = VirtioSerialPortIoState::Unavailable;
88 state_cvar.notify_one();
89 }
90 }
91
92 pub fn write_to_port(&self, data: &[u8]) -> usize {
93 if data.is_empty() {
94 return 0;
95 }
96 let mut write_work: Option<VirtioQueueCallbackWork> = None;
97 let (state, state_cvar) = &self.write_state;
98 {
99 let mut cur_state = state.lock();
100 while let VirtioSerialPortIoState::Unavailable = *cur_state {
101 state_cvar.wait(&mut cur_state);
102 }
103
104 if let VirtioSerialPortIoState::Ready(work) = cur_state.deref_mut() {
105 assert!(work.is_some());
106 write_work = work.take();
107 *cur_state = VirtioSerialPortIoState::Processing;
108 }
109 }
110
111 let mut bytes_written = 0;
112 if let Some(mut work) = write_work {
113 let mut remaining = data;
114 if remaining.len() > u32::MAX as usize {
115 remaining = &remaining[..u32::MAX as usize];
116 }
117 for payload in work.payload.iter() {
118 if remaining.is_empty() {
119 break;
120 }
121 if !payload.writeable {
122 break;
123 }
124 if payload.length == 0 {
125 continue;
126 }
127 let bytes_to_write = std::cmp::min(payload.length as usize, remaining.len());
128 if let Err(error) = self.mem.write_at(payload.address, remaining) {
129 tracing::error!(
130 error = &error as &dyn std::error::Error,
131 "[virtio_serial] Failed to write to guest memory",
132 );
133 break;
134 }
135 remaining = &remaining[bytes_to_write..];
136 bytes_written += bytes_to_write;
137 }
138 work.complete(u32::try_from(bytes_written).unwrap());
139 let mut cur_state = state.lock();
140 if let VirtioSerialPortIoState::Processing = *cur_state {
141 *cur_state = VirtioSerialPortIoState::Unavailable;
142 state_cvar.notify_one();
143 }
144 }
145 bytes_written
146 }
147
148 pub fn write_all_to_port(&self, data: &[u8]) -> bool {
149 let mut remaining = data;
150 while !remaining.is_empty() {
151 let bytes_written = self.write_to_port(remaining);
152 if bytes_written == 0 {
153 break;
154 }
155 remaining = &remaining[bytes_written..];
156 }
157 remaining.is_empty()
158 }
159
160 pub async fn process_virtio_read(&self, work: VirtioQueueCallbackWork) -> bool {
162 let (state, state_cvar) = &self.write_state;
163 let mut cur_state = state.lock();
164 match *cur_state {
165 VirtioSerialPortIoState::Unavailable | VirtioSerialPortIoState::Disconnected => {
166 *cur_state = VirtioSerialPortIoState::Ready(Some(work));
167 state_cvar.notify_one();
168 }
169 VirtioSerialPortIoState::Exiting => {
170 return false;
171 }
172 _ => panic!("Unexpected serial port IO state"),
173 };
174 loop {
175 match *cur_state {
176 VirtioSerialPortIoState::Unavailable
177 | VirtioSerialPortIoState::Disconnected
178 | VirtioSerialPortIoState::Exiting => {
179 break false;
180 }
181 _ => {
182 state_cvar.wait(&mut cur_state);
183 }
184 };
185 }
186 }
187
188 pub async fn process_virtio_write(&self, mut work: VirtioQueueCallbackWork) -> bool {
190 let (state, state_cvar) = &self.read_state;
191 let mut cur_state = state.lock();
192 match *cur_state {
193 VirtioSerialPortIoState::Unavailable => {
194 *cur_state = VirtioSerialPortIoState::Ready(Some(work));
195 state_cvar.notify_one();
196 }
197 VirtioSerialPortIoState::Disconnected => {
198 work.complete(0);
200 return true;
201 }
202 VirtioSerialPortIoState::Exiting => {
203 return false;
204 }
205 _ => panic!("Unexpected serial port IO state"),
206 };
207 loop {
208 match *cur_state {
209 VirtioSerialPortIoState::Unavailable
210 | VirtioSerialPortIoState::Disconnected
211 | VirtioSerialPortIoState::Exiting => {
212 break false;
213 }
214 _ => {
215 state_cvar.wait(&mut cur_state);
216 }
217 };
218 }
219 }
220
221 pub fn open(&self) {
222 let (state, _) = &self.read_state;
223 let mut cur_state = state.lock();
224 if let VirtioSerialPortIoState::Disconnected = *cur_state {
225 *cur_state = VirtioSerialPortIoState::Unavailable;
226 } else {
227 panic!("Opening a port that has already been opened");
228 }
229
230 let (state, _) = &self.write_state;
231 let mut cur_state = state.lock();
232 if let VirtioSerialPortIoState::Disconnected = *cur_state {
233 *cur_state = VirtioSerialPortIoState::Unavailable;
234 }
235 }
236
237 pub fn close(&self) {
238 let (state, state_cvar) = &self.read_state;
239 let mut cur_state = state.lock();
240 if let VirtioSerialPortIoState::Ready(work) = cur_state.deref_mut() {
241 assert!(work.is_some());
242 let mut work = work.take().expect("[VIRTO SERIAL] empty work");
243 work.complete(0);
244 } else if let VirtioSerialPortIoState::Disconnected = *cur_state {
245 panic!("Closing a port that was not open");
246 }
247 *cur_state = VirtioSerialPortIoState::Disconnected;
248 state_cvar.notify_one();
249
250 let (state, state_cvar) = &self.write_state;
251 let mut cur_state = state.lock();
252 *cur_state = VirtioSerialPortIoState::Disconnected;
253 state_cvar.notify_one();
254 }
255
256 pub fn stop(&self) {
257 let (state, state_cvar) = &self.read_state;
258 *state.lock() = VirtioSerialPortIoState::Exiting;
259 state_cvar.notify_one();
260 let (state, state_cvar) = &self.write_state;
261 *state.lock() = VirtioSerialPortIoState::Exiting;
262 state_cvar.notify_one();
263 }
264}
265
266impl Drop for VirtioSerialPort {
267 fn drop(&mut self) {
268 self.stop();
269 }
270}
271
272struct VirtioSerialDeviceConfig {
273 columns: u16,
274 rows: u16,
275 max_ports: u32,
276 _emergency_write: u32,
277}
278
279struct VirtioSerialControl {
280 port_number: u32,
281 event: u16,
282 value: u16,
283}
284
285impl VirtioSerialControl {
286 pub fn to_bytes(&self) -> Vec<u8> {
287 let mut data = Vec::new();
288 data.extend_from_slice(&self.port_number.to_le_bytes());
289 data.extend_from_slice(&self.event.to_le_bytes());
290 data.extend_from_slice(&self.value.to_le_bytes());
291 data
292 }
293
294 pub fn from_bytes(data: &[u8]) -> Result<Self, io::Error> {
295 if data.len() < 8 {
296 return Err(io::Error::new(
297 ErrorKind::InvalidInput,
298 format!("Data is too small {} bytes", data.len()),
299 ));
300 }
301 let port_number = u32::from_le_bytes(data[0..4].try_into().unwrap());
302 let event = u16::from_le_bytes(data[4..6].try_into().unwrap());
303 let value = u16::from_le_bytes(data[6..8].try_into().unwrap());
304 Ok(Self {
305 port_number,
306 event,
307 value,
308 })
309 }
310}
311
312struct VirtioSerialControlPort {
318 mem: GuestMemory,
319 thread: Option<JoinHandle<()>>,
320 port: Arc<VirtioSerialPort>,
321}
322
323type VirtioSerialControlReadyFn = Box<dyn Fn() + Send>;
324
325impl VirtioSerialControlPort {
326 pub fn new(mem: &GuestMemory) -> Self {
327 Self {
328 mem: mem.clone(),
329 thread: None,
330 port: Arc::new(VirtioSerialPort::new(mem)),
331 }
332 }
333
334 pub fn start(&mut self, ready_callback: VirtioSerialControlReadyFn) {
335 self.thread = Some(Self::start_control_thread(
336 &self.port,
337 &self.mem,
338 ready_callback,
339 ));
340 }
341
342 fn start_control_thread(
343 port: &Arc<VirtioSerialPort>,
344 mem: &GuestMemory,
345 ready_callback: VirtioSerialControlReadyFn,
346 ) -> JoinHandle<()> {
347 let read_fn = port_read_fn(port, mem);
348 std::thread::Builder::new()
349 .name("virtio control".into())
350 .spawn(move || {
351 loop {
352 let data = (read_fn)();
353 if data.is_empty() {
354 break;
355 }
356 let message = VirtioSerialControl::from_bytes(&data);
357 if message.is_err() {
358 continue;
359 }
360
361 let message = message.unwrap();
362 match message.event {
363 VIRTIO_CONSOLE_DEVICE_READY => {
364 if message.value != 0 {
365 (ready_callback)()
366 }
367 }
368 VIRTIO_CONSOLE_PORT_READY => (),
369 _ => tracing::warn!(
370 event = message.event,
371 port_number = message.port_number,
372 value = message.value,
373 "[SERIAL] Unhandled control event",
374 ),
375 }
376 }
377 })
378 .unwrap()
379 }
380
381 pub fn register_port(&self, port_number: u16) {
382 let mut control_message = VirtioSerialControl {
383 port_number: port_number as u32,
384 event: VIRTIO_CONSOLE_DEVICE_ADD,
385 value: 0,
386 };
387 self.port
388 .write_all_to_port(control_message.to_bytes().as_slice());
389
390 control_message.event = VIRTIO_CONSOLE_PORT_NAME;
391 let mut name_message = control_message.to_bytes();
392 name_message.extend_from_slice(format!("port{}\0", port_number).as_bytes());
393 self.port.write_all_to_port(name_message.as_slice());
394
395 control_message.event = VIRTIO_CONSOLE_PORT_OPEN;
396 control_message.value = 1;
397 self.port
398 .write_all_to_port(control_message.to_bytes().as_slice());
399 }
400}
401
402pub struct VirtioSerialDevice {
403 mem: GuestMemory,
404 config: VirtioSerialDeviceConfig,
405 ports: Vec<Arc<VirtioSerialPort>>,
406 control_port: Arc<Mutex<VirtioSerialControlPort>>,
407}
408
409type VirtioSerialPortRead = Box<dyn Fn() -> Vec<u8> + Send>;
410type VirtioSerialPortWrite = Box<dyn Fn(&[u8]) + Send>;
411
412impl VirtioSerialDevice {
413 pub fn new(max_ports: u16, gm: &GuestMemory) -> Self {
414 let config = VirtioSerialDeviceConfig {
415 columns: 0,
416 rows: 0,
417 max_ports: max_ports as u32,
418 _emergency_write: 0,
419 };
420
421 let control_port = Arc::new(Mutex::new(VirtioSerialControlPort::new(gm)));
422 let mut ports = Vec::new();
423 ports.resize_with(config.max_ports as usize, || {
424 Arc::new(VirtioSerialPort::new(gm))
425 });
426 VirtioSerialDevice {
427 mem: gm.clone(),
428 config,
429 ports,
430 control_port,
431 }
432 }
433
434 pub fn io(&self) -> SerialIo {
435 SerialIo {
436 ports: self.ports.clone(),
437 mem: self.mem.clone(),
438 }
439 }
440}
441
442fn port_read_fn(port: &Arc<VirtioSerialPort>, mem: &GuestMemory) -> VirtioSerialPortRead {
443 let port = port.clone();
444 let mem = mem.clone();
445 Box::new(move || {
446 let work = port.read_from_port();
447 let mut data = Vec::new();
448 if let Some(work) = work {
449 for payload in work.payload.iter() {
450 if payload.writeable {
451 break;
452 }
453 data.resize(data.len() + payload.length as usize, 0);
454 let dest_index = data.len() - payload.length as usize;
455 let next_chunk = data.as_mut_slice().split_at_mut(dest_index).1;
456 mem.read_at(payload.address, next_chunk).unwrap();
457 }
458 port.complete_read_from_port(work);
459 }
460 data
461 })
462}
463
464#[derive(Clone)]
465pub struct SerialIo {
466 ports: Vec<Arc<VirtioSerialPort>>,
467 mem: GuestMemory,
468}
469
470impl SerialIo {
471 pub fn get_port_read_fn(&self, port: u16) -> VirtioSerialPortRead {
472 assert!((port as usize) < self.ports.len());
473 port_read_fn(&self.ports[port as usize], &self.mem)
474 }
475
476 pub fn port_write_fn(port: &Arc<VirtioSerialPort>) -> VirtioSerialPortWrite {
477 let port = port.clone();
478 Box::new(move |data: &[u8]| {
479 port.write_all_to_port(data);
480 })
481 }
482
483 pub fn get_port_write_fn(&self, port: u16) -> VirtioSerialPortWrite {
484 assert!((port as usize) < self.ports.len());
485 Self::port_write_fn(&self.ports[port as usize])
486 }
487
488 pub fn open_port(&self, port: u16) {
489 assert!((port as usize) < self.ports.len());
490 self.ports[port as usize].open();
491 }
492
493 pub fn close_port(&self, port: u16) {
494 assert!((port as usize) < self.ports.len());
495 self.ports[port as usize].close();
496 }
497
498 pub fn queue_input_bytes(&mut self, c: &[u8]) -> io::Result<()> {
499 self.write_port(0, &c);
500 Ok(())
501 }
502
503 pub fn write_port<T: AsRef<[u8]>>(&self, port: u16, data: &T) {
504 self.write(port, data);
505 }
506
507 pub fn write<T>(&self, port: u16, data: &T)
508 where
509 T: AsRef<[u8]>,
510 {
511 assert!((port as usize) < self.ports.len());
512 let port = &self.ports[port as usize];
513 port.write_all_to_port(data.as_ref());
514 }
515}
516
517impl LegacyVirtioDevice for VirtioSerialDevice {
518 fn traits(&self) -> DeviceTraits {
519 let queue_size = 2 + 2 * self.config.max_ports;
520 let features = if self.config.max_ports > 1 {
521 VIRTIO_CONSOLE_F_MULTIPORT
522 } else {
523 0
524 };
525 DeviceTraits {
526 device_id: VIRTIO_DEVICE_TYPE_CONSOLE,
527 device_features: features,
528 max_queues: queue_size as u16,
529 device_register_length: 12,
530 ..Default::default()
531 }
532 }
533
534 fn read_registers_u32(&self, offset: u16) -> u32 {
535 match offset {
536 0 => {
537 (self.config.rows as u32 & 0xff) << 24
538 | (self.config.rows as u32 >> 16) << 16
539 | (self.config.columns as u32 & 0xff) << 8
540 | (self.config.columns as u32 >> 16)
541 }
542 4 => self.config.max_ports,
543 _ => 0,
544 }
545 }
546
547 fn write_registers_u32(&mut self, offset: u16, val: u32) {
548 tracing::warn!(offset, val, "[VIRTIO SERIAL] Unknown write",);
550 }
551
552 fn get_work_callback(&mut self, index: u16) -> Box<dyn VirtioQueueWorkerContext + Send> {
553 let port = match index {
554 0 | 1 => self.ports[0].clone(),
555 2 | 3 => self.control_port.lock().port.clone(),
556 _ => self.ports[index as usize / 2 - 1].clone(),
557 };
558 Box::new(VirtioSerialWorker {
559 index,
560 port,
561 reader: (index & 1 == 0),
562 })
563 }
564
565 fn state_change(&mut self, state: &VirtioState) {
566 match state {
567 VirtioState::Running(run_state) => {
569 if run_state.features & VIRTIO_CONSOLE_F_MULTIPORT != 0 {
570 let enabled_queues = run_state.enabled_queues.clone();
571 if run_state.enabled_queues[2] && run_state.enabled_queues[3] {
572 let max_ports = self.config.max_ports as u16;
574 let register_control_port = self.control_port.clone();
575 self.control_port.lock().start(Box::new(move || {
576 let enabled_queues = enabled_queues.clone();
577 let register_control_port = register_control_port.clone();
578 std::thread::Builder::new()
579 .name("virtio serial register".into())
580 .spawn(move || {
581 for port_index in 0..max_ports {
582 let read_index = if port_index == 0 {
583 0
584 } else {
585 4 + port_index as usize - 1
586 };
587 let write_index = if port_index == 0 {
588 1
589 } else {
590 4 + port_index as usize
591 };
592 if enabled_queues[read_index] && enabled_queues[write_index]
593 {
594 register_control_port.lock().register_port(port_index);
595 }
596 }
597 })
598 .unwrap();
599 }));
600 }
601 }
602 }
603 _ => {
604 for port in self.ports.iter() {
605 port.stop();
606 }
607 }
608 }
609 }
610}
611
612struct VirtioSerialWorker {
613 index: u16,
614 port: Arc<VirtioSerialPort>,
615 reader: bool,
616}
617
618#[async_trait]
619impl VirtioQueueWorkerContext for VirtioSerialWorker {
620 async fn process_work(&mut self, work: anyhow::Result<VirtioQueueCallbackWork>) -> bool {
621 if let Err(err) = work {
622 tracing::error!(
623 index = self.index,
624 err = err.as_ref() as &dyn std::error::Error,
625 "queue error"
626 );
627 return false;
628 }
629 let work = work.unwrap();
630 if self.reader {
631 self.port.process_virtio_read(work).await
632 } else {
633 self.port.process_virtio_write(work).await
634 }
635 }
636}