1use crate::queue::QueueCore;
5use crate::queue::QueueError;
6use crate::queue::QueueParams;
7use crate::queue::VirtioQueuePayload;
8use async_trait::async_trait;
9use futures::FutureExt;
10use futures::Stream;
11use futures::StreamExt;
12use guestmem::DoorbellRegistration;
13use guestmem::GuestMemory;
14use guestmem::GuestMemoryError;
15use guestmem::MappedMemoryRegion;
16use pal_async::DefaultPool;
17use pal_async::driver::Driver;
18use pal_async::task::Spawn;
19use pal_async::wait::PolledWait;
20use pal_event::Event;
21use parking_lot::Mutex;
22use std::io::Error;
23use std::pin::Pin;
24use std::sync::Arc;
25use std::task::Context;
26use std::task::Poll;
27use std::task::ready;
28use task_control::AsyncRun;
29use task_control::StopTask;
30use task_control::TaskControl;
31use thiserror::Error;
32use vmcore::interrupt::Interrupt;
33use vmcore::vm_task::VmTaskDriver;
34use vmcore::vm_task::VmTaskDriverSource;
35
36#[async_trait]
37pub trait VirtioQueueWorkerContext {
38 async fn process_work(&mut self, work: anyhow::Result<VirtioQueueCallbackWork>) -> bool;
39}
40
41#[derive(Debug)]
42pub struct VirtioQueueUsedHandler {
43 core: QueueCore,
44 last_used_index: u16,
45 outstanding_desc_count: Arc<Mutex<(u16, event_listener::Event)>>,
46 notify_guest: Interrupt,
47}
48
49impl VirtioQueueUsedHandler {
50 fn new(core: QueueCore, notify_guest: Interrupt) -> Self {
51 Self {
52 core,
53 last_used_index: 0,
54 outstanding_desc_count: Arc::new(Mutex::new((0, event_listener::Event::new()))),
55 notify_guest,
56 }
57 }
58
59 pub fn add_outstanding_descriptor(&self) {
60 let (count, _) = &mut *self.outstanding_desc_count.lock();
61 *count += 1;
62 }
63
64 pub fn await_outstanding_descriptors(&self) -> event_listener::EventListener {
65 let (count, event) = &*self.outstanding_desc_count.lock();
66 let listener = event.listen();
67 if *count == 0 {
68 event.notify(usize::MAX);
69 }
70 listener
71 }
72
73 pub fn complete_descriptor(&mut self, descriptor_index: u16, bytes_written: u32) {
74 match self.core.complete_descriptor(
75 &mut self.last_used_index,
76 descriptor_index,
77 bytes_written,
78 ) {
79 Ok(true) => {
80 self.notify_guest.deliver();
81 }
82 Ok(false) => {}
83 Err(err) => {
84 tracelimit::error_ratelimited!(
85 error = &err as &dyn std::error::Error,
86 "failed to complete descriptor"
87 );
88 }
89 }
90 {
91 let (count, event) = &mut *self.outstanding_desc_count.lock();
92 *count -= 1;
93 if *count == 0 {
94 event.notify(usize::MAX);
95 }
96 }
97 }
98}
99
100pub struct VirtioQueueCallbackWork {
101 pub payload: Vec<VirtioQueuePayload>,
102 used_queue_handler: Arc<Mutex<VirtioQueueUsedHandler>>,
103 descriptor_index: u16,
104 completed: bool,
105}
106
107impl VirtioQueueCallbackWork {
108 pub fn new(
109 payload: Vec<VirtioQueuePayload>,
110 used_queue_handler: &Arc<Mutex<VirtioQueueUsedHandler>>,
111 descriptor_index: u16,
112 ) -> Self {
113 let used_queue_handler = used_queue_handler.clone();
114 used_queue_handler.lock().add_outstanding_descriptor();
115 Self {
116 payload,
117 used_queue_handler,
118 descriptor_index,
119 completed: false,
120 }
121 }
122
123 pub fn complete(&mut self, bytes_written: u32) {
124 assert!(!self.completed);
125 self.used_queue_handler
126 .lock()
127 .complete_descriptor(self.descriptor_index, bytes_written);
128 self.completed = true;
129 }
130
131 pub fn descriptor_index(&self) -> u16 {
132 self.descriptor_index
133 }
134
135 pub fn get_payload_length(&self, writeable: bool) -> u64 {
137 self.payload
138 .iter()
139 .filter(|x| x.writeable == writeable)
140 .fold(0, |acc, x| acc + x.length as u64)
141 }
142
143 pub fn read(&self, mem: &GuestMemory, target: &mut [u8]) -> Result<usize, GuestMemoryError> {
145 let mut remaining = target;
146 let mut read_bytes: usize = 0;
147 for payload in &self.payload {
148 if payload.writeable {
149 continue;
150 }
151
152 let size = std::cmp::min(payload.length as usize, remaining.len());
153 let (current, next) = remaining.split_at_mut(size);
154 mem.read_at(payload.address, current)?;
155 read_bytes += size;
156 if next.is_empty() {
157 break;
158 }
159
160 remaining = next;
161 }
162
163 Ok(read_bytes)
164 }
165
166 pub fn write_at_offset(
168 &self,
169 offset: u64,
170 mem: &GuestMemory,
171 source: &[u8],
172 ) -> Result<(), VirtioWriteError> {
173 let mut skip_bytes = offset;
174 let mut remaining = source;
175 for payload in &self.payload {
176 if !payload.writeable {
177 continue;
178 }
179
180 let payload_length = payload.length as u64;
181 if skip_bytes >= payload_length {
182 skip_bytes -= payload_length;
183 continue;
184 }
185
186 let size = std::cmp::min(
187 payload_length as usize - skip_bytes as usize,
188 remaining.len(),
189 );
190 let (current, next) = remaining.split_at(size);
191 mem.write_at(payload.address + skip_bytes, current)?;
192 remaining = next;
193 if remaining.is_empty() {
194 break;
195 }
196 skip_bytes = 0;
197 }
198
199 if !remaining.is_empty() {
200 return Err(VirtioWriteError::NotAllWritten(source.len()));
201 }
202
203 Ok(())
204 }
205
206 pub fn write(&self, mem: &GuestMemory, source: &[u8]) -> Result<(), VirtioWriteError> {
207 self.write_at_offset(0, mem, source)
208 }
209}
210
211#[derive(Debug, Error)]
212pub enum VirtioWriteError {
213 #[error(transparent)]
214 Memory(#[from] GuestMemoryError),
215 #[error("{0:#x} bytes not written")]
216 NotAllWritten(usize),
217}
218
219impl Drop for VirtioQueueCallbackWork {
220 fn drop(&mut self) {
221 if !self.completed {
222 self.complete(0);
223 }
224 }
225}
226
227#[derive(Debug)]
228pub struct VirtioQueue {
229 core: QueueCore,
230 last_avail_index: u16,
231 used_handler: Arc<Mutex<VirtioQueueUsedHandler>>,
232 queue_event: PolledWait<Event>,
233}
234
235impl VirtioQueue {
236 pub fn new(
237 features: u64,
238 params: QueueParams,
239 mem: GuestMemory,
240 notify: Interrupt,
241 queue_event: PolledWait<Event>,
242 ) -> Result<Self, QueueError> {
243 let core = QueueCore::new(features, mem, params)?;
244 let used_handler = Arc::new(Mutex::new(VirtioQueueUsedHandler::new(
245 core.clone(),
246 notify,
247 )));
248 Ok(Self {
249 core,
250 last_avail_index: 0,
251 used_handler,
252 queue_event,
253 })
254 }
255
256 async fn wait_for_outstanding_descriptors(&self) {
257 let wait_for_descriptors = self.used_handler.lock().await_outstanding_descriptors();
258 wait_for_descriptors.await;
259 }
260
261 fn poll_next_buffer(
262 &mut self,
263 cx: &mut Context<'_>,
264 ) -> Poll<Result<Option<VirtioQueueCallbackWork>, QueueError>> {
265 let descriptor_index = loop {
266 if let Some(descriptor_index) = self.core.descriptor_index(self.last_avail_index)? {
267 break descriptor_index;
268 };
269 ready!(self.queue_event.wait().poll_unpin(cx)).expect("waits on Event cannot fail");
270 };
271 let payload = self
272 .core
273 .reader(descriptor_index)
274 .collect::<Result<Vec<_>, _>>()?;
275
276 self.last_avail_index = self.last_avail_index.wrapping_add(1);
277 Poll::Ready(Ok(Some(VirtioQueueCallbackWork::new(
278 payload,
279 &self.used_handler,
280 descriptor_index,
281 ))))
282 }
283}
284
285impl Drop for VirtioQueue {
286 fn drop(&mut self) {
287 if Arc::get_mut(&mut self.used_handler).is_none() {
288 tracing::error!("Virtio queue dropped with outstanding work pending")
289 }
290 }
291}
292
293impl Stream for VirtioQueue {
294 type Item = Result<VirtioQueueCallbackWork, Error>;
295
296 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
297 let Some(r) = ready!(self.get_mut().poll_next_buffer(cx)).transpose() else {
298 return Poll::Ready(None);
299 };
300
301 Poll::Ready(Some(r.map_err(Error::other)))
302 }
303}
304
305enum VirtioQueueStateInner {
306 Initializing {
307 mem: GuestMemory,
308 features: u64,
309 params: QueueParams,
310 event: Event,
311 notify: Interrupt,
312 exit_event: event_listener::EventListener,
313 },
314 InitializationInProgress,
315 Running {
316 queue: VirtioQueue,
317 exit_event: event_listener::EventListener,
318 },
319}
320
321pub struct VirtioQueueState {
322 inner: VirtioQueueStateInner,
323}
324
325pub struct VirtioQueueWorker {
326 driver: Box<dyn Driver>,
327 context: Box<dyn VirtioQueueWorkerContext + Send>,
328}
329
330impl VirtioQueueWorker {
331 pub fn new(driver: impl Driver, context: Box<dyn VirtioQueueWorkerContext + Send>) -> Self {
332 Self {
333 driver: Box::new(driver),
334 context,
335 }
336 }
337
338 pub fn into_running_task(
339 self,
340 name: impl Into<String>,
341 mem: GuestMemory,
342 features: u64,
343 queue_resources: QueueResources,
344 exit_event: event_listener::EventListener,
345 ) -> TaskControl<VirtioQueueWorker, VirtioQueueState> {
346 let name = name.into();
347 let (_, driver) = DefaultPool::spawn_on_thread(&name);
348
349 let mut task = TaskControl::new(self);
350 task.insert(
351 driver,
352 name,
353 VirtioQueueState {
354 inner: VirtioQueueStateInner::Initializing {
355 mem,
356 features,
357 params: queue_resources.params,
358 event: queue_resources.event,
359 notify: queue_resources.notify,
360 exit_event,
361 },
362 },
363 );
364 task.start();
365 task
366 }
367
368 async fn run_queue(&mut self, state: &mut VirtioQueueState) -> bool {
369 match &mut state.inner {
370 VirtioQueueStateInner::InitializationInProgress => unreachable!(),
371 VirtioQueueStateInner::Initializing { .. } => {
372 let VirtioQueueStateInner::Initializing {
373 mem,
374 features,
375 params,
376 event,
377 notify,
378 exit_event,
379 } = std::mem::replace(
380 &mut state.inner,
381 VirtioQueueStateInner::InitializationInProgress,
382 )
383 else {
384 unreachable!()
385 };
386 let queue_event = PolledWait::new(&self.driver, event).unwrap();
387 let queue = VirtioQueue::new(features, params, mem, notify, queue_event);
388 if let Err(err) = queue {
389 tracing::error!(
390 err = &err as &dyn std::error::Error,
391 "Failed to start queue"
392 );
393 false
394 } else {
395 state.inner = VirtioQueueStateInner::Running {
396 queue: queue.unwrap(),
397 exit_event,
398 };
399 true
400 }
401 }
402 VirtioQueueStateInner::Running { queue, exit_event } => {
403 let mut exit = exit_event.fuse();
404 let mut queue_ready = queue.next().fuse();
405 let work = futures::select_biased! {
406 _ = exit => return false,
407 work = queue_ready => work.expect("queue will never complete").map_err(anyhow::Error::from),
408 };
409 self.context.process_work(work).await
410 }
411 }
412 }
413}
414
415impl AsyncRun<VirtioQueueState> for VirtioQueueWorker {
416 async fn run(
417 &mut self,
418 stop: &mut StopTask<'_>,
419 state: &mut VirtioQueueState,
420 ) -> Result<(), task_control::Cancelled> {
421 while stop.until_stopped(self.run_queue(state)).await? {}
422 Ok(())
423 }
424}
425
426pub struct VirtioRunningState {
427 pub features: u64,
428 pub enabled_queues: Vec<bool>,
429}
430
431pub enum VirtioState {
432 Unknown,
433 Running(VirtioRunningState),
434 Stopped,
435}
436
437pub(crate) struct VirtioDoorbells {
438 registration: Option<Arc<dyn DoorbellRegistration>>,
439 doorbells: Vec<Box<dyn Send + Sync>>,
440}
441
442impl VirtioDoorbells {
443 pub fn new(registration: Option<Arc<dyn DoorbellRegistration>>) -> Self {
444 Self {
445 registration,
446 doorbells: Vec::new(),
447 }
448 }
449
450 pub fn add(&mut self, address: u64, value: Option<u64>, length: Option<u32>, event: &Event) {
451 if let Some(registration) = &mut self.registration {
452 let doorbell = registration.register_doorbell(address, value, length, event);
453 if let Ok(doorbell) = doorbell {
454 self.doorbells.push(doorbell);
455 }
456 }
457 }
458
459 pub fn clear(&mut self) {
460 self.doorbells.clear();
461 }
462}
463
464#[derive(Copy, Clone, Debug, Default)]
465pub struct DeviceTraitsSharedMemory {
466 pub id: u8,
467 pub size: u64,
468}
469
470#[derive(Copy, Clone, Debug, Default)]
471pub struct DeviceTraits {
472 pub device_id: u16,
473 pub device_features: u64,
474 pub max_queues: u16,
475 pub device_register_length: u32,
476 pub shared_memory: DeviceTraitsSharedMemory,
477}
478
479pub trait LegacyVirtioDevice: Send {
480 fn traits(&self) -> DeviceTraits;
481 fn read_registers_u32(&self, offset: u16) -> u32;
482 fn write_registers_u32(&mut self, offset: u16, val: u32);
483 fn get_work_callback(&mut self, index: u16) -> Box<dyn VirtioQueueWorkerContext + Send>;
484 fn state_change(&mut self, state: &VirtioState);
485}
486
487pub trait VirtioDevice: Send {
488 fn traits(&self) -> DeviceTraits;
489 fn read_registers_u32(&self, offset: u16) -> u32;
490 fn write_registers_u32(&mut self, offset: u16, val: u32);
491 fn enable(&mut self, resources: Resources);
492 fn disable(&mut self);
493}
494
495pub struct QueueResources {
496 pub params: QueueParams,
497 pub notify: Interrupt,
498 pub event: Event,
499}
500
501pub struct Resources {
502 pub features: u64,
503 pub queues: Vec<QueueResources>,
504 pub shared_memory_region: Option<Arc<dyn MappedMemoryRegion>>,
505 pub shared_memory_size: u64,
506}
507
508pub struct LegacyWrapper<T: LegacyVirtioDevice> {
510 device: T,
511 driver: VmTaskDriver,
512 mem: GuestMemory,
513 workers: Vec<TaskControl<VirtioQueueWorker, VirtioQueueState>>,
514 exit_event: event_listener::Event,
515}
516
517impl<T: LegacyVirtioDevice> LegacyWrapper<T> {
518 pub fn new(driver_source: &VmTaskDriverSource, device: T, mem: &GuestMemory) -> Self {
519 Self {
520 device,
521 driver: driver_source.simple(),
522 mem: mem.clone(),
523 workers: Vec::new(),
524 exit_event: event_listener::Event::new(),
525 }
526 }
527}
528
529impl<T: LegacyVirtioDevice> VirtioDevice for LegacyWrapper<T> {
530 fn traits(&self) -> DeviceTraits {
531 self.device.traits()
532 }
533
534 fn read_registers_u32(&self, offset: u16) -> u32 {
535 self.device.read_registers_u32(offset)
536 }
537
538 fn write_registers_u32(&mut self, offset: u16, val: u32) {
539 self.device.write_registers_u32(offset, val)
540 }
541
542 fn enable(&mut self, resources: Resources) {
543 let running_state = VirtioRunningState {
544 features: resources.features,
545 enabled_queues: resources
546 .queues
547 .iter()
548 .map(|QueueResources { params, .. }| params.enable)
549 .collect(),
550 };
551
552 self.device
553 .state_change(&VirtioState::Running(running_state));
554 self.workers = resources
555 .queues
556 .into_iter()
557 .enumerate()
558 .filter_map(|(i, queue_resources)| {
559 if !queue_resources.params.enable {
560 return None;
561 }
562 let worker = VirtioQueueWorker::new(
563 self.driver.clone(),
564 self.device.get_work_callback(i as u16),
565 );
566 Some(worker.into_running_task(
567 "virtio-queue".to_string(),
568 self.mem.clone(),
569 resources.features,
570 queue_resources,
571 self.exit_event.listen(),
572 ))
573 })
574 .collect();
575 }
576
577 fn disable(&mut self) {
578 if self.workers.is_empty() {
579 return;
580 }
581 self.exit_event.notify(usize::MAX);
582 self.device.state_change(&VirtioState::Stopped);
583 let mut workers = self.workers.drain(..).collect::<Vec<_>>();
584 self.driver
585 .spawn("shutdown-legacy-virtio-queues".to_owned(), async move {
586 futures::future::join_all(workers.iter_mut().map(async |worker| {
587 worker.stop().await;
588 if let Some(VirtioQueueStateInner::Running { queue, .. }) =
589 worker.state_mut().map(|s| &s.inner)
590 {
591 queue.wait_for_outstanding_descriptors().await;
592 }
593 }))
594 .await;
595 })
596 .detach();
597 }
598}
599
600impl<T: LegacyVirtioDevice> Drop for LegacyWrapper<T> {
601 fn drop(&mut self) {
602 self.disable();
603 }
604}