1use crate::queue::QueueCore;
5use crate::queue::QueueError;
6use crate::queue::QueueParams;
7use crate::queue::VirtioQueuePayload;
8use async_trait::async_trait;
9use futures::FutureExt;
10use futures::Stream;
11use futures::StreamExt;
12use guestmem::DoorbellRegistration;
13use guestmem::GuestMemory;
14use guestmem::GuestMemoryError;
15use guestmem::MappedMemoryRegion;
16use pal_async::DefaultPool;
17use pal_async::driver::Driver;
18use pal_async::task::Spawn;
19use pal_async::wait::PolledWait;
20use pal_event::Event;
21use parking_lot::Mutex;
22use std::io::Error;
23use std::pin::Pin;
24use std::sync::Arc;
25use std::task::Context;
26use std::task::Poll;
27use std::task::ready;
28use task_control::AsyncRun;
29use task_control::StopTask;
30use task_control::TaskControl;
31use thiserror::Error;
32use vmcore::interrupt::Interrupt;
33use vmcore::vm_task::VmTaskDriver;
34use vmcore::vm_task::VmTaskDriverSource;
35
36#[async_trait]
37pub trait VirtioQueueWorkerContext {
38 async fn process_work(&mut self, work: anyhow::Result<VirtioQueueCallbackWork>) -> bool;
39}
40
41#[derive(Debug)]
42pub struct VirtioQueueUsedHandler {
43 core: QueueCore,
44 last_used_index: u16,
45 outstanding_desc_count: Arc<Mutex<(u16, event_listener::Event)>>,
46 notify_guest: Interrupt,
47}
48
49impl VirtioQueueUsedHandler {
50 fn new(core: QueueCore, notify_guest: Interrupt) -> Self {
51 Self {
52 core,
53 last_used_index: 0,
54 outstanding_desc_count: Arc::new(Mutex::new((0, event_listener::Event::new()))),
55 notify_guest,
56 }
57 }
58
59 pub fn add_outstanding_descriptor(&self) {
60 let (count, _) = &mut *self.outstanding_desc_count.lock();
61 *count += 1;
62 }
63
64 pub fn await_outstanding_descriptors(&self) -> event_listener::EventListener {
65 let (count, event) = &*self.outstanding_desc_count.lock();
66 let listener = event.listen();
67 if *count == 0 {
68 event.notify(usize::MAX);
69 }
70 listener
71 }
72
73 pub fn complete_descriptor(&mut self, descriptor_index: u16, bytes_written: u32) {
74 match self.core.complete_descriptor(
75 &mut self.last_used_index,
76 descriptor_index,
77 bytes_written,
78 ) {
79 Ok(true) => {
80 self.notify_guest.deliver();
81 }
82 Ok(false) => {}
83 Err(err) => {
84 tracelimit::error_ratelimited!(
85 error = &err as &dyn std::error::Error,
86 "failed to complete descriptor"
87 );
88 }
89 }
90 {
91 let (count, event) = &mut *self.outstanding_desc_count.lock();
92 *count -= 1;
93 if *count == 0 {
94 event.notify(usize::MAX);
95 }
96 }
97 }
98}
99
100pub struct VirtioQueueCallbackWork {
101 pub payload: Vec<VirtioQueuePayload>,
102 used_queue_handler: Arc<Mutex<VirtioQueueUsedHandler>>,
103 descriptor_index: u16,
104 completed: bool,
105}
106
107impl VirtioQueueCallbackWork {
108 pub fn new(
109 payload: Vec<VirtioQueuePayload>,
110 used_queue_handler: &Arc<Mutex<VirtioQueueUsedHandler>>,
111 descriptor_index: u16,
112 ) -> Self {
113 let used_queue_handler = used_queue_handler.clone();
114 used_queue_handler.lock().add_outstanding_descriptor();
115 Self {
116 payload,
117 used_queue_handler,
118 descriptor_index,
119 completed: false,
120 }
121 }
122
123 pub fn complete(&mut self, bytes_written: u32) {
124 assert!(!self.completed);
125 self.used_queue_handler
126 .lock()
127 .complete_descriptor(self.descriptor_index, bytes_written);
128 self.completed = true;
129 }
130
131 pub fn descriptor_index(&self) -> u16 {
132 self.descriptor_index
133 }
134
135 pub fn get_payload_length(&self, writeable: bool) -> u64 {
137 self.payload
138 .iter()
139 .filter(|x| x.writeable == writeable)
140 .fold(0, |acc, x| acc + x.length as u64)
141 }
142
143 pub fn read(&self, mem: &GuestMemory, target: &mut [u8]) -> Result<usize, GuestMemoryError> {
145 let mut remaining = target;
146 let mut read_bytes: usize = 0;
147 for payload in &self.payload {
148 if payload.writeable {
149 continue;
150 }
151
152 let size = std::cmp::min(payload.length as usize, remaining.len());
153 let (current, next) = remaining.split_at_mut(size);
154 mem.read_at(payload.address, current)?;
155 read_bytes += size;
156 if next.is_empty() {
157 break;
158 }
159
160 remaining = next;
161 }
162
163 Ok(read_bytes)
164 }
165
166 pub fn write_at_offset(
168 &self,
169 offset: u64,
170 mem: &GuestMemory,
171 source: &[u8],
172 ) -> Result<(), VirtioWriteError> {
173 let mut skip_bytes = offset;
174 let mut remaining = source;
175 for payload in &self.payload {
176 if !payload.writeable {
177 continue;
178 }
179
180 let payload_length = payload.length as u64;
181 if skip_bytes >= payload_length {
182 skip_bytes -= payload_length;
183 continue;
184 }
185
186 let size = std::cmp::min(
187 payload_length as usize - skip_bytes as usize,
188 remaining.len(),
189 );
190 let (current, next) = remaining.split_at(size);
191 mem.write_at(payload.address + skip_bytes, current)?;
192 remaining = next;
193 if remaining.is_empty() {
194 break;
195 }
196 skip_bytes = 0;
197 }
198
199 if !remaining.is_empty() {
200 return Err(VirtioWriteError::NotAllWritten(source.len()));
201 }
202
203 Ok(())
204 }
205
206 pub fn write(&self, mem: &GuestMemory, source: &[u8]) -> Result<(), VirtioWriteError> {
207 self.write_at_offset(0, mem, source)
208 }
209}
210
211#[derive(Debug, Error)]
212pub enum VirtioWriteError {
213 #[error(transparent)]
214 Memory(#[from] GuestMemoryError),
215 #[error("{0:#x} bytes not written")]
216 NotAllWritten(usize),
217}
218
219impl Drop for VirtioQueueCallbackWork {
220 fn drop(&mut self) {
221 if !self.completed {
222 self.complete(0);
223 }
224 }
225}
226
227#[derive(Debug)]
228pub struct VirtioQueue {
229 core: QueueCore,
230 last_avail_index: u16,
231 used_handler: Arc<Mutex<VirtioQueueUsedHandler>>,
232 queue_event: PolledWait<Event>,
233}
234
235impl VirtioQueue {
236 pub fn new(
237 features: u64,
238 params: QueueParams,
239 mem: GuestMemory,
240 notify: Interrupt,
241 queue_event: PolledWait<Event>,
242 ) -> Result<Self, QueueError> {
243 let core = QueueCore::new(features, mem, params)?;
244 let used_handler = Arc::new(Mutex::new(VirtioQueueUsedHandler::new(
245 core.clone(),
246 notify,
247 )));
248 Ok(Self {
249 core,
250 last_avail_index: 0,
251 used_handler,
252 queue_event,
253 })
254 }
255
256 async fn wait_for_outstanding_descriptors(&self) {
257 let wait_for_descriptors = self.used_handler.lock().await_outstanding_descriptors();
258 wait_for_descriptors.await;
259 }
260
261 fn poll_next_buffer(
262 &mut self,
263 cx: &mut Context<'_>,
264 ) -> Poll<Result<Option<VirtioQueueCallbackWork>, QueueError>> {
265 let descriptor_index = loop {
266 if let Some(descriptor_index) = self.core.descriptor_index(self.last_avail_index)? {
267 break descriptor_index;
268 };
269 ready!(self.queue_event.wait().poll_unpin(cx)).expect("waits on Event cannot fail");
270 };
271 let payload = self
272 .core
273 .reader(descriptor_index)
274 .collect::<Result<Vec<_>, _>>()?;
275
276 self.last_avail_index = self.last_avail_index.wrapping_add(1);
277 Poll::Ready(Ok(Some(VirtioQueueCallbackWork::new(
278 payload,
279 &self.used_handler,
280 descriptor_index,
281 ))))
282 }
283}
284
285impl Drop for VirtioQueue {
286 fn drop(&mut self) {
287 if Arc::get_mut(&mut self.used_handler).is_none() {
288 tracing::error!("Virtio queue dropped with outstanding work pending")
289 }
290 }
291}
292
293impl Stream for VirtioQueue {
294 type Item = Result<VirtioQueueCallbackWork, Error>;
295
296 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
297 let Some(r) = ready!(self.get_mut().poll_next_buffer(cx)).transpose() else {
298 return Poll::Ready(None);
299 };
300
301 Poll::Ready(Some(
302 r.map_err(|err| Error::new(std::io::ErrorKind::Other, err)),
303 ))
304 }
305}
306
307enum VirtioQueueStateInner {
308 Initializing {
309 mem: GuestMemory,
310 features: u64,
311 params: QueueParams,
312 event: Event,
313 notify: Interrupt,
314 exit_event: event_listener::EventListener,
315 },
316 InitializationInProgress,
317 Running {
318 queue: VirtioQueue,
319 exit_event: event_listener::EventListener,
320 },
321}
322
323pub struct VirtioQueueState {
324 inner: VirtioQueueStateInner,
325}
326
327pub struct VirtioQueueWorker {
328 driver: Box<dyn Driver>,
329 context: Box<dyn VirtioQueueWorkerContext + Send>,
330}
331
332impl VirtioQueueWorker {
333 pub fn new(driver: impl Driver, context: Box<dyn VirtioQueueWorkerContext + Send>) -> Self {
334 Self {
335 driver: Box::new(driver),
336 context,
337 }
338 }
339
340 pub fn into_running_task(
341 self,
342 name: impl Into<String>,
343 mem: GuestMemory,
344 features: u64,
345 queue_resources: QueueResources,
346 exit_event: event_listener::EventListener,
347 ) -> TaskControl<VirtioQueueWorker, VirtioQueueState> {
348 let name = name.into();
349 let (_, driver) = DefaultPool::spawn_on_thread(&name);
350
351 let mut task = TaskControl::new(self);
352 task.insert(
353 driver,
354 name,
355 VirtioQueueState {
356 inner: VirtioQueueStateInner::Initializing {
357 mem,
358 features,
359 params: queue_resources.params,
360 event: queue_resources.event,
361 notify: queue_resources.notify,
362 exit_event,
363 },
364 },
365 );
366 task.start();
367 task
368 }
369
370 async fn run_queue(&mut self, state: &mut VirtioQueueState) -> bool {
371 match &mut state.inner {
372 VirtioQueueStateInner::InitializationInProgress => unreachable!(),
373 VirtioQueueStateInner::Initializing { .. } => {
374 let VirtioQueueStateInner::Initializing {
375 mem,
376 features,
377 params,
378 event,
379 notify,
380 exit_event,
381 } = std::mem::replace(
382 &mut state.inner,
383 VirtioQueueStateInner::InitializationInProgress,
384 )
385 else {
386 unreachable!()
387 };
388 let queue_event = PolledWait::new(&self.driver, event).unwrap();
389 let queue = VirtioQueue::new(features, params, mem, notify, queue_event);
390 if let Err(err) = queue {
391 tracing::error!(
392 err = &err as &dyn std::error::Error,
393 "Failed to start queue"
394 );
395 false
396 } else {
397 state.inner = VirtioQueueStateInner::Running {
398 queue: queue.unwrap(),
399 exit_event,
400 };
401 true
402 }
403 }
404 VirtioQueueStateInner::Running { queue, exit_event } => {
405 let mut exit = exit_event.fuse();
406 let mut queue_ready = queue.next().fuse();
407 let work = futures::select_biased! {
408 _ = exit => return false,
409 work = queue_ready => work.expect("queue will never complete").map_err(anyhow::Error::from),
410 };
411 self.context.process_work(work).await
412 }
413 }
414 }
415}
416
417impl AsyncRun<VirtioQueueState> for VirtioQueueWorker {
418 async fn run(
419 &mut self,
420 stop: &mut StopTask<'_>,
421 state: &mut VirtioQueueState,
422 ) -> Result<(), task_control::Cancelled> {
423 while stop.until_stopped(self.run_queue(state)).await? {}
424 Ok(())
425 }
426}
427
428pub struct VirtioRunningState {
429 pub features: u64,
430 pub enabled_queues: Vec<bool>,
431}
432
433pub enum VirtioState {
434 Unknown,
435 Running(VirtioRunningState),
436 Stopped,
437}
438
439pub(crate) struct VirtioDoorbells {
440 registration: Option<Arc<dyn DoorbellRegistration>>,
441 doorbells: Vec<Box<dyn Send + Sync>>,
442}
443
444impl VirtioDoorbells {
445 pub fn new(registration: Option<Arc<dyn DoorbellRegistration>>) -> Self {
446 Self {
447 registration,
448 doorbells: Vec::new(),
449 }
450 }
451
452 pub fn add(&mut self, address: u64, value: Option<u64>, length: Option<u32>, event: &Event) {
453 if let Some(registration) = &mut self.registration {
454 let doorbell = registration.register_doorbell(address, value, length, event);
455 if let Ok(doorbell) = doorbell {
456 self.doorbells.push(doorbell);
457 }
458 }
459 }
460
461 pub fn clear(&mut self) {
462 self.doorbells.clear();
463 }
464}
465
466#[derive(Copy, Clone, Debug, Default)]
467pub struct DeviceTraitsSharedMemory {
468 pub id: u8,
469 pub size: u64,
470}
471
472#[derive(Copy, Clone, Debug, Default)]
473pub struct DeviceTraits {
474 pub device_id: u16,
475 pub device_features: u64,
476 pub max_queues: u16,
477 pub device_register_length: u32,
478 pub shared_memory: DeviceTraitsSharedMemory,
479}
480
481pub trait LegacyVirtioDevice: Send {
482 fn traits(&self) -> DeviceTraits;
483 fn read_registers_u32(&self, offset: u16) -> u32;
484 fn write_registers_u32(&mut self, offset: u16, val: u32);
485 fn get_work_callback(&mut self, index: u16) -> Box<dyn VirtioQueueWorkerContext + Send>;
486 fn state_change(&mut self, state: &VirtioState);
487}
488
489pub trait VirtioDevice: Send {
490 fn traits(&self) -> DeviceTraits;
491 fn read_registers_u32(&self, offset: u16) -> u32;
492 fn write_registers_u32(&mut self, offset: u16, val: u32);
493 fn enable(&mut self, resources: Resources);
494 fn disable(&mut self);
495}
496
497pub struct QueueResources {
498 pub params: QueueParams,
499 pub notify: Interrupt,
500 pub event: Event,
501}
502
503pub struct Resources {
504 pub features: u64,
505 pub queues: Vec<QueueResources>,
506 pub shared_memory_region: Option<Arc<dyn MappedMemoryRegion>>,
507 pub shared_memory_size: u64,
508}
509
510pub struct LegacyWrapper<T: LegacyVirtioDevice> {
512 device: T,
513 driver: VmTaskDriver,
514 mem: GuestMemory,
515 workers: Vec<TaskControl<VirtioQueueWorker, VirtioQueueState>>,
516 exit_event: event_listener::Event,
517}
518
519impl<T: LegacyVirtioDevice> LegacyWrapper<T> {
520 pub fn new(driver_source: &VmTaskDriverSource, device: T, mem: &GuestMemory) -> Self {
521 Self {
522 device,
523 driver: driver_source.simple(),
524 mem: mem.clone(),
525 workers: Vec::new(),
526 exit_event: event_listener::Event::new(),
527 }
528 }
529}
530
531impl<T: LegacyVirtioDevice> VirtioDevice for LegacyWrapper<T> {
532 fn traits(&self) -> DeviceTraits {
533 self.device.traits()
534 }
535
536 fn read_registers_u32(&self, offset: u16) -> u32 {
537 self.device.read_registers_u32(offset)
538 }
539
540 fn write_registers_u32(&mut self, offset: u16, val: u32) {
541 self.device.write_registers_u32(offset, val)
542 }
543
544 fn enable(&mut self, resources: Resources) {
545 let running_state = VirtioRunningState {
546 features: resources.features,
547 enabled_queues: resources
548 .queues
549 .iter()
550 .map(|QueueResources { params, .. }| params.enable)
551 .collect(),
552 };
553
554 self.device
555 .state_change(&VirtioState::Running(running_state));
556 self.workers = resources
557 .queues
558 .into_iter()
559 .enumerate()
560 .filter_map(|(i, queue_resources)| {
561 if !queue_resources.params.enable {
562 return None;
563 }
564 let worker = VirtioQueueWorker::new(
565 self.driver.clone(),
566 self.device.get_work_callback(i as u16),
567 );
568 Some(worker.into_running_task(
569 "virtio-queue".to_string(),
570 self.mem.clone(),
571 resources.features,
572 queue_resources,
573 self.exit_event.listen(),
574 ))
575 })
576 .collect();
577 }
578
579 fn disable(&mut self) {
580 if self.workers.is_empty() {
581 return;
582 }
583 self.exit_event.notify(usize::MAX);
584 self.device.state_change(&VirtioState::Stopped);
585 let mut workers = self.workers.drain(..).collect::<Vec<_>>();
586 self.driver
587 .spawn("shutdown-legacy-virtio-queues".to_owned(), async move {
588 futures::future::join_all(workers.iter_mut().map(async |worker| {
589 worker.stop().await;
590 if let Some(VirtioQueueStateInner::Running { queue, .. }) =
591 worker.state_mut().map(|s| &s.inner)
592 {
593 queue.wait_for_outstanding_descriptors().await;
594 }
595 }))
596 .await;
597 })
598 .detach();
599 }
600}
601
602impl<T: LegacyVirtioDevice> Drop for LegacyWrapper<T> {
603 fn drop(&mut self) {
604 self.disable();
605 }
606}
607
608#[expect(unsafe_code)]
611#[cfg(test)]
612mod tests {
613 use super::*;
614 use crate::PciInterruptModel;
615 use crate::spec::pci::*;
616 use crate::spec::queue::*;
617 use crate::spec::*;
618 use crate::transport::VirtioMmioDevice;
619 use crate::transport::VirtioPciDevice;
620 use chipset_device::mmio::ExternallyManagedMmioIntercepts;
621 use chipset_device::mmio::MmioIntercept;
622 use chipset_device::pci::PciConfigSpace;
623 use futures::StreamExt;
624 use guestmem::GuestMemoryAccess;
625 use guestmem::GuestMemoryBackingError;
626 use pal_async::DefaultDriver;
627 use pal_async::async_test;
628 use pal_async::timer::PolledTimer;
629 use pci_core::msi::MsiInterruptSet;
630 use pci_core::spec::caps::CapabilityId;
631 use pci_core::spec::cfg_space;
632 use pci_core::test_helpers::TestPciInterruptController;
633 use std::collections::BTreeMap;
634 use std::future::poll_fn;
635 use std::io;
636 use std::ptr::NonNull;
637 use std::time::Duration;
638 use test_with_tracing::test;
639 use vmcore::line_interrupt::LineInterrupt;
640 use vmcore::line_interrupt::test_helpers::TestLineInterruptTarget;
641 use vmcore::vm_task::SingleDriverBackend;
642 use vmcore::vm_task::VmTaskDriverSource;
643
644 async fn must_recv_in_timeout<T: 'static + Send>(
645 recv: &mut mesh::Receiver<T>,
646 timeout: Duration,
647 ) -> T {
648 mesh::CancelContext::new()
649 .with_timeout(timeout)
650 .until_cancelled(recv.next())
651 .await
652 .unwrap()
653 .unwrap()
654 }
655
656 #[derive(Default)]
657 struct VirtioTestMemoryAccess {
658 memory_map: Mutex<MemoryMap>,
659 }
660
661 #[derive(Default)]
662 struct MemoryMap {
663 map: BTreeMap<u64, (bool, Vec<u8>)>,
664 }
665
666 impl MemoryMap {
667 fn get(&mut self, address: u64, len: usize) -> Option<(bool, &mut [u8])> {
668 let (&base, &mut (writable, ref mut data)) = self.map.range_mut(..=address).last()?;
669 let data = data
670 .get_mut(usize::try_from(address - base).ok()?..)?
671 .get_mut(..len)?;
672
673 Some((writable, data))
674 }
675
676 fn insert(&mut self, address: u64, data: &[u8], writable: bool) {
677 if let Some((is_writable, v)) = self.get(address, data.len()) {
678 assert_eq!(writable, is_writable);
679 v.copy_from_slice(data);
680 return;
681 }
682
683 let end = address + data.len() as u64;
684 let mut data = data.to_vec();
685 if let Some((&next, &(next_writable, ref next_data))) = self.map.range(address..).next()
686 {
687 if end > next {
688 let next_end = next + next_data.len() as u64;
689 panic!(
690 "overlapping memory map: {address:#x}..{end:#x} > {next:#x}..={next_end:#x}"
691 );
692 }
693 if end == next && next_writable == writable {
694 data.extend(next_data.as_slice());
695 self.map.remove(&next).unwrap();
696 }
697 }
698
699 if let Some((&prev, &mut (prev_writable, ref mut prev_data))) =
700 self.map.range_mut(..address).last()
701 {
702 let prev_end = prev + prev_data.len() as u64;
703 if prev_end > address {
704 panic!(
705 "overlapping memory map: {prev:#x}..{prev_end:#x} > {address:#x}..={end:#x}"
706 );
707 }
708 if prev_end == address && prev_writable == writable {
709 prev_data.extend_from_slice(&data);
710 return;
711 }
712 }
713
714 self.map.insert(address, (writable, data));
715 }
716 }
717
718 impl VirtioTestMemoryAccess {
719 fn new() -> Arc<Self> {
720 Default::default()
721 }
722
723 fn modify_memory_map(&self, address: u64, data: &[u8], writeable: bool) {
724 self.memory_map.lock().insert(address, data, writeable);
725 }
726
727 fn memory_map_get_u16(&self, address: u64) -> u16 {
728 let mut map = self.memory_map.lock();
729 let (_, data) = map.get(address, 2).unwrap();
730 u16::from_le_bytes(data.try_into().unwrap())
731 }
732
733 fn memory_map_get_u32(&self, address: u64) -> u32 {
734 let mut map = self.memory_map.lock();
735 let (_, data) = map.get(address, 4).unwrap();
736 u32::from_le_bytes(data.try_into().unwrap())
737 }
738 }
739
740 unsafe impl GuestMemoryAccess for VirtioTestMemoryAccess {
742 fn mapping(&self) -> Option<NonNull<u8>> {
743 None
744 }
745
746 fn max_address(&self) -> u64 {
747 1 << 52
750 }
751
752 unsafe fn read_fallback(
753 &self,
754 address: u64,
755 dest: *mut u8,
756 len: usize,
757 ) -> Result<(), GuestMemoryBackingError> {
758 match self.memory_map.lock().get(address, len) {
759 Some((_, value)) => {
760 unsafe {
762 std::ptr::copy(value.as_ptr(), dest, len);
763 }
764 }
765 None => panic!("Unexpected read request at address {:x}", address),
766 }
767 Ok(())
768 }
769
770 unsafe fn write_fallback(
771 &self,
772 address: u64,
773 src: *const u8,
774 len: usize,
775 ) -> Result<(), GuestMemoryBackingError> {
776 match self.memory_map.lock().get(address, len) {
777 Some((true, value)) => {
778 unsafe {
780 std::ptr::copy(src, value.as_mut_ptr(), len);
781 }
782 }
783 _ => panic!("Unexpected write request at address {:x}", address),
784 }
785 Ok(())
786 }
787
788 fn fill_fallback(
789 &self,
790 address: u64,
791 val: u8,
792 len: usize,
793 ) -> Result<(), GuestMemoryBackingError> {
794 match self.memory_map.lock().get(address, len) {
795 Some((true, value)) => value.fill(val),
796 _ => panic!("Unexpected write request at address {:x}", address),
797 };
798 Ok(())
799 }
800 }
801
802 struct DoorbellEntry;
803 impl Drop for DoorbellEntry {
804 fn drop(&mut self) {}
805 }
806
807 impl DoorbellRegistration for VirtioTestMemoryAccess {
808 fn register_doorbell(
809 &self,
810 _: u64,
811 _: Option<u64>,
812 _: Option<u32>,
813 _: &Event,
814 ) -> io::Result<Box<dyn Send + Sync>> {
815 Ok(Box::new(DoorbellEntry))
816 }
817 }
818
819 type VirtioTestWorkCallback =
820 Box<dyn Fn(anyhow::Result<VirtioQueueCallbackWork>) -> bool + Sync + Send>;
821 struct CreateDirectQueueParams {
822 process_work: VirtioTestWorkCallback,
823 notify: Interrupt,
824 event: Event,
825 }
826
827 struct VirtioTestGuest {
828 test_mem: Arc<VirtioTestMemoryAccess>,
829 driver: DefaultDriver,
830 num_queues: u16,
831 queue_size: u16,
832 use_ring_event_index: bool,
833 last_avail_index: Vec<u16>,
834 last_used_index: Vec<u16>,
835 avail_descriptors: Vec<Vec<bool>>,
836 exit_event: event_listener::Event,
837 }
838
839 impl VirtioTestGuest {
840 fn new(
841 driver: &DefaultDriver,
842 test_mem: &Arc<VirtioTestMemoryAccess>,
843 num_queues: u16,
844 queue_size: u16,
845 use_ring_event_index: bool,
846 ) -> Self {
847 let last_avail_index: Vec<u16> = vec![0; num_queues as usize];
848 let last_used_index: Vec<u16> = vec![0; num_queues as usize];
849 let avail_descriptors: Vec<Vec<bool>> =
850 vec![vec![true; queue_size as usize]; num_queues as usize];
851 let test_guest = Self {
852 test_mem: test_mem.clone(),
853 driver: driver.clone(),
854 num_queues,
855 queue_size,
856 use_ring_event_index,
857 last_avail_index,
858 last_used_index,
859 avail_descriptors,
860 exit_event: event_listener::Event::new(),
861 };
862 for i in 0..num_queues {
863 test_guest.add_queue_memory(i);
864 }
865 test_guest
866 }
867
868 fn mem(&self) -> GuestMemory {
869 GuestMemory::new("test", self.test_mem.clone())
870 }
871
872 fn create_direct_queues<F>(
873 &self,
874 f: F,
875 ) -> Vec<TaskControl<VirtioQueueWorker, VirtioQueueState>>
876 where
877 F: Fn(u16) -> CreateDirectQueueParams,
878 {
879 (0..self.num_queues)
880 .map(|i| {
881 let params = f(i);
882 let worker = VirtioQueueWorker::new(
883 self.driver.clone(),
884 Box::new(VirtioTestWork {
885 callback: params.process_work,
886 }),
887 );
888 worker.into_running_task(
889 "virtio-test-queue".to_string(),
890 self.mem(),
891 self.queue_features(),
892 QueueResources {
893 params: self.queue_params(i),
894 notify: params.notify,
895 event: params.event,
896 },
897 self.exit_event.listen(),
898 )
899 })
900 .collect::<Vec<_>>()
901 }
902
903 fn queue_features(&self) -> u64 {
904 if self.use_ring_event_index {
905 VIRTIO_F_RING_EVENT_IDX as u64
906 } else {
907 0
908 }
909 }
910
911 fn queue_params(&self, i: u16) -> QueueParams {
912 QueueParams {
913 size: self.queue_size,
914 enable: true,
915 desc_addr: self.get_queue_descriptor_base_address(i),
916 avail_addr: self.get_queue_available_base_address(i),
917 used_addr: self.get_queue_used_base_address(i),
918 }
919 }
920
921 fn get_queue_base_address(&self, index: u16) -> u64 {
922 0x10000000 * index as u64
923 }
924
925 fn get_queue_descriptor_base_address(&self, index: u16) -> u64 {
926 self.get_queue_base_address(index) + 0x1000
927 }
928
929 fn get_queue_available_base_address(&self, index: u16) -> u64 {
930 self.get_queue_base_address(index) + 0x2000
931 }
932
933 fn get_queue_used_base_address(&self, index: u16) -> u64 {
934 self.get_queue_base_address(index) + 0x3000
935 }
936
937 fn get_queue_descriptor_backing_memory_address(&self, index: u16) -> u64 {
938 self.get_queue_base_address(index) + 0x4000
939 }
940
941 fn setup_chipset_device(&self, dev: &mut VirtioMmioDevice, driver_features: u64) {
942 dev.write_u32(112, VIRTIO_ACKNOWLEDGE);
943 dev.write_u32(112, VIRTIO_DRIVER);
944 dev.write_u32(36, 0);
945 dev.write_u32(32, driver_features as u32);
946 dev.write_u32(36, 1);
947 dev.write_u32(32, (driver_features >> 32) as u32);
948 dev.write_u32(112, VIRTIO_FEATURES_OK);
949 for i in 0..self.num_queues {
950 let queue_index = i;
951 dev.write_u32(48, i as u32);
952 dev.write_u32(56, self.queue_size as u32);
953 let desc_addr = self.get_queue_descriptor_base_address(queue_index);
954 dev.write_u32(128, desc_addr as u32);
955 dev.write_u32(132, (desc_addr >> 32) as u32);
956 let avail_addr = self.get_queue_available_base_address(queue_index);
957 dev.write_u32(144, avail_addr as u32);
958 dev.write_u32(148, (avail_addr >> 32) as u32);
959 let used_addr = self.get_queue_used_base_address(queue_index);
960 dev.write_u32(160, used_addr as u32);
961 dev.write_u32(164, (used_addr >> 32) as u32);
962 dev.write_u32(68, 1);
964 }
965 dev.write_u32(112, VIRTIO_DRIVER_OK);
966 assert_eq!(dev.read_u32(0xfc), 2);
967 }
968
969 fn setup_pci_device(&self, dev: &mut VirtioPciTestDevice, driver_features: u64) {
970 let bar_address1: u64 = 0x10000000000;
971 dev.pci_device
972 .pci_cfg_write(0x14, (bar_address1 >> 32) as u32)
973 .unwrap();
974 dev.pci_device
975 .pci_cfg_write(0x10, bar_address1 as u32)
976 .unwrap();
977
978 let bar_address2: u64 = 0x20000000000;
979 dev.pci_device
980 .pci_cfg_write(0x1c, (bar_address2 >> 32) as u32)
981 .unwrap();
982 dev.pci_device
983 .pci_cfg_write(0x18, bar_address2 as u32)
984 .unwrap();
985
986 dev.pci_device
987 .pci_cfg_write(
988 0x4,
989 cfg_space::Command::new()
990 .with_mmio_enabled(true)
991 .into_bits() as u32,
992 )
993 .unwrap();
994
995 let mut device_status = VIRTIO_ACKNOWLEDGE as u8;
996 dev.pci_device
997 .mmio_write(bar_address1 + 20, &device_status.to_le_bytes())
998 .unwrap();
999 device_status = VIRTIO_DRIVER as u8;
1000 dev.pci_device
1001 .mmio_write(bar_address1 + 20, &device_status.to_le_bytes())
1002 .unwrap();
1003 dev.write_u32(bar_address1 + 8, 0);
1004 dev.write_u32(bar_address1 + 12, driver_features as u32);
1005 dev.write_u32(bar_address1 + 8, 1);
1006 dev.write_u32(bar_address1 + 12, (driver_features >> 32) as u32);
1007 device_status = VIRTIO_FEATURES_OK as u8;
1008 dev.pci_device
1009 .mmio_write(bar_address1 + 20, &device_status.to_le_bytes())
1010 .unwrap();
1011 dev.pci_device
1013 .mmio_write(bar_address2, &0_u64.to_le_bytes())
1014 .unwrap(); dev.pci_device
1016 .mmio_write(bar_address2 + 8, &0_u32.to_le_bytes())
1017 .unwrap(); dev.pci_device
1019 .mmio_write(bar_address2 + 12, &0_u32.to_le_bytes())
1020 .unwrap();
1021 for i in 0..self.num_queues {
1022 let queue_index = i;
1023 dev.pci_device
1024 .mmio_write(bar_address1 + 22, &queue_index.to_le_bytes())
1025 .unwrap();
1026 dev.pci_device
1027 .mmio_write(bar_address1 + 24, &self.queue_size.to_le_bytes())
1028 .unwrap();
1029 let msix_vector = queue_index + 1;
1031 let address = bar_address2 + 0x10 * msix_vector as u64;
1032 dev.pci_device
1033 .mmio_write(address, &(msix_vector as u64).to_le_bytes())
1034 .unwrap();
1035 let address = bar_address2 + 0x10 * msix_vector as u64 + 8;
1036 dev.pci_device
1037 .mmio_write(address, &0_u32.to_le_bytes())
1038 .unwrap();
1039 let address = bar_address2 + 0x10 * msix_vector as u64 + 12;
1040 dev.pci_device
1041 .mmio_write(address, &0_u32.to_le_bytes())
1042 .unwrap();
1043 dev.pci_device
1044 .mmio_write(bar_address1 + 26, &msix_vector.to_le_bytes())
1045 .unwrap();
1046 let desc_addr = self.get_queue_descriptor_base_address(queue_index);
1048 dev.write_u32(bar_address1 + 32, desc_addr as u32);
1049 dev.write_u32(bar_address1 + 36, (desc_addr >> 32) as u32);
1050 let avail_addr = self.get_queue_available_base_address(queue_index);
1051 dev.write_u32(bar_address1 + 40, avail_addr as u32);
1052 dev.write_u32(bar_address1 + 44, (avail_addr >> 32) as u32);
1053 let used_addr = self.get_queue_used_base_address(queue_index);
1054 dev.write_u32(bar_address1 + 48, used_addr as u32);
1055 dev.write_u32(bar_address1 + 52, (used_addr >> 32) as u32);
1056 let enabled: u16 = 1;
1058 dev.pci_device
1059 .mmio_write(bar_address1 + 28, &enabled.to_le_bytes())
1060 .unwrap();
1061 }
1062 dev.pci_device.pci_cfg_write(0x40, 0x80000000).unwrap();
1064 device_status = VIRTIO_DRIVER_OK as u8;
1066 dev.pci_device
1067 .mmio_write(bar_address1 + 20, &device_status.to_le_bytes())
1068 .unwrap();
1069 let mut config_generation: [u8; 1] = [0];
1070 dev.pci_device
1071 .mmio_read(bar_address1 + 21, &mut config_generation)
1072 .unwrap();
1073 assert_eq!(config_generation[0], 2);
1074 }
1075
1076 fn get_queue_descriptor(&self, queue_index: u16, descriptor_index: u16) -> u64 {
1077 self.get_queue_descriptor_base_address(queue_index) + 0x10 * descriptor_index as u64
1078 }
1079
1080 fn add_queue_memory(&self, queue_index: u16) {
1081 for i in 0..self.queue_size {
1083 let base = self.get_queue_descriptor(queue_index, i);
1084 self.test_mem.modify_memory_map(
1086 base,
1087 &(self.get_queue_descriptor_backing_memory_address(queue_index)
1088 + 0x1000 * i as u64)
1089 .to_le_bytes(),
1090 false,
1091 );
1092 self.test_mem
1094 .modify_memory_map(base + 8, &0x1000u32.to_le_bytes(), false);
1095 self.test_mem
1097 .modify_memory_map(base + 12, &0u16.to_le_bytes(), false);
1098 self.test_mem
1100 .modify_memory_map(base + 14, &0u16.to_le_bytes(), false);
1101 }
1102
1103 let base = self.get_queue_available_base_address(queue_index);
1105 self.test_mem
1106 .modify_memory_map(base, &0u16.to_le_bytes(), false);
1107 self.test_mem
1108 .modify_memory_map(base + 2, &0u16.to_le_bytes(), false);
1109 for i in 0..self.queue_size {
1111 let base = base + 4 + 2 * i as u64;
1112 self.test_mem
1113 .modify_memory_map(base, &0u16.to_le_bytes(), false);
1114 }
1115 if self.use_ring_event_index {
1117 self.test_mem.modify_memory_map(
1118 base + 4 + 2 * self.queue_size as u64,
1119 &0u16.to_le_bytes(),
1120 false,
1121 );
1122 }
1123
1124 let base = self.get_queue_used_base_address(queue_index);
1126 self.test_mem
1127 .modify_memory_map(base, &0u16.to_le_bytes(), true);
1128 self.test_mem
1129 .modify_memory_map(base + 2, &0u16.to_le_bytes(), true);
1130 for i in 0..self.queue_size {
1131 let base = base + 4 + 8 * i as u64;
1132 self.test_mem
1134 .modify_memory_map(base, &0u32.to_le_bytes(), true);
1135 self.test_mem
1137 .modify_memory_map(base + 4, &0u32.to_le_bytes(), true);
1138 }
1139 if self.use_ring_event_index {
1141 self.test_mem.modify_memory_map(
1142 base + 4 + 8 * self.queue_size as u64,
1143 &0u16.to_le_bytes(),
1144 true,
1145 );
1146 }
1147 }
1148
1149 fn reserve_descriptor(&mut self, queue_index: u16) -> u16 {
1150 let avail_descriptors = &mut self.avail_descriptors[queue_index as usize];
1151 for (i, desc) in avail_descriptors.iter_mut().enumerate() {
1152 if *desc {
1153 *desc = false;
1154 return i as u16;
1155 }
1156 }
1157
1158 panic!("No descriptors are available!");
1159 }
1160
1161 fn free_descriptor(&mut self, queue_index: u16, desc_index: u16) {
1162 assert!(desc_index < self.queue_size);
1163 let desc_addr = self.get_queue_descriptor(queue_index, desc_index);
1164 let flags: DescriptorFlags = self.test_mem.memory_map_get_u16(desc_addr + 12).into();
1165 if flags.next() {
1166 let next = self.test_mem.memory_map_get_u16(desc_addr + 14);
1167 self.free_descriptor(queue_index, next);
1168 }
1169 let avail_descriptors = &mut self.avail_descriptors[queue_index as usize];
1170 assert_eq!(avail_descriptors[desc_index as usize], false);
1171 avail_descriptors[desc_index as usize] = true;
1172 }
1173
1174 fn queue_available_desc(&mut self, queue_index: u16, desc_index: u16) {
1175 let avail_base_addr = self.get_queue_available_base_address(queue_index);
1176 let last_avail_index = &mut self.last_avail_index[queue_index as usize];
1177 let next_index = *last_avail_index % self.queue_size;
1178 *last_avail_index = last_avail_index.wrapping_add(1);
1179 self.test_mem.modify_memory_map(
1180 avail_base_addr + 4 + 2 * next_index as u64,
1181 &desc_index.to_le_bytes(),
1182 false,
1183 );
1184 self.test_mem.modify_memory_map(
1185 avail_base_addr + 2,
1186 &last_avail_index.to_le_bytes(),
1187 false,
1188 );
1189 }
1190
1191 fn add_to_avail_queue(&mut self, queue_index: u16) {
1192 let next_descriptor = self.reserve_descriptor(queue_index);
1193 self.test_mem.modify_memory_map(
1195 self.get_queue_descriptor(queue_index, next_descriptor) + 12,
1196 &0u16.to_le_bytes(),
1197 false,
1198 );
1199 self.queue_available_desc(queue_index, next_descriptor);
1200 }
1201
1202 fn add_indirect_to_avail_queue(&mut self, queue_index: u16) {
1203 let next_descriptor = self.reserve_descriptor(queue_index);
1204 self.test_mem.modify_memory_map(
1206 self.get_queue_descriptor(queue_index, next_descriptor) + 12,
1207 &u16::from(DescriptorFlags::new().with_indirect(true)).to_le_bytes(),
1208 false,
1209 );
1210 let buffer_addr = self.get_queue_descriptor_backing_memory_address(queue_index);
1212 self.test_mem.modify_memory_map(
1214 buffer_addr,
1215 &0xffffffff00000000u64.to_le_bytes(),
1216 false,
1217 );
1218 self.test_mem
1220 .modify_memory_map(buffer_addr + 8, &0x1000u32.to_le_bytes(), false);
1221 self.test_mem
1223 .modify_memory_map(buffer_addr + 12, &0u16.to_le_bytes(), false);
1224 self.test_mem
1226 .modify_memory_map(buffer_addr + 14, &0u16.to_le_bytes(), false);
1227 self.queue_available_desc(queue_index, next_descriptor);
1228 }
1229
1230 fn add_linked_to_avail_queue(&mut self, queue_index: u16, desc_count: u16) {
1231 let mut descriptors = Vec::with_capacity(desc_count as usize);
1232 for _ in 0..desc_count {
1233 descriptors.push(self.reserve_descriptor(queue_index));
1234 }
1235
1236 for i in 0..descriptors.len() {
1237 let base = self.get_queue_descriptor(queue_index, descriptors[i]);
1238 let flags = if i < descriptors.len() - 1 {
1239 u16::from(DescriptorFlags::new().with_next(true))
1240 } else {
1241 0
1242 };
1243 self.test_mem
1244 .modify_memory_map(base + 12, &flags.to_le_bytes(), false);
1245 let next = if i < descriptors.len() - 1 {
1246 descriptors[i + 1]
1247 } else {
1248 0
1249 };
1250 self.test_mem
1251 .modify_memory_map(base + 14, &next.to_le_bytes(), false);
1252 }
1253 self.queue_available_desc(queue_index, descriptors[0]);
1254 }
1255
1256 fn add_indirect_linked_to_avail_queue(&mut self, queue_index: u16, desc_count: u16) {
1257 let next_descriptor = self.reserve_descriptor(queue_index);
1258 self.test_mem.modify_memory_map(
1260 self.get_queue_descriptor(queue_index, next_descriptor) + 12,
1261 &u16::from(DescriptorFlags::new().with_indirect(true)).to_le_bytes(),
1262 false,
1263 );
1264 let buffer_addr = self.get_queue_descriptor_backing_memory_address(queue_index);
1266 for i in 0..desc_count {
1267 let base = buffer_addr + 0x10 * i as u64;
1268 let indirect_buffer_addr = 0xffffffff00000000u64 + 0x1000 * i as u64;
1269 self.test_mem
1271 .modify_memory_map(base, &indirect_buffer_addr.to_le_bytes(), false);
1272 self.test_mem
1274 .modify_memory_map(base + 8, &0x1000u32.to_le_bytes(), false);
1275 let flags = if i < desc_count - 1 {
1277 u16::from(DescriptorFlags::new().with_next(true))
1278 } else {
1279 0
1280 };
1281 self.test_mem
1282 .modify_memory_map(base + 12, &flags.to_le_bytes(), false);
1283 let next = if i < desc_count - 1 { i + 1 } else { 0 };
1285 self.test_mem
1286 .modify_memory_map(base + 14, &next.to_le_bytes(), false);
1287 }
1288 self.queue_available_desc(queue_index, next_descriptor);
1289 }
1290
1291 fn get_next_completed(&mut self, queue_index: u16) -> Option<(u16, u32)> {
1292 let avail_base_addr = self.get_queue_available_base_address(queue_index);
1293 let used_base_addr = self.get_queue_used_base_address(queue_index);
1294 let cur_used_index = self.test_mem.memory_map_get_u16(used_base_addr + 2);
1295 let last_used_index = &mut self.last_used_index[queue_index as usize];
1296 if *last_used_index == cur_used_index {
1297 return None;
1298 }
1299
1300 if self.use_ring_event_index {
1301 self.test_mem.modify_memory_map(
1302 avail_base_addr + 4 + 2 * self.queue_size as u64,
1303 &cur_used_index.to_le_bytes(),
1304 false,
1305 );
1306 }
1307
1308 let next_index = *last_used_index % self.queue_size;
1309 *last_used_index = last_used_index.wrapping_add(1);
1310 let desc_index = self
1311 .test_mem
1312 .memory_map_get_u32(used_base_addr + 4 + 8 * next_index as u64);
1313 let desc_index = desc_index as u16;
1314 let bytes_written = self
1315 .test_mem
1316 .memory_map_get_u32(used_base_addr + 8 + 8 * next_index as u64);
1317 self.free_descriptor(queue_index, desc_index);
1318 Some((desc_index, bytes_written))
1319 }
1320 }
1321
1322 struct VirtioTestWork {
1323 callback: VirtioTestWorkCallback,
1324 }
1325
1326 #[async_trait]
1327 impl VirtioQueueWorkerContext for VirtioTestWork {
1328 async fn process_work(&mut self, work: anyhow::Result<VirtioQueueCallbackWork>) -> bool {
1329 (self.callback)(work)
1330 }
1331 }
1332 struct VirtioPciTestDevice {
1333 pci_device: VirtioPciDevice,
1334 test_intc: Arc<TestPciInterruptController>,
1335 }
1336
1337 type TestDeviceQueueWorkFn = Arc<dyn Fn(u16, VirtioQueueCallbackWork) + Send + Sync>;
1338
1339 struct TestDevice {
1340 traits: DeviceTraits,
1341 queue_work: Option<TestDeviceQueueWorkFn>,
1342 }
1343
1344 impl TestDevice {
1345 fn new(traits: DeviceTraits, queue_work: Option<TestDeviceQueueWorkFn>) -> Self {
1346 Self { traits, queue_work }
1347 }
1348 }
1349
1350 impl LegacyVirtioDevice for TestDevice {
1351 fn traits(&self) -> DeviceTraits {
1352 self.traits
1353 }
1354
1355 fn read_registers_u32(&self, _offset: u16) -> u32 {
1356 0
1357 }
1358
1359 fn write_registers_u32(&mut self, _offset: u16, _val: u32) {}
1360
1361 fn get_work_callback(&mut self, index: u16) -> Box<dyn VirtioQueueWorkerContext + Send> {
1362 Box::new(TestDeviceWorker {
1363 index,
1364 queue_work: self.queue_work.clone(),
1365 })
1366 }
1367
1368 fn state_change(&mut self, _state: &VirtioState) {}
1369 }
1370
1371 struct TestDeviceWorker {
1372 index: u16,
1373 queue_work: Option<TestDeviceQueueWorkFn>,
1374 }
1375
1376 #[async_trait]
1377 impl VirtioQueueWorkerContext for TestDeviceWorker {
1378 async fn process_work(&mut self, work: anyhow::Result<VirtioQueueCallbackWork>) -> bool {
1379 if let Err(err) = work {
1380 panic!(
1381 "Invalid virtio queue state index {} error {}",
1382 self.index,
1383 err.as_ref() as &dyn std::error::Error
1384 );
1385 }
1386 if let Some(ref func) = self.queue_work {
1387 (func)(self.index, work.unwrap());
1388 }
1389 true
1390 }
1391 }
1392
1393 impl VirtioPciTestDevice {
1394 fn new(
1395 driver: &DefaultDriver,
1396 num_queues: u16,
1397 test_mem: &Arc<VirtioTestMemoryAccess>,
1398 queue_work: Option<TestDeviceQueueWorkFn>,
1399 ) -> Self {
1400 let doorbell_registration: Arc<dyn DoorbellRegistration> = test_mem.clone();
1401 let mem = GuestMemory::new("test", test_mem.clone());
1402 let mut msi_set = MsiInterruptSet::new();
1403
1404 let dev = VirtioPciDevice::new(
1405 Box::new(LegacyWrapper::new(
1406 &VmTaskDriverSource::new(SingleDriverBackend::new(driver.clone())),
1407 TestDevice::new(
1408 DeviceTraits {
1409 device_id: 3,
1410 device_features: 2,
1411 max_queues: num_queues,
1412 device_register_length: 12,
1413 ..Default::default()
1414 },
1415 queue_work,
1416 ),
1417 &mem,
1418 )),
1419 PciInterruptModel::Msix(&mut msi_set),
1420 Some(doorbell_registration),
1421 &mut ExternallyManagedMmioIntercepts,
1422 None,
1423 )
1424 .unwrap();
1425
1426 let test_intc = Arc::new(TestPciInterruptController::new());
1427 msi_set.connect(test_intc.as_ref());
1428
1429 Self {
1430 pci_device: dev,
1431 test_intc,
1432 }
1433 }
1434
1435 fn read_u32(&mut self, address: u64) -> u32 {
1436 let mut value = [0; 4];
1437 self.pci_device.mmio_read(address, &mut value).unwrap();
1438 u32::from_ne_bytes(value)
1439 }
1440
1441 fn write_u32(&mut self, address: u64, value: u32) {
1442 self.pci_device
1443 .mmio_write(address, &value.to_ne_bytes())
1444 .unwrap();
1445 }
1446 }
1447
1448 #[async_test]
1449 async fn verify_chipset_config(driver: DefaultDriver) {
1450 let mem = VirtioTestMemoryAccess::new();
1451 let doorbell_registration: Arc<dyn DoorbellRegistration> = mem.clone();
1452 let mem = GuestMemory::new("test", mem);
1453 let interrupt = LineInterrupt::detached();
1454
1455 let mut dev = VirtioMmioDevice::new(
1456 Box::new(LegacyWrapper::new(
1457 &VmTaskDriverSource::new(SingleDriverBackend::new(driver)),
1458 TestDevice::new(
1459 DeviceTraits {
1460 device_id: 3,
1461 device_features: 2,
1462 max_queues: 1,
1463 device_register_length: 0,
1464 ..Default::default()
1465 },
1466 None,
1467 ),
1468 &mem,
1469 )),
1470 interrupt,
1471 Some(doorbell_registration),
1472 0,
1473 1,
1474 );
1475 assert_eq!(dev.read_u32(0), u32::from_le_bytes(*b"virt"));
1477 assert_eq!(dev.read_u32(4), 2);
1479 assert_eq!(dev.read_u32(8), 3);
1481 assert_eq!(dev.read_u32(12), 0x1af4);
1483 assert_eq!(
1485 dev.read_u32(16),
1486 VIRTIO_F_RING_INDIRECT_DESC | VIRTIO_F_RING_EVENT_IDX | 2
1487 );
1488 assert_eq!(dev.read_u32(20), 0);
1490 dev.write_u32(20, 1);
1492 assert_eq!(dev.read_u32(20), 1);
1493 assert_eq!(dev.read_u32(16), VIRTIO_F_VERSION_1);
1494 dev.write_u32(20, 2);
1496 assert_eq!(dev.read_u32(16), 0);
1497 assert_eq!(dev.read_u32(32), 0);
1499 dev.write_u32(32, 2);
1500 assert_eq!(dev.read_u32(32), 2);
1501 dev.write_u32(32, 0xffffffff);
1502 assert_eq!(
1503 dev.read_u32(32),
1504 VIRTIO_F_RING_INDIRECT_DESC | VIRTIO_F_RING_EVENT_IDX | 2
1505 );
1506 assert_eq!(dev.read_u32(36), 0);
1508 dev.write_u32(36, 1);
1509 assert_eq!(dev.read_u32(36), 1);
1510 assert_eq!(dev.read_u32(32), 0);
1512 dev.write_u32(32, 0xffffffff);
1513 assert_eq!(dev.read_u32(32), VIRTIO_F_VERSION_1);
1514 dev.write_u32(36, 2);
1516 assert_eq!(dev.read_u32(32), 0);
1517 dev.write_u32(32, 0xffffffff);
1518 assert_eq!(dev.read_u32(32), 0);
1519 assert_eq!(dev.read_u32(80), 0);
1521 assert_eq!(dev.read_u32(96), 0);
1523 assert_eq!(dev.read_u32(100), 0);
1525 assert_eq!(dev.read_u32(112), 0);
1527 assert_eq!(dev.read_u32(0xfc), 0);
1529
1530 assert_eq!(dev.read_u32(48), 0);
1532 assert_eq!(dev.read_u32(52), 0x40);
1534 assert_eq!(dev.read_u32(56), 0x40);
1536 dev.write_u32(56, 0x20);
1537 assert_eq!(dev.read_u32(56), 0x20);
1538 assert_eq!(dev.read_u32(68), 0);
1540 dev.write_u32(68, 1);
1541 assert_eq!(dev.read_u32(68), 1);
1542 dev.write_u32(68, 0xffffffff);
1543 assert_eq!(dev.read_u32(68), 1);
1544 dev.write_u32(68, 0);
1545 assert_eq!(dev.read_u32(68), 0);
1546 assert_eq!(dev.read_u32(128), 0);
1548 dev.write_u32(128, 0xffff);
1549 assert_eq!(dev.read_u32(128), 0xffff);
1550 assert_eq!(dev.read_u32(132), 0);
1552 dev.write_u32(132, 1);
1553 assert_eq!(dev.read_u32(132), 1);
1554 assert_eq!(dev.read_u32(144), 0);
1556 dev.write_u32(144, 0xeeee);
1557 assert_eq!(dev.read_u32(144), 0xeeee);
1558 assert_eq!(dev.read_u32(148), 0);
1560 dev.write_u32(148, 2);
1561 assert_eq!(dev.read_u32(148), 2);
1562 assert_eq!(dev.read_u32(160), 0);
1564 dev.write_u32(160, 0xdddd);
1565 assert_eq!(dev.read_u32(160), 0xdddd);
1566 assert_eq!(dev.read_u32(164), 0);
1568 dev.write_u32(164, 3);
1569 assert_eq!(dev.read_u32(164), 3);
1570
1571 dev.write_u32(48, 1);
1573 assert_eq!(dev.read_u32(48), 1);
1574 assert_eq!(dev.read_u32(52), 0);
1576 assert_eq!(dev.read_u32(56), 0);
1578 dev.write_u32(56, 2);
1579 assert_eq!(dev.read_u32(56), 0);
1580 assert_eq!(dev.read_u32(68), 0);
1582 dev.write_u32(68, 1);
1583 assert_eq!(dev.read_u32(68), 0);
1584 assert_eq!(dev.read_u32(128), 0);
1586 dev.write_u32(128, 1);
1587 assert_eq!(dev.read_u32(128), 0);
1588 assert_eq!(dev.read_u32(132), 0);
1590 dev.write_u32(132, 1);
1591 assert_eq!(dev.read_u32(132), 0);
1592 assert_eq!(dev.read_u32(144), 0);
1594 dev.write_u32(144, 1);
1595 assert_eq!(dev.read_u32(144), 0);
1596 assert_eq!(dev.read_u32(148), 0);
1598 dev.write_u32(148, 1);
1599 assert_eq!(dev.read_u32(148), 0);
1600 assert_eq!(dev.read_u32(160), 0);
1602 dev.write_u32(160, 1);
1603 assert_eq!(dev.read_u32(160), 0);
1604 assert_eq!(dev.read_u32(164), 0);
1606 dev.write_u32(164, 1);
1607 assert_eq!(dev.read_u32(164), 0);
1608 }
1609
1610 #[async_test]
1611 async fn verify_pci_config(driver: DefaultDriver) {
1612 let mut pci_test_device =
1613 VirtioPciTestDevice::new(&driver, 1, &VirtioTestMemoryAccess::new(), None);
1614 let mut capabilities = 0;
1615 pci_test_device
1616 .pci_device
1617 .pci_cfg_read(4, &mut capabilities)
1618 .unwrap();
1619 assert_eq!(
1620 capabilities,
1621 (cfg_space::Status::new()
1622 .with_capabilities_list(true)
1623 .into_bits() as u32)
1624 << 16
1625 );
1626 let mut next_cap_offset = 0;
1627 pci_test_device
1628 .pci_device
1629 .pci_cfg_read(0x34, &mut next_cap_offset)
1630 .unwrap();
1631 assert_ne!(next_cap_offset, 0);
1632
1633 let mut header = 0;
1634 pci_test_device
1635 .pci_device
1636 .pci_cfg_read(next_cap_offset as u16, &mut header)
1637 .unwrap();
1638 let header = header.to_le_bytes();
1639 assert_eq!(header[0], CapabilityId::MSIX.0);
1640 next_cap_offset = header[1] as u32;
1641 assert_ne!(next_cap_offset, 0);
1642
1643 let mut header = 0;
1644 pci_test_device
1645 .pci_device
1646 .pci_cfg_read(next_cap_offset as u16, &mut header)
1647 .unwrap();
1648 let header = header.to_le_bytes();
1649 assert_eq!(header[0], CapabilityId::VENDOR_SPECIFIC.0);
1650 assert_eq!(header[3], VIRTIO_PCI_CAP_COMMON_CFG);
1651 assert_eq!(header[2], 16);
1652 let mut buf = 0;
1653
1654 pci_test_device
1655 .pci_device
1656 .pci_cfg_read(next_cap_offset as u16 + 4, &mut buf)
1657 .unwrap();
1658 assert_eq!(buf, 0);
1659 pci_test_device
1660 .pci_device
1661 .pci_cfg_read(next_cap_offset as u16 + 8, &mut buf)
1662 .unwrap();
1663 assert_eq!(buf, 0);
1664 pci_test_device
1665 .pci_device
1666 .pci_cfg_read(next_cap_offset as u16 + 12, &mut buf)
1667 .unwrap();
1668 assert_eq!(buf, 0x38);
1669 next_cap_offset = header[1] as u32;
1670 assert_ne!(next_cap_offset, 0);
1671
1672 let mut header = 0;
1673 pci_test_device
1674 .pci_device
1675 .pci_cfg_read(next_cap_offset as u16, &mut header)
1676 .unwrap();
1677 let header = header.to_le_bytes();
1678 assert_eq!(header[0], CapabilityId::VENDOR_SPECIFIC.0);
1679 assert_eq!(header[3], VIRTIO_PCI_CAP_NOTIFY_CFG);
1680 assert_eq!(header[2], 20);
1681 pci_test_device
1682 .pci_device
1683 .pci_cfg_read(next_cap_offset as u16 + 4, &mut buf)
1684 .unwrap();
1685 assert_eq!(buf, 0);
1686 pci_test_device
1687 .pci_device
1688 .pci_cfg_read(next_cap_offset as u16 + 8, &mut buf)
1689 .unwrap();
1690 assert_eq!(buf, 0x38);
1691 pci_test_device
1692 .pci_device
1693 .pci_cfg_read(next_cap_offset as u16 + 12, &mut buf)
1694 .unwrap();
1695 assert_eq!(buf, 4);
1696 next_cap_offset = header[1] as u32;
1697 assert_ne!(next_cap_offset, 0);
1698
1699 let mut header = 0;
1700 pci_test_device
1701 .pci_device
1702 .pci_cfg_read(next_cap_offset as u16, &mut header)
1703 .unwrap();
1704 let header = header.to_le_bytes();
1705 assert_eq!(header[0], CapabilityId::VENDOR_SPECIFIC.0);
1706 assert_eq!(header[3], VIRTIO_PCI_CAP_ISR_CFG);
1707 assert_eq!(header[2], 16);
1708 pci_test_device
1709 .pci_device
1710 .pci_cfg_read(next_cap_offset as u16 + 4, &mut buf)
1711 .unwrap();
1712 assert_eq!(buf, 0);
1713 pci_test_device
1714 .pci_device
1715 .pci_cfg_read(next_cap_offset as u16 + 8, &mut buf)
1716 .unwrap();
1717 assert_eq!(buf, 0x3c);
1718 pci_test_device
1719 .pci_device
1720 .pci_cfg_read(next_cap_offset as u16 + 12, &mut buf)
1721 .unwrap();
1722 assert_eq!(buf, 4);
1723 next_cap_offset = header[1] as u32;
1724 assert_ne!(next_cap_offset, 0);
1725
1726 let mut header = 0;
1727 pci_test_device
1728 .pci_device
1729 .pci_cfg_read(next_cap_offset as u16, &mut header)
1730 .unwrap();
1731 let header = header.to_le_bytes();
1732 assert_eq!(header[0], CapabilityId::VENDOR_SPECIFIC.0);
1733 assert_eq!(header[3], VIRTIO_PCI_CAP_DEVICE_CFG);
1734 assert_eq!(header[2], 16);
1735 pci_test_device
1736 .pci_device
1737 .pci_cfg_read(next_cap_offset as u16 + 4, &mut buf)
1738 .unwrap();
1739 assert_eq!(buf, 0);
1740 pci_test_device
1741 .pci_device
1742 .pci_cfg_read(next_cap_offset as u16 + 8, &mut buf)
1743 .unwrap();
1744 assert_eq!(buf, 0x40);
1745 pci_test_device
1746 .pci_device
1747 .pci_cfg_read(next_cap_offset as u16 + 12, &mut buf)
1748 .unwrap();
1749 assert_eq!(buf, 12);
1750 next_cap_offset = header[1] as u32;
1751 assert_eq!(next_cap_offset, 0);
1752 }
1753
1754 #[async_test]
1755 async fn verify_pci_registers(driver: DefaultDriver) {
1756 let mut pci_test_device =
1757 VirtioPciTestDevice::new(&driver, 1, &VirtioTestMemoryAccess::new(), None);
1758 let bar_address1: u64 = 0x2000000000;
1759 pci_test_device
1760 .pci_device
1761 .pci_cfg_write(0x14, (bar_address1 >> 32) as u32)
1762 .unwrap();
1763 pci_test_device
1764 .pci_device
1765 .pci_cfg_write(0x10, bar_address1 as u32)
1766 .unwrap();
1767
1768 let bar_address2: u64 = 0x4000;
1769 pci_test_device
1770 .pci_device
1771 .pci_cfg_write(0x1c, (bar_address2 >> 32) as u32)
1772 .unwrap();
1773 pci_test_device
1774 .pci_device
1775 .pci_cfg_write(0x18, bar_address2 as u32)
1776 .unwrap();
1777
1778 pci_test_device
1779 .pci_device
1780 .pci_cfg_write(
1781 0x4,
1782 cfg_space::Command::new()
1783 .with_mmio_enabled(true)
1784 .into_bits() as u32,
1785 )
1786 .unwrap();
1787
1788 assert_eq!(pci_test_device.read_u32(bar_address1), 0);
1790 assert_eq!(
1792 pci_test_device.read_u32(bar_address1 + 4),
1793 VIRTIO_F_RING_INDIRECT_DESC | VIRTIO_F_RING_EVENT_IDX | 2
1794 );
1795 pci_test_device.write_u32(bar_address1, 1);
1797 assert_eq!(pci_test_device.read_u32(bar_address1), 1);
1798 assert_eq!(
1799 pci_test_device.read_u32(bar_address1 + 4),
1800 VIRTIO_F_VERSION_1
1801 );
1802 pci_test_device.write_u32(bar_address1, 2);
1804 assert_eq!(pci_test_device.read_u32(bar_address1), 2);
1805 assert_eq!(pci_test_device.read_u32(bar_address1 + 4), 0);
1806 assert_eq!(pci_test_device.read_u32(bar_address1 + 8), 0);
1808 assert_eq!(pci_test_device.read_u32(bar_address1 + 12), 0);
1810 pci_test_device.write_u32(bar_address1 + 12, 2);
1811 assert_eq!(pci_test_device.read_u32(bar_address1 + 12), 2);
1812 pci_test_device.write_u32(bar_address1 + 12, 0xffffffff);
1813 assert_eq!(
1814 pci_test_device.read_u32(bar_address1 + 12),
1815 VIRTIO_F_RING_INDIRECT_DESC | VIRTIO_F_RING_EVENT_IDX | 2
1816 );
1817 pci_test_device.write_u32(bar_address1 + 8, 1);
1819 assert_eq!(pci_test_device.read_u32(bar_address1 + 8), 1);
1820 assert_eq!(pci_test_device.read_u32(bar_address1 + 12), 0);
1821 pci_test_device.write_u32(bar_address1 + 12, 0xffffffff);
1822 assert_eq!(
1823 pci_test_device.read_u32(bar_address1 + 12),
1824 VIRTIO_F_VERSION_1
1825 );
1826 pci_test_device.write_u32(bar_address1 + 8, 2);
1828 assert_eq!(pci_test_device.read_u32(bar_address1 + 8), 2);
1829 assert_eq!(pci_test_device.read_u32(bar_address1 + 12), 0);
1830 pci_test_device.write_u32(bar_address1 + 12, 0xffffffff);
1831 assert_eq!(pci_test_device.read_u32(bar_address1 + 12), 0);
1832 assert_eq!(pci_test_device.read_u32(bar_address1 + 16), 1 << 16);
1834 assert_eq!(pci_test_device.read_u32(bar_address1 + 20), 0);
1836 assert_eq!(pci_test_device.read_u32(bar_address1 + 24), 0x40);
1838 pci_test_device.write_u32(bar_address1 + 24, 0x20);
1839 assert_eq!(pci_test_device.read_u32(bar_address1 + 24), 0x20);
1840 assert_eq!(pci_test_device.read_u32(bar_address1 + 28), 0);
1842 pci_test_device.write_u32(bar_address1 + 28, 1);
1843 assert_eq!(pci_test_device.read_u32(bar_address1 + 28), 1);
1844 pci_test_device.write_u32(bar_address1 + 28, 0xffff);
1845 assert_eq!(pci_test_device.read_u32(bar_address1 + 28), 1);
1846 pci_test_device.write_u32(bar_address1 + 28, 0);
1847 assert_eq!(pci_test_device.read_u32(bar_address1 + 28), 0);
1848 assert_eq!(pci_test_device.read_u32(bar_address1 + 32), 0);
1850 pci_test_device.write_u32(bar_address1 + 32, 0xffff);
1851 assert_eq!(pci_test_device.read_u32(bar_address1 + 32), 0xffff);
1852 assert_eq!(pci_test_device.read_u32(bar_address1 + 36), 0);
1854 pci_test_device.write_u32(bar_address1 + 36, 1);
1855 assert_eq!(pci_test_device.read_u32(bar_address1 + 36), 1);
1856 assert_eq!(pci_test_device.read_u32(bar_address1 + 40), 0);
1858 pci_test_device.write_u32(bar_address1 + 40, 0xeeee);
1859 assert_eq!(pci_test_device.read_u32(bar_address1 + 40), 0xeeee);
1860 assert_eq!(pci_test_device.read_u32(bar_address1 + 44), 0);
1862 pci_test_device.write_u32(bar_address1 + 44, 2);
1863 assert_eq!(pci_test_device.read_u32(bar_address1 + 44), 2);
1864 assert_eq!(pci_test_device.read_u32(bar_address1 + 48), 0);
1866 pci_test_device.write_u32(bar_address1 + 48, 0xdddd);
1867 assert_eq!(pci_test_device.read_u32(bar_address1 + 48), 0xdddd);
1868 assert_eq!(pci_test_device.read_u32(bar_address1 + 52), 0);
1870 pci_test_device.write_u32(bar_address1 + 52, 3);
1871 assert_eq!(pci_test_device.read_u32(bar_address1 + 52), 3);
1872 assert_eq!(pci_test_device.read_u32(bar_address1 + 56), 0);
1874 assert_eq!(pci_test_device.read_u32(bar_address1 + 60), 0);
1876
1877 let queue_index: u16 = 1;
1879 pci_test_device
1880 .pci_device
1881 .mmio_write(bar_address1 + 22, &queue_index.to_le_bytes())
1882 .unwrap();
1883 assert_eq!(pci_test_device.read_u32(bar_address1 + 20), 1 << 24);
1884 assert_eq!(pci_test_device.read_u32(bar_address1 + 24), 0);
1886 pci_test_device.write_u32(bar_address1 + 24, 2);
1887 assert_eq!(pci_test_device.read_u32(bar_address1 + 24), 0);
1888 assert_eq!(pci_test_device.read_u32(bar_address1 + 28), 0);
1890 pci_test_device.write_u32(bar_address1 + 28, 1);
1891 assert_eq!(pci_test_device.read_u32(bar_address1 + 28), 0);
1892 assert_eq!(pci_test_device.read_u32(bar_address1 + 32), 0);
1894 pci_test_device.write_u32(bar_address1 + 32, 0x10);
1895 assert_eq!(pci_test_device.read_u32(bar_address1 + 32), 0);
1896 assert_eq!(pci_test_device.read_u32(bar_address1 + 36), 0);
1898 pci_test_device.write_u32(bar_address1 + 36, 0x10);
1899 assert_eq!(pci_test_device.read_u32(bar_address1 + 36), 0);
1900 assert_eq!(pci_test_device.read_u32(bar_address1 + 40), 0);
1902 pci_test_device.write_u32(bar_address1 + 40, 0x10);
1903 assert_eq!(pci_test_device.read_u32(bar_address1 + 40), 0);
1904 assert_eq!(pci_test_device.read_u32(bar_address1 + 44), 0);
1906 pci_test_device.write_u32(bar_address1 + 44, 0x10);
1907 assert_eq!(pci_test_device.read_u32(bar_address1 + 44), 0);
1908 assert_eq!(pci_test_device.read_u32(bar_address1 + 48), 0);
1910 pci_test_device.write_u32(bar_address1 + 48, 0x10);
1911 assert_eq!(pci_test_device.read_u32(bar_address1 + 48), 0);
1912 assert_eq!(pci_test_device.read_u32(bar_address1 + 52), 0);
1914 pci_test_device.write_u32(bar_address1 + 52, 0x10);
1915 assert_eq!(pci_test_device.read_u32(bar_address1 + 52), 0);
1916 }
1917
1918 #[async_test]
1919 async fn verify_queue_simple(driver: DefaultDriver) {
1920 let test_mem = VirtioTestMemoryAccess::new();
1921 let mut guest = VirtioTestGuest::new(&driver, &test_mem, 1, 2, true);
1922 let base_addr = guest.get_queue_descriptor_backing_memory_address(0);
1923 let (tx, mut rx) = mesh::mpsc_channel();
1924 let event = Event::new();
1925 let mut queues = guest.create_direct_queues(|i| {
1926 let tx = tx.clone();
1927 CreateDirectQueueParams {
1928 process_work: Box::new(move |work: anyhow::Result<VirtioQueueCallbackWork>| {
1929 let mut work = work.expect("Queue failure");
1930 assert_eq!(work.payload.len(), 1);
1931 assert_eq!(work.payload[0].address, base_addr);
1932 assert_eq!(work.payload[0].length, 0x1000);
1933 work.complete(123);
1934 true
1935 }),
1936 notify: Interrupt::from_fn(move || {
1937 tx.send(i as usize);
1938 }),
1939 event: event.clone(),
1940 }
1941 });
1942
1943 guest.add_to_avail_queue(0);
1944 event.signal();
1945 must_recv_in_timeout(&mut rx, Duration::from_millis(100)).await;
1946 let (desc, len) = guest.get_next_completed(0).unwrap();
1947 assert_eq!(desc, 0u16);
1948 assert_eq!(len, 123);
1949 assert_eq!(guest.get_next_completed(0).is_none(), true);
1950 queues[0].stop().await;
1951 }
1952
1953 #[async_test]
1954 async fn verify_queue_indirect(driver: DefaultDriver) {
1955 let test_mem = VirtioTestMemoryAccess::new();
1956 let mut guest = VirtioTestGuest::new(&driver, &test_mem, 1, 2, true);
1957 let (tx, mut rx) = mesh::mpsc_channel();
1958 let event = Event::new();
1959 let mut queues = guest.create_direct_queues(|i| {
1960 let tx = tx.clone();
1961 CreateDirectQueueParams {
1962 process_work: Box::new(move |work: anyhow::Result<VirtioQueueCallbackWork>| {
1963 let mut work = work.expect("Queue failure");
1964 assert_eq!(work.payload.len(), 1);
1965 assert_eq!(work.payload[0].address, 0xffffffff00000000u64);
1966 assert_eq!(work.payload[0].length, 0x1000);
1967 work.complete(123);
1968 true
1969 }),
1970 notify: Interrupt::from_fn(move || {
1971 tx.send(i as usize);
1972 }),
1973 event: event.clone(),
1974 }
1975 });
1976
1977 guest.add_indirect_to_avail_queue(0);
1978 event.signal();
1979 must_recv_in_timeout(&mut rx, Duration::from_millis(100)).await;
1980 let (desc, len) = guest.get_next_completed(0).unwrap();
1981 assert_eq!(desc, 0u16);
1982 assert_eq!(len, 123);
1983 assert_eq!(guest.get_next_completed(0).is_none(), true);
1984 queues[0].stop().await;
1985 }
1986
1987 #[async_test]
1988 async fn verify_queue_linked(driver: DefaultDriver) {
1989 let test_mem = VirtioTestMemoryAccess::new();
1990 let mut guest = VirtioTestGuest::new(&driver, &test_mem, 1, 5, true);
1991 let (tx, mut rx) = mesh::mpsc_channel();
1992 let base_address = guest.get_queue_descriptor_backing_memory_address(0);
1993 let event = Event::new();
1994 let mut queues = guest.create_direct_queues(|i| {
1995 let tx = tx.clone();
1996 CreateDirectQueueParams {
1997 process_work: Box::new(move |work: anyhow::Result<VirtioQueueCallbackWork>| {
1998 let mut work = work.expect("Queue failure");
1999 assert_eq!(work.payload.len(), 3);
2000 for i in 0..work.payload.len() {
2001 assert_eq!(work.payload[i].address, base_address + 0x1000 * i as u64);
2002 assert_eq!(work.payload[i].length, 0x1000);
2003 }
2004 work.complete(123 * 3);
2005 true
2006 }),
2007 notify: Interrupt::from_fn(move || {
2008 tx.send(i as usize);
2009 }),
2010 event: event.clone(),
2011 }
2012 });
2013
2014 guest.add_linked_to_avail_queue(0, 3);
2015 event.signal();
2016 must_recv_in_timeout(&mut rx, Duration::from_millis(100)).await;
2017 let (desc, len) = guest.get_next_completed(0).unwrap();
2018 assert_eq!(desc, 0u16);
2019 assert_eq!(len, 123 * 3);
2020 assert_eq!(guest.get_next_completed(0).is_none(), true);
2021 queues[0].stop().await;
2022 }
2023
2024 #[async_test]
2025 async fn verify_queue_indirect_linked(driver: DefaultDriver) {
2026 let test_mem = VirtioTestMemoryAccess::new();
2027 let mut guest = VirtioTestGuest::new(&driver, &test_mem, 1, 5, true);
2028 let (tx, mut rx) = mesh::mpsc_channel();
2029 let event = Event::new();
2030 let mut queues = guest.create_direct_queues(|i| {
2031 let tx = tx.clone();
2032 CreateDirectQueueParams {
2033 process_work: Box::new(move |work: anyhow::Result<VirtioQueueCallbackWork>| {
2034 let mut work = work.expect("Queue failure");
2035 assert_eq!(work.payload.len(), 3);
2036 for i in 0..work.payload.len() {
2037 assert_eq!(
2038 work.payload[i].address,
2039 0xffffffff00000000u64 + 0x1000 * i as u64
2040 );
2041 assert_eq!(work.payload[i].length, 0x1000);
2042 }
2043 work.complete(123 * 3);
2044 true
2045 }),
2046 notify: Interrupt::from_fn(move || {
2047 tx.send(i as usize);
2048 }),
2049 event: event.clone(),
2050 }
2051 });
2052
2053 guest.add_indirect_linked_to_avail_queue(0, 3);
2054 event.signal();
2055 must_recv_in_timeout(&mut rx, Duration::from_millis(100)).await;
2056 let (desc, len) = guest.get_next_completed(0).unwrap();
2057 assert_eq!(desc, 0u16);
2058 assert_eq!(len, 123 * 3);
2059 assert_eq!(guest.get_next_completed(0).is_none(), true);
2060 queues[0].stop().await;
2061 }
2062
2063 #[async_test]
2064 async fn verify_queue_avail_rollover(driver: DefaultDriver) {
2065 let test_mem = VirtioTestMemoryAccess::new();
2066 let mut guest = VirtioTestGuest::new(&driver, &test_mem, 1, 2, true);
2067 let base_addr = guest.get_queue_descriptor_backing_memory_address(0);
2068 let (tx, mut rx) = mesh::mpsc_channel();
2069 let event = Event::new();
2070 let mut queues = guest.create_direct_queues(|i| {
2071 let tx = tx.clone();
2072 CreateDirectQueueParams {
2073 process_work: Box::new(move |work: anyhow::Result<VirtioQueueCallbackWork>| {
2074 let mut work = work.expect("Queue failure");
2075 assert_eq!(work.payload.len(), 1);
2076 assert_eq!(work.payload[0].address, base_addr);
2077 assert_eq!(work.payload[0].length, 0x1000);
2078 work.complete(123);
2079 true
2080 }),
2081 notify: Interrupt::from_fn(move || {
2082 tx.send(i as usize);
2083 }),
2084 event: event.clone(),
2085 }
2086 });
2087
2088 for _ in 0..3 {
2089 guest.add_to_avail_queue(0);
2090 event.signal();
2091 must_recv_in_timeout(&mut rx, Duration::from_millis(100)).await;
2092 let (desc, len) = guest.get_next_completed(0).unwrap();
2093 assert_eq!(desc, 0u16);
2094 assert_eq!(len, 123);
2095 assert_eq!(guest.get_next_completed(0).is_none(), true);
2096 }
2097
2098 queues[0].stop().await;
2099 }
2100
2101 #[async_test]
2102 async fn verify_multi_queue(driver: DefaultDriver) {
2103 let test_mem = VirtioTestMemoryAccess::new();
2104 let mut guest = VirtioTestGuest::new(&driver, &test_mem, 5, 2, true);
2105 let (tx, mut rx) = mesh::mpsc_channel();
2106 let events = (0..guest.num_queues)
2107 .map(|_| Event::new())
2108 .collect::<Vec<_>>();
2109 let mut queues = guest.create_direct_queues(|queue_index| {
2110 let tx = tx.clone();
2111 let base_addr = guest.get_queue_descriptor_backing_memory_address(queue_index);
2112 CreateDirectQueueParams {
2113 process_work: Box::new(move |work: anyhow::Result<VirtioQueueCallbackWork>| {
2114 let mut work = work.expect("Queue failure");
2115 assert_eq!(work.payload.len(), 1);
2116 assert_eq!(work.payload[0].address, base_addr);
2117 assert_eq!(work.payload[0].length, 0x1000);
2118 work.complete(123 * queue_index as u32);
2119 true
2120 }),
2121 notify: Interrupt::from_fn(move || {
2122 tx.send(queue_index as usize);
2123 }),
2124 event: events[queue_index as usize].clone(),
2125 }
2126 });
2127
2128 for (i, event) in events.iter().enumerate() {
2129 let queue_index = i as u16;
2130 guest.add_to_avail_queue(queue_index);
2131 event.signal();
2132 }
2133 for _ in 0..guest.num_queues {
2135 must_recv_in_timeout(&mut rx, Duration::from_millis(100)).await;
2136 }
2137 for queue_index in 0..guest.num_queues {
2139 let (desc, len) = guest.get_next_completed(queue_index).unwrap();
2140 assert_eq!(desc, 0u16);
2141 assert_eq!(len, 123 * queue_index as u32);
2142 }
2143 for (i, queue) in queues.iter_mut().enumerate() {
2145 let queue_index = i as u16;
2146 assert_eq!(guest.get_next_completed(queue_index).is_none(), true);
2147 queue.stop().await;
2148 }
2149 }
2150
2151 fn take_mmio_interrupt_status(dev: &mut VirtioMmioDevice, mask: u32) -> u32 {
2152 let mut v = [0; 4];
2153 dev.mmio_read(96, &mut v).unwrap();
2154 dev.mmio_write(100, &mask.to_ne_bytes()).unwrap();
2155 u32::from_ne_bytes(v)
2156 }
2157
2158 async fn expect_mmio_interrupt(
2159 dev: &mut VirtioMmioDevice,
2160 target: &TestLineInterruptTarget,
2161 mask: u32,
2162 multiple_expected: bool,
2163 ) {
2164 poll_fn(|cx| target.poll_high(cx, 0)).await;
2165 let v = take_mmio_interrupt_status(dev, mask);
2166 assert_eq!(v & mask, mask);
2167 assert!(multiple_expected || !target.is_high(0));
2168 }
2169
2170 #[async_test]
2171 async fn verify_device_queue_simple(driver: DefaultDriver) {
2172 let test_mem = VirtioTestMemoryAccess::new();
2173 let doorbell_registration: Arc<dyn DoorbellRegistration> = test_mem.clone();
2174 let mut guest = VirtioTestGuest::new(&driver, &test_mem, 1, 2, true);
2175 let mem = guest.mem();
2176 let features = ((VIRTIO_F_VERSION_1 as u64) << 32) | VIRTIO_F_RING_EVENT_IDX as u64 | 2;
2177 let target = TestLineInterruptTarget::new_arc();
2178 let interrupt = LineInterrupt::new_with_target("test", target.clone(), 0);
2179 let base_addr = guest.get_queue_descriptor_backing_memory_address(0);
2180 let queue_work = Arc::new(move |_: u16, mut work: VirtioQueueCallbackWork| {
2181 assert_eq!(work.payload.len(), 1);
2182 assert_eq!(work.payload[0].address, base_addr);
2183 assert_eq!(work.payload[0].length, 0x1000);
2184 work.complete(123);
2185 });
2186 let mut dev = VirtioMmioDevice::new(
2187 Box::new(LegacyWrapper::new(
2188 &VmTaskDriverSource::new(SingleDriverBackend::new(driver)),
2189 TestDevice::new(
2190 DeviceTraits {
2191 device_id: 3,
2192 device_features: features,
2193 max_queues: 1,
2194 device_register_length: 0,
2195 ..Default::default()
2196 },
2197 Some(queue_work),
2198 ),
2199 &mem,
2200 )),
2201 interrupt,
2202 Some(doorbell_registration),
2203 0,
2204 1,
2205 );
2206
2207 guest.setup_chipset_device(&mut dev, features);
2208 expect_mmio_interrupt(
2209 &mut dev,
2210 &target,
2211 VIRTIO_MMIO_INTERRUPT_STATUS_CONFIG_CHANGE,
2212 false,
2213 )
2214 .await;
2215 guest.add_to_avail_queue(0);
2216 dev.write_u32(80, 0);
2218 expect_mmio_interrupt(
2219 &mut dev,
2220 &target,
2221 VIRTIO_MMIO_INTERRUPT_STATUS_USED_BUFFER,
2222 false,
2223 )
2224 .await;
2225 let (desc, len) = guest.get_next_completed(0).unwrap();
2226 assert_eq!(desc, 0u16);
2227 assert_eq!(len, 123);
2228 assert_eq!(guest.get_next_completed(0).is_none(), true);
2229 dev.write_u32(112, 0);
2231 drop(dev);
2232 }
2233
2234 #[async_test]
2235 async fn verify_device_multi_queue(driver: DefaultDriver) {
2236 let num_queues = 5;
2237 let test_mem = VirtioTestMemoryAccess::new();
2238 let doorbell_registration: Arc<dyn DoorbellRegistration> = test_mem.clone();
2239 let mut guest = VirtioTestGuest::new(&driver, &test_mem, num_queues, 2, true);
2240 let mem = guest.mem();
2241 let features = ((VIRTIO_F_VERSION_1 as u64) << 32) | VIRTIO_F_RING_EVENT_IDX as u64 | 2;
2242 let target = TestLineInterruptTarget::new_arc();
2243 let interrupt = LineInterrupt::new_with_target("test", target.clone(), 0);
2244 let base_addr: Vec<_> = (0..num_queues)
2245 .map(|i| guest.get_queue_descriptor_backing_memory_address(i))
2246 .collect();
2247 let queue_work = Arc::new(move |i: u16, mut work: VirtioQueueCallbackWork| {
2248 assert_eq!(work.payload.len(), 1);
2249 assert_eq!(work.payload[0].address, base_addr[i as usize]);
2250 assert_eq!(work.payload[0].length, 0x1000);
2251 work.complete(123 * i as u32);
2252 });
2253 let mut dev = VirtioMmioDevice::new(
2254 Box::new(LegacyWrapper::new(
2255 &VmTaskDriverSource::new(SingleDriverBackend::new(driver)),
2256 TestDevice::new(
2257 DeviceTraits {
2258 device_id: 3,
2259 device_features: features,
2260 max_queues: num_queues + 1,
2261 device_register_length: 0,
2262 ..Default::default()
2263 },
2264 Some(queue_work),
2265 ),
2266 &mem,
2267 )),
2268 interrupt,
2269 Some(doorbell_registration),
2270 0,
2271 1,
2272 );
2273 guest.setup_chipset_device(&mut dev, features);
2274 expect_mmio_interrupt(
2275 &mut dev,
2276 &target,
2277 VIRTIO_MMIO_INTERRUPT_STATUS_CONFIG_CHANGE,
2278 false,
2279 )
2280 .await;
2281 for i in 0..num_queues {
2282 guest.add_to_avail_queue(i);
2283 dev.write_u32(80, i as u32);
2285 }
2286 for i in 0..num_queues {
2288 let (desc, len) = loop {
2289 if let Some(x) = guest.get_next_completed(i) {
2290 break x;
2291 }
2292 expect_mmio_interrupt(
2293 &mut dev,
2294 &target,
2295 VIRTIO_MMIO_INTERRUPT_STATUS_USED_BUFFER,
2296 i < (num_queues - 1),
2297 )
2298 .await;
2299 };
2300 assert_eq!(desc, 0u16);
2301 assert_eq!(len, 123 * i as u32);
2302 }
2303 for i in 0..num_queues {
2305 assert_eq!(guest.get_next_completed(i).is_none(), true);
2306 }
2307 dev.write_u32(112, 0);
2309 drop(dev);
2310 }
2311
2312 #[async_test]
2313 async fn verify_device_multi_queue_pci(driver: DefaultDriver) {
2314 let num_queues = 5;
2315 let test_mem = VirtioTestMemoryAccess::new();
2316 let mut guest = VirtioTestGuest::new(&driver, &test_mem, num_queues, 2, true);
2317 let features = ((VIRTIO_F_VERSION_1 as u64) << 32) | VIRTIO_F_RING_EVENT_IDX as u64 | 2;
2318 let base_addr: Vec<_> = (0..num_queues)
2319 .map(|i| guest.get_queue_descriptor_backing_memory_address(i))
2320 .collect();
2321 let mut dev = VirtioPciTestDevice::new(
2322 &driver,
2323 num_queues + 1,
2324 &test_mem,
2325 Some(Arc::new(move |i, mut work| {
2326 assert_eq!(work.payload.len(), 1);
2327 assert_eq!(work.payload[0].address, base_addr[i as usize]);
2328 assert_eq!(work.payload[0].length, 0x1000);
2329 work.complete(123 * i as u32);
2330 })),
2331 );
2332
2333 guest.setup_pci_device(&mut dev, features);
2334
2335 let mut timer = PolledTimer::new(&driver);
2336
2337 timer.sleep(Duration::from_millis(100)).await;
2339 let delivered = dev.test_intc.get_next_interrupt().unwrap();
2340 assert_eq!(delivered.0, 0);
2341 assert!(dev.test_intc.get_next_interrupt().is_none());
2342
2343 for i in 0..num_queues {
2344 guest.add_to_avail_queue(i);
2345 dev.write_u32(0x10000000000 + 0x38, i as u32);
2347 }
2348 timer.sleep(Duration::from_millis(100)).await;
2350 for _ in 0..num_queues {
2351 let delivered = dev.test_intc.get_next_interrupt();
2352 assert!(delivered.is_some());
2353 }
2354 for i in 0..num_queues {
2356 let (desc, len) = guest.get_next_completed(i).unwrap();
2357 assert_eq!(desc, 0u16);
2358 assert_eq!(len, 123 * i as u32);
2359 }
2360 for i in 0..num_queues {
2362 assert_eq!(guest.get_next_completed(i).is_none(), true);
2363 }
2364 let device_status: u8 = 0;
2366 dev.pci_device
2367 .mmio_write(0x10000000000 + 20, &device_status.to_le_bytes())
2368 .unwrap();
2369 drop(dev);
2370 }
2371}