1use crate::queue::QueueCore;
5use crate::queue::QueueError;
6use crate::queue::QueueParams;
7use crate::queue::VirtioQueuePayload;
8use async_trait::async_trait;
9use futures::FutureExt;
10use futures::Stream;
11use futures::StreamExt;
12use guestmem::DoorbellRegistration;
13use guestmem::GuestMemory;
14use guestmem::GuestMemoryError;
15use guestmem::MappedMemoryRegion;
16use pal_async::DefaultPool;
17use pal_async::driver::Driver;
18use pal_async::task::Spawn;
19use pal_async::wait::PolledWait;
20use pal_event::Event;
21use parking_lot::Mutex;
22use std::io::Error;
23use std::pin::Pin;
24use std::sync::Arc;
25use std::task::Context;
26use std::task::Poll;
27use std::task::ready;
28use task_control::AsyncRun;
29use task_control::StopTask;
30use task_control::TaskControl;
31use thiserror::Error;
32use vmcore::interrupt::Interrupt;
33use vmcore::vm_task::VmTaskDriver;
34use vmcore::vm_task::VmTaskDriverSource;
35
36#[async_trait]
37pub trait VirtioQueueWorkerContext {
38 async fn process_work(&mut self, work: anyhow::Result<VirtioQueueCallbackWork>) -> bool;
39}
40
41#[derive(Debug)]
42pub struct VirtioQueueUsedHandler {
43 core: QueueCore,
44 last_used_index: u16,
45 outstanding_desc_count: Arc<Mutex<(u16, event_listener::Event)>>,
46 notify_guest: Interrupt,
47}
48
49impl VirtioQueueUsedHandler {
50 fn new(core: QueueCore, notify_guest: Interrupt) -> Self {
51 Self {
52 core,
53 last_used_index: 0,
54 outstanding_desc_count: Arc::new(Mutex::new((0, event_listener::Event::new()))),
55 notify_guest,
56 }
57 }
58
59 pub fn add_outstanding_descriptor(&self) {
60 let (count, _) = &mut *self.outstanding_desc_count.lock();
61 *count += 1;
62 }
63
64 pub fn await_outstanding_descriptors(&self) -> event_listener::EventListener {
65 let (count, event) = &*self.outstanding_desc_count.lock();
66 let listener = event.listen();
67 if *count == 0 {
68 event.notify(usize::MAX);
69 }
70 listener
71 }
72
73 pub fn complete_descriptor(&mut self, descriptor_index: u16, bytes_written: u32) {
74 match self.core.complete_descriptor(
75 &mut self.last_used_index,
76 descriptor_index,
77 bytes_written,
78 ) {
79 Ok(true) => {
80 self.notify_guest.deliver();
81 }
82 Ok(false) => {}
83 Err(err) => {
84 tracelimit::error_ratelimited!(
85 error = &err as &dyn std::error::Error,
86 "failed to complete descriptor"
87 );
88 }
89 }
90 {
91 let (count, event) = &mut *self.outstanding_desc_count.lock();
92 *count -= 1;
93 if *count == 0 {
94 event.notify(usize::MAX);
95 }
96 }
97 }
98}
99
100pub struct VirtioQueueCallbackWork {
101 pub payload: Vec<VirtioQueuePayload>,
102 used_queue_handler: Arc<Mutex<VirtioQueueUsedHandler>>,
103 descriptor_index: u16,
104 completed: bool,
105}
106
107impl VirtioQueueCallbackWork {
108 pub fn new(
109 payload: Vec<VirtioQueuePayload>,
110 used_queue_handler: &Arc<Mutex<VirtioQueueUsedHandler>>,
111 descriptor_index: u16,
112 ) -> Self {
113 let used_queue_handler = used_queue_handler.clone();
114 used_queue_handler.lock().add_outstanding_descriptor();
115 Self {
116 payload,
117 used_queue_handler,
118 descriptor_index,
119 completed: false,
120 }
121 }
122
123 pub fn complete(&mut self, bytes_written: u32) {
124 assert!(!self.completed);
125 self.used_queue_handler
126 .lock()
127 .complete_descriptor(self.descriptor_index, bytes_written);
128 self.completed = true;
129 }
130
131 pub fn descriptor_index(&self) -> u16 {
132 self.descriptor_index
133 }
134
135 pub fn get_payload_length(&self, writeable: bool) -> u64 {
137 self.payload
138 .iter()
139 .filter(|x| x.writeable == writeable)
140 .fold(0, |acc, x| acc + x.length as u64)
141 }
142
143 pub fn read(&self, mem: &GuestMemory, target: &mut [u8]) -> Result<usize, GuestMemoryError> {
145 let mut remaining = target;
146 let mut read_bytes: usize = 0;
147 for payload in &self.payload {
148 if payload.writeable {
149 continue;
150 }
151
152 let size = std::cmp::min(payload.length as usize, remaining.len());
153 let (current, next) = remaining.split_at_mut(size);
154 mem.read_at(payload.address, current)?;
155 read_bytes += size;
156 if next.is_empty() {
157 break;
158 }
159
160 remaining = next;
161 }
162
163 Ok(read_bytes)
164 }
165
166 pub fn write_at_offset(
168 &self,
169 offset: u64,
170 mem: &GuestMemory,
171 source: &[u8],
172 ) -> Result<(), VirtioWriteError> {
173 let mut skip_bytes = offset;
174 let mut remaining = source;
175 for payload in &self.payload {
176 if !payload.writeable {
177 continue;
178 }
179
180 let payload_length = payload.length as u64;
181 if skip_bytes >= payload_length {
182 skip_bytes -= payload_length;
183 continue;
184 }
185
186 let size = std::cmp::min(
187 payload_length as usize - skip_bytes as usize,
188 remaining.len(),
189 );
190 let (current, next) = remaining.split_at(size);
191 mem.write_at(payload.address + skip_bytes, current)?;
192 remaining = next;
193 if remaining.is_empty() {
194 break;
195 }
196 skip_bytes = 0;
197 }
198
199 if !remaining.is_empty() {
200 return Err(VirtioWriteError::NotAllWritten(source.len()));
201 }
202
203 Ok(())
204 }
205
206 pub fn write(&self, mem: &GuestMemory, source: &[u8]) -> Result<(), VirtioWriteError> {
207 self.write_at_offset(0, mem, source)
208 }
209}
210
211#[derive(Debug, Error)]
212pub enum VirtioWriteError {
213 #[error(transparent)]
214 Memory(#[from] GuestMemoryError),
215 #[error("{0:#x} bytes not written")]
216 NotAllWritten(usize),
217}
218
219impl Drop for VirtioQueueCallbackWork {
220 fn drop(&mut self) {
221 if !self.completed {
222 self.complete(0);
223 }
224 }
225}
226
227#[derive(Debug)]
228pub struct VirtioQueue {
229 core: QueueCore,
230 last_avail_index: u16,
231 used_handler: Arc<Mutex<VirtioQueueUsedHandler>>,
232 queue_event: PolledWait<Event>,
233}
234
235impl VirtioQueue {
236 pub fn new(
237 features: u64,
238 params: QueueParams,
239 mem: GuestMemory,
240 notify: Interrupt,
241 queue_event: PolledWait<Event>,
242 ) -> Result<Self, QueueError> {
243 let core = QueueCore::new(features, mem, params)?;
244 let used_handler = Arc::new(Mutex::new(VirtioQueueUsedHandler::new(
245 core.clone(),
246 notify,
247 )));
248 Ok(Self {
249 core,
250 last_avail_index: 0,
251 used_handler,
252 queue_event,
253 })
254 }
255
256 async fn wait_for_outstanding_descriptors(&self) {
257 let wait_for_descriptors = self.used_handler.lock().await_outstanding_descriptors();
258 wait_for_descriptors.await;
259 }
260
261 fn poll_next_buffer(
262 &mut self,
263 cx: &mut Context<'_>,
264 ) -> Poll<Result<Option<VirtioQueueCallbackWork>, QueueError>> {
265 let descriptor_index = loop {
266 if let Some(descriptor_index) = self.core.descriptor_index(self.last_avail_index)? {
267 break descriptor_index;
268 };
269 ready!(self.queue_event.wait().poll_unpin(cx)).expect("waits on Event cannot fail");
270 };
271 let payload = self
272 .core
273 .reader(descriptor_index)
274 .collect::<Result<Vec<_>, _>>()?;
275
276 self.last_avail_index = self.last_avail_index.wrapping_add(1);
277 Poll::Ready(Ok(Some(VirtioQueueCallbackWork::new(
278 payload,
279 &self.used_handler,
280 descriptor_index,
281 ))))
282 }
283}
284
285impl Drop for VirtioQueue {
286 fn drop(&mut self) {
287 if Arc::get_mut(&mut self.used_handler).is_none() {
288 tracing::error!("Virtio queue dropped with outstanding work pending")
289 }
290 }
291}
292
293impl Stream for VirtioQueue {
294 type Item = Result<VirtioQueueCallbackWork, Error>;
295
296 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
297 let Some(r) = ready!(self.get_mut().poll_next_buffer(cx)).transpose() else {
298 return Poll::Ready(None);
299 };
300
301 Poll::Ready(Some(r.map_err(Error::other)))
302 }
303}
304
305enum VirtioQueueStateInner {
306 Initializing {
307 mem: GuestMemory,
308 features: u64,
309 params: QueueParams,
310 event: Event,
311 notify: Interrupt,
312 exit_event: event_listener::EventListener,
313 },
314 InitializationInProgress,
315 Running {
316 queue: VirtioQueue,
317 exit_event: event_listener::EventListener,
318 },
319}
320
321pub struct VirtioQueueState {
322 inner: VirtioQueueStateInner,
323}
324
325pub struct VirtioQueueWorker {
326 driver: Box<dyn Driver>,
327 context: Box<dyn VirtioQueueWorkerContext + Send>,
328}
329
330impl VirtioQueueWorker {
331 pub fn new(driver: impl Driver, context: Box<dyn VirtioQueueWorkerContext + Send>) -> Self {
332 Self {
333 driver: Box::new(driver),
334 context,
335 }
336 }
337
338 pub fn into_running_task(
339 self,
340 name: impl Into<String>,
341 mem: GuestMemory,
342 features: u64,
343 queue_resources: QueueResources,
344 exit_event: event_listener::EventListener,
345 ) -> TaskControl<VirtioQueueWorker, VirtioQueueState> {
346 let name = name.into();
347 let (_, driver) = DefaultPool::spawn_on_thread(&name);
348
349 let mut task = TaskControl::new(self);
350 task.insert(
351 driver,
352 name,
353 VirtioQueueState {
354 inner: VirtioQueueStateInner::Initializing {
355 mem,
356 features,
357 params: queue_resources.params,
358 event: queue_resources.event,
359 notify: queue_resources.notify,
360 exit_event,
361 },
362 },
363 );
364 task.start();
365 task
366 }
367
368 async fn run_queue(&mut self, state: &mut VirtioQueueState) -> bool {
369 match &mut state.inner {
370 VirtioQueueStateInner::InitializationInProgress => unreachable!(),
371 VirtioQueueStateInner::Initializing { .. } => {
372 let VirtioQueueStateInner::Initializing {
373 mem,
374 features,
375 params,
376 event,
377 notify,
378 exit_event,
379 } = std::mem::replace(
380 &mut state.inner,
381 VirtioQueueStateInner::InitializationInProgress,
382 )
383 else {
384 unreachable!()
385 };
386 let queue_event = PolledWait::new(&self.driver, event).unwrap();
387 let queue = VirtioQueue::new(features, params, mem, notify, queue_event);
388 if let Err(err) = queue {
389 tracing::error!(
390 err = &err as &dyn std::error::Error,
391 "Failed to start queue"
392 );
393 false
394 } else {
395 state.inner = VirtioQueueStateInner::Running {
396 queue: queue.unwrap(),
397 exit_event,
398 };
399 true
400 }
401 }
402 VirtioQueueStateInner::Running { queue, exit_event } => {
403 let mut exit = exit_event.fuse();
404 let mut queue_ready = queue.next().fuse();
405 let work = futures::select_biased! {
406 _ = exit => return false,
407 work = queue_ready => work.expect("queue will never complete").map_err(anyhow::Error::from),
408 };
409 self.context.process_work(work).await
410 }
411 }
412 }
413}
414
415impl AsyncRun<VirtioQueueState> for VirtioQueueWorker {
416 async fn run(
417 &mut self,
418 stop: &mut StopTask<'_>,
419 state: &mut VirtioQueueState,
420 ) -> Result<(), task_control::Cancelled> {
421 while stop.until_stopped(self.run_queue(state)).await? {}
422 Ok(())
423 }
424}
425
426pub struct VirtioRunningState {
427 pub features: u64,
428 pub enabled_queues: Vec<bool>,
429}
430
431pub enum VirtioState {
432 Unknown,
433 Running(VirtioRunningState),
434 Stopped,
435}
436
437pub(crate) struct VirtioDoorbells {
438 registration: Option<Arc<dyn DoorbellRegistration>>,
439 doorbells: Vec<Box<dyn Send + Sync>>,
440}
441
442impl VirtioDoorbells {
443 pub fn new(registration: Option<Arc<dyn DoorbellRegistration>>) -> Self {
444 Self {
445 registration,
446 doorbells: Vec::new(),
447 }
448 }
449
450 pub fn add(&mut self, address: u64, value: Option<u64>, length: Option<u32>, event: &Event) {
451 if let Some(registration) = &mut self.registration {
452 let doorbell = registration.register_doorbell(address, value, length, event);
453 if let Ok(doorbell) = doorbell {
454 self.doorbells.push(doorbell);
455 }
456 }
457 }
458
459 pub fn clear(&mut self) {
460 self.doorbells.clear();
461 }
462}
463
464#[derive(Copy, Clone, Debug, Default)]
465pub struct DeviceTraitsSharedMemory {
466 pub id: u8,
467 pub size: u64,
468}
469
470#[derive(Copy, Clone, Debug, Default)]
471pub struct DeviceTraits {
472 pub device_id: u16,
473 pub device_features: u64,
474 pub max_queues: u16,
475 pub device_register_length: u32,
476 pub shared_memory: DeviceTraitsSharedMemory,
477}
478
479pub trait LegacyVirtioDevice: Send {
480 fn traits(&self) -> DeviceTraits;
481 fn read_registers_u32(&self, offset: u16) -> u32;
482 fn write_registers_u32(&mut self, offset: u16, val: u32);
483 fn get_work_callback(&mut self, index: u16) -> Box<dyn VirtioQueueWorkerContext + Send>;
484 fn state_change(&mut self, state: &VirtioState);
485}
486
487pub trait VirtioDevice: Send {
488 fn traits(&self) -> DeviceTraits;
489 fn read_registers_u32(&self, offset: u16) -> u32;
490 fn write_registers_u32(&mut self, offset: u16, val: u32);
491 fn enable(&mut self, resources: Resources);
492 fn disable(&mut self);
493}
494
495pub struct QueueResources {
496 pub params: QueueParams,
497 pub notify: Interrupt,
498 pub event: Event,
499}
500
501pub struct Resources {
502 pub features: u64,
503 pub queues: Vec<QueueResources>,
504 pub shared_memory_region: Option<Arc<dyn MappedMemoryRegion>>,
505 pub shared_memory_size: u64,
506}
507
508pub struct LegacyWrapper<T: LegacyVirtioDevice> {
510 device: T,
511 driver: VmTaskDriver,
512 mem: GuestMemory,
513 workers: Vec<TaskControl<VirtioQueueWorker, VirtioQueueState>>,
514 exit_event: event_listener::Event,
515}
516
517impl<T: LegacyVirtioDevice> LegacyWrapper<T> {
518 pub fn new(driver_source: &VmTaskDriverSource, device: T, mem: &GuestMemory) -> Self {
519 Self {
520 device,
521 driver: driver_source.simple(),
522 mem: mem.clone(),
523 workers: Vec::new(),
524 exit_event: event_listener::Event::new(),
525 }
526 }
527}
528
529impl<T: LegacyVirtioDevice> VirtioDevice for LegacyWrapper<T> {
530 fn traits(&self) -> DeviceTraits {
531 self.device.traits()
532 }
533
534 fn read_registers_u32(&self, offset: u16) -> u32 {
535 self.device.read_registers_u32(offset)
536 }
537
538 fn write_registers_u32(&mut self, offset: u16, val: u32) {
539 self.device.write_registers_u32(offset, val)
540 }
541
542 fn enable(&mut self, resources: Resources) {
543 let running_state = VirtioRunningState {
544 features: resources.features,
545 enabled_queues: resources
546 .queues
547 .iter()
548 .map(|QueueResources { params, .. }| params.enable)
549 .collect(),
550 };
551
552 self.device
553 .state_change(&VirtioState::Running(running_state));
554 self.workers = resources
555 .queues
556 .into_iter()
557 .enumerate()
558 .filter_map(|(i, queue_resources)| {
559 if !queue_resources.params.enable {
560 return None;
561 }
562 let worker = VirtioQueueWorker::new(
563 self.driver.clone(),
564 self.device.get_work_callback(i as u16),
565 );
566 Some(worker.into_running_task(
567 "virtio-queue".to_string(),
568 self.mem.clone(),
569 resources.features,
570 queue_resources,
571 self.exit_event.listen(),
572 ))
573 })
574 .collect();
575 }
576
577 fn disable(&mut self) {
578 if self.workers.is_empty() {
579 return;
580 }
581 self.exit_event.notify(usize::MAX);
582 self.device.state_change(&VirtioState::Stopped);
583 let mut workers = self.workers.drain(..).collect::<Vec<_>>();
584 self.driver
585 .spawn("shutdown-legacy-virtio-queues".to_owned(), async move {
586 futures::future::join_all(workers.iter_mut().map(async |worker| {
587 worker.stop().await;
588 if let Some(VirtioQueueStateInner::Running { queue, .. }) =
589 worker.state_mut().map(|s| &s.inner)
590 {
591 queue.wait_for_outstanding_descriptors().await;
592 }
593 }))
594 .await;
595 })
596 .detach();
597 }
598}
599
600impl<T: LegacyVirtioDevice> Drop for LegacyWrapper<T> {
601 fn drop(&mut self) {
602 self.disable();
603 }
604}
605
606#[expect(unsafe_code)]
609#[cfg(test)]
610mod tests {
611 use super::*;
612 use crate::PciInterruptModel;
613 use crate::spec::pci::*;
614 use crate::spec::queue::*;
615 use crate::spec::*;
616 use crate::transport::VirtioMmioDevice;
617 use crate::transport::VirtioPciDevice;
618 use chipset_device::mmio::ExternallyManagedMmioIntercepts;
619 use chipset_device::mmio::MmioIntercept;
620 use chipset_device::pci::PciConfigSpace;
621 use futures::StreamExt;
622 use guestmem::GuestMemoryAccess;
623 use guestmem::GuestMemoryBackingError;
624 use pal_async::DefaultDriver;
625 use pal_async::async_test;
626 use pal_async::timer::PolledTimer;
627 use pci_core::msi::MsiInterruptSet;
628 use pci_core::spec::caps::CapabilityId;
629 use pci_core::spec::cfg_space;
630 use pci_core::test_helpers::TestPciInterruptController;
631 use std::collections::BTreeMap;
632 use std::future::poll_fn;
633 use std::io;
634 use std::ptr::NonNull;
635 use std::time::Duration;
636 use test_with_tracing::test;
637 use vmcore::line_interrupt::LineInterrupt;
638 use vmcore::line_interrupt::test_helpers::TestLineInterruptTarget;
639 use vmcore::vm_task::SingleDriverBackend;
640 use vmcore::vm_task::VmTaskDriverSource;
641
642 async fn must_recv_in_timeout<T: 'static + Send>(
643 recv: &mut mesh::Receiver<T>,
644 timeout: Duration,
645 ) -> T {
646 mesh::CancelContext::new()
647 .with_timeout(timeout)
648 .until_cancelled(recv.next())
649 .await
650 .unwrap()
651 .unwrap()
652 }
653
654 #[derive(Default)]
655 struct VirtioTestMemoryAccess {
656 memory_map: Mutex<MemoryMap>,
657 }
658
659 #[derive(Default)]
660 struct MemoryMap {
661 map: BTreeMap<u64, (bool, Vec<u8>)>,
662 }
663
664 impl MemoryMap {
665 fn get(&mut self, address: u64, len: usize) -> Option<(bool, &mut [u8])> {
666 let (&base, &mut (writable, ref mut data)) = self.map.range_mut(..=address).last()?;
667 let data = data
668 .get_mut(usize::try_from(address - base).ok()?..)?
669 .get_mut(..len)?;
670
671 Some((writable, data))
672 }
673
674 fn insert(&mut self, address: u64, data: &[u8], writable: bool) {
675 if let Some((is_writable, v)) = self.get(address, data.len()) {
676 assert_eq!(writable, is_writable);
677 v.copy_from_slice(data);
678 return;
679 }
680
681 let end = address + data.len() as u64;
682 let mut data = data.to_vec();
683 if let Some((&next, &(next_writable, ref next_data))) = self.map.range(address..).next()
684 {
685 if end > next {
686 let next_end = next + next_data.len() as u64;
687 panic!(
688 "overlapping memory map: {address:#x}..{end:#x} > {next:#x}..={next_end:#x}"
689 );
690 }
691 if end == next && next_writable == writable {
692 data.extend(next_data.as_slice());
693 self.map.remove(&next).unwrap();
694 }
695 }
696
697 if let Some((&prev, &mut (prev_writable, ref mut prev_data))) =
698 self.map.range_mut(..address).last()
699 {
700 let prev_end = prev + prev_data.len() as u64;
701 if prev_end > address {
702 panic!(
703 "overlapping memory map: {prev:#x}..{prev_end:#x} > {address:#x}..={end:#x}"
704 );
705 }
706 if prev_end == address && prev_writable == writable {
707 prev_data.extend_from_slice(&data);
708 return;
709 }
710 }
711
712 self.map.insert(address, (writable, data));
713 }
714 }
715
716 impl VirtioTestMemoryAccess {
717 fn new() -> Arc<Self> {
718 Default::default()
719 }
720
721 fn modify_memory_map(&self, address: u64, data: &[u8], writeable: bool) {
722 self.memory_map.lock().insert(address, data, writeable);
723 }
724
725 fn memory_map_get_u16(&self, address: u64) -> u16 {
726 let mut map = self.memory_map.lock();
727 let (_, data) = map.get(address, 2).unwrap();
728 u16::from_le_bytes(data.try_into().unwrap())
729 }
730
731 fn memory_map_get_u32(&self, address: u64) -> u32 {
732 let mut map = self.memory_map.lock();
733 let (_, data) = map.get(address, 4).unwrap();
734 u32::from_le_bytes(data.try_into().unwrap())
735 }
736 }
737
738 unsafe impl GuestMemoryAccess for VirtioTestMemoryAccess {
740 fn mapping(&self) -> Option<NonNull<u8>> {
741 None
742 }
743
744 fn max_address(&self) -> u64 {
745 1 << 52
748 }
749
750 unsafe fn read_fallback(
751 &self,
752 address: u64,
753 dest: *mut u8,
754 len: usize,
755 ) -> Result<(), GuestMemoryBackingError> {
756 match self.memory_map.lock().get(address, len) {
757 Some((_, value)) => {
758 unsafe {
760 std::ptr::copy(value.as_ptr(), dest, len);
761 }
762 }
763 None => panic!("Unexpected read request at address {:x}", address),
764 }
765 Ok(())
766 }
767
768 unsafe fn write_fallback(
769 &self,
770 address: u64,
771 src: *const u8,
772 len: usize,
773 ) -> Result<(), GuestMemoryBackingError> {
774 match self.memory_map.lock().get(address, len) {
775 Some((true, value)) => {
776 unsafe {
778 std::ptr::copy(src, value.as_mut_ptr(), len);
779 }
780 }
781 _ => panic!("Unexpected write request at address {:x}", address),
782 }
783 Ok(())
784 }
785
786 fn fill_fallback(
787 &self,
788 address: u64,
789 val: u8,
790 len: usize,
791 ) -> Result<(), GuestMemoryBackingError> {
792 match self.memory_map.lock().get(address, len) {
793 Some((true, value)) => value.fill(val),
794 _ => panic!("Unexpected write request at address {:x}", address),
795 };
796 Ok(())
797 }
798 }
799
800 struct DoorbellEntry;
801 impl Drop for DoorbellEntry {
802 fn drop(&mut self) {}
803 }
804
805 impl DoorbellRegistration for VirtioTestMemoryAccess {
806 fn register_doorbell(
807 &self,
808 _: u64,
809 _: Option<u64>,
810 _: Option<u32>,
811 _: &Event,
812 ) -> io::Result<Box<dyn Send + Sync>> {
813 Ok(Box::new(DoorbellEntry))
814 }
815 }
816
817 type VirtioTestWorkCallback =
818 Box<dyn Fn(anyhow::Result<VirtioQueueCallbackWork>) -> bool + Sync + Send>;
819 struct CreateDirectQueueParams {
820 process_work: VirtioTestWorkCallback,
821 notify: Interrupt,
822 event: Event,
823 }
824
825 struct VirtioTestGuest {
826 test_mem: Arc<VirtioTestMemoryAccess>,
827 driver: DefaultDriver,
828 num_queues: u16,
829 queue_size: u16,
830 use_ring_event_index: bool,
831 last_avail_index: Vec<u16>,
832 last_used_index: Vec<u16>,
833 avail_descriptors: Vec<Vec<bool>>,
834 exit_event: event_listener::Event,
835 }
836
837 impl VirtioTestGuest {
838 fn new(
839 driver: &DefaultDriver,
840 test_mem: &Arc<VirtioTestMemoryAccess>,
841 num_queues: u16,
842 queue_size: u16,
843 use_ring_event_index: bool,
844 ) -> Self {
845 let last_avail_index: Vec<u16> = vec![0; num_queues as usize];
846 let last_used_index: Vec<u16> = vec![0; num_queues as usize];
847 let avail_descriptors: Vec<Vec<bool>> =
848 vec![vec![true; queue_size as usize]; num_queues as usize];
849 let test_guest = Self {
850 test_mem: test_mem.clone(),
851 driver: driver.clone(),
852 num_queues,
853 queue_size,
854 use_ring_event_index,
855 last_avail_index,
856 last_used_index,
857 avail_descriptors,
858 exit_event: event_listener::Event::new(),
859 };
860 for i in 0..num_queues {
861 test_guest.add_queue_memory(i);
862 }
863 test_guest
864 }
865
866 fn mem(&self) -> GuestMemory {
867 GuestMemory::new("test", self.test_mem.clone())
868 }
869
870 fn create_direct_queues<F>(
871 &self,
872 f: F,
873 ) -> Vec<TaskControl<VirtioQueueWorker, VirtioQueueState>>
874 where
875 F: Fn(u16) -> CreateDirectQueueParams,
876 {
877 (0..self.num_queues)
878 .map(|i| {
879 let params = f(i);
880 let worker = VirtioQueueWorker::new(
881 self.driver.clone(),
882 Box::new(VirtioTestWork {
883 callback: params.process_work,
884 }),
885 );
886 worker.into_running_task(
887 "virtio-test-queue".to_string(),
888 self.mem(),
889 self.queue_features(),
890 QueueResources {
891 params: self.queue_params(i),
892 notify: params.notify,
893 event: params.event,
894 },
895 self.exit_event.listen(),
896 )
897 })
898 .collect::<Vec<_>>()
899 }
900
901 fn queue_features(&self) -> u64 {
902 if self.use_ring_event_index {
903 VIRTIO_F_RING_EVENT_IDX as u64
904 } else {
905 0
906 }
907 }
908
909 fn queue_params(&self, i: u16) -> QueueParams {
910 QueueParams {
911 size: self.queue_size,
912 enable: true,
913 desc_addr: self.get_queue_descriptor_base_address(i),
914 avail_addr: self.get_queue_available_base_address(i),
915 used_addr: self.get_queue_used_base_address(i),
916 }
917 }
918
919 fn get_queue_base_address(&self, index: u16) -> u64 {
920 0x10000000 * index as u64
921 }
922
923 fn get_queue_descriptor_base_address(&self, index: u16) -> u64 {
924 self.get_queue_base_address(index) + 0x1000
925 }
926
927 fn get_queue_available_base_address(&self, index: u16) -> u64 {
928 self.get_queue_base_address(index) + 0x2000
929 }
930
931 fn get_queue_used_base_address(&self, index: u16) -> u64 {
932 self.get_queue_base_address(index) + 0x3000
933 }
934
935 fn get_queue_descriptor_backing_memory_address(&self, index: u16) -> u64 {
936 self.get_queue_base_address(index) + 0x4000
937 }
938
939 fn setup_chipset_device(&self, dev: &mut VirtioMmioDevice, driver_features: u64) {
940 dev.write_u32(112, VIRTIO_ACKNOWLEDGE);
941 dev.write_u32(112, VIRTIO_DRIVER);
942 dev.write_u32(36, 0);
943 dev.write_u32(32, driver_features as u32);
944 dev.write_u32(36, 1);
945 dev.write_u32(32, (driver_features >> 32) as u32);
946 dev.write_u32(112, VIRTIO_FEATURES_OK);
947 for i in 0..self.num_queues {
948 let queue_index = i;
949 dev.write_u32(48, i as u32);
950 dev.write_u32(56, self.queue_size as u32);
951 let desc_addr = self.get_queue_descriptor_base_address(queue_index);
952 dev.write_u32(128, desc_addr as u32);
953 dev.write_u32(132, (desc_addr >> 32) as u32);
954 let avail_addr = self.get_queue_available_base_address(queue_index);
955 dev.write_u32(144, avail_addr as u32);
956 dev.write_u32(148, (avail_addr >> 32) as u32);
957 let used_addr = self.get_queue_used_base_address(queue_index);
958 dev.write_u32(160, used_addr as u32);
959 dev.write_u32(164, (used_addr >> 32) as u32);
960 dev.write_u32(68, 1);
962 }
963 dev.write_u32(112, VIRTIO_DRIVER_OK);
964 assert_eq!(dev.read_u32(0xfc), 2);
965 }
966
967 fn setup_pci_device(&self, dev: &mut VirtioPciTestDevice, driver_features: u64) {
968 let bar_address1: u64 = 0x10000000000;
969 dev.pci_device
970 .pci_cfg_write(0x14, (bar_address1 >> 32) as u32)
971 .unwrap();
972 dev.pci_device
973 .pci_cfg_write(0x10, bar_address1 as u32)
974 .unwrap();
975
976 let bar_address2: u64 = 0x20000000000;
977 dev.pci_device
978 .pci_cfg_write(0x1c, (bar_address2 >> 32) as u32)
979 .unwrap();
980 dev.pci_device
981 .pci_cfg_write(0x18, bar_address2 as u32)
982 .unwrap();
983
984 dev.pci_device
985 .pci_cfg_write(
986 0x4,
987 cfg_space::Command::new()
988 .with_mmio_enabled(true)
989 .into_bits() as u32,
990 )
991 .unwrap();
992
993 let mut device_status = VIRTIO_ACKNOWLEDGE as u8;
994 dev.pci_device
995 .mmio_write(bar_address1 + 20, &device_status.to_le_bytes())
996 .unwrap();
997 device_status = VIRTIO_DRIVER as u8;
998 dev.pci_device
999 .mmio_write(bar_address1 + 20, &device_status.to_le_bytes())
1000 .unwrap();
1001 dev.write_u32(bar_address1 + 8, 0);
1002 dev.write_u32(bar_address1 + 12, driver_features as u32);
1003 dev.write_u32(bar_address1 + 8, 1);
1004 dev.write_u32(bar_address1 + 12, (driver_features >> 32) as u32);
1005 device_status = VIRTIO_FEATURES_OK as u8;
1006 dev.pci_device
1007 .mmio_write(bar_address1 + 20, &device_status.to_le_bytes())
1008 .unwrap();
1009 dev.pci_device
1011 .mmio_write(bar_address2, &0_u64.to_le_bytes())
1012 .unwrap(); dev.pci_device
1014 .mmio_write(bar_address2 + 8, &0_u32.to_le_bytes())
1015 .unwrap(); dev.pci_device
1017 .mmio_write(bar_address2 + 12, &0_u32.to_le_bytes())
1018 .unwrap();
1019 for i in 0..self.num_queues {
1020 let queue_index = i;
1021 dev.pci_device
1022 .mmio_write(bar_address1 + 22, &queue_index.to_le_bytes())
1023 .unwrap();
1024 dev.pci_device
1025 .mmio_write(bar_address1 + 24, &self.queue_size.to_le_bytes())
1026 .unwrap();
1027 let msix_vector = queue_index + 1;
1029 let address = bar_address2 + 0x10 * msix_vector as u64;
1030 dev.pci_device
1031 .mmio_write(address, &(msix_vector as u64).to_le_bytes())
1032 .unwrap();
1033 let address = bar_address2 + 0x10 * msix_vector as u64 + 8;
1034 dev.pci_device
1035 .mmio_write(address, &0_u32.to_le_bytes())
1036 .unwrap();
1037 let address = bar_address2 + 0x10 * msix_vector as u64 + 12;
1038 dev.pci_device
1039 .mmio_write(address, &0_u32.to_le_bytes())
1040 .unwrap();
1041 dev.pci_device
1042 .mmio_write(bar_address1 + 26, &msix_vector.to_le_bytes())
1043 .unwrap();
1044 let desc_addr = self.get_queue_descriptor_base_address(queue_index);
1046 dev.write_u32(bar_address1 + 32, desc_addr as u32);
1047 dev.write_u32(bar_address1 + 36, (desc_addr >> 32) as u32);
1048 let avail_addr = self.get_queue_available_base_address(queue_index);
1049 dev.write_u32(bar_address1 + 40, avail_addr as u32);
1050 dev.write_u32(bar_address1 + 44, (avail_addr >> 32) as u32);
1051 let used_addr = self.get_queue_used_base_address(queue_index);
1052 dev.write_u32(bar_address1 + 48, used_addr as u32);
1053 dev.write_u32(bar_address1 + 52, (used_addr >> 32) as u32);
1054 let enabled: u16 = 1;
1056 dev.pci_device
1057 .mmio_write(bar_address1 + 28, &enabled.to_le_bytes())
1058 .unwrap();
1059 }
1060 dev.pci_device.pci_cfg_write(0x40, 0x80000000).unwrap();
1062 device_status = VIRTIO_DRIVER_OK as u8;
1064 dev.pci_device
1065 .mmio_write(bar_address1 + 20, &device_status.to_le_bytes())
1066 .unwrap();
1067 let mut config_generation: [u8; 1] = [0];
1068 dev.pci_device
1069 .mmio_read(bar_address1 + 21, &mut config_generation)
1070 .unwrap();
1071 assert_eq!(config_generation[0], 2);
1072 }
1073
1074 fn get_queue_descriptor(&self, queue_index: u16, descriptor_index: u16) -> u64 {
1075 self.get_queue_descriptor_base_address(queue_index) + 0x10 * descriptor_index as u64
1076 }
1077
1078 fn add_queue_memory(&self, queue_index: u16) {
1079 for i in 0..self.queue_size {
1081 let base = self.get_queue_descriptor(queue_index, i);
1082 self.test_mem.modify_memory_map(
1084 base,
1085 &(self.get_queue_descriptor_backing_memory_address(queue_index)
1086 + 0x1000 * i as u64)
1087 .to_le_bytes(),
1088 false,
1089 );
1090 self.test_mem
1092 .modify_memory_map(base + 8, &0x1000u32.to_le_bytes(), false);
1093 self.test_mem
1095 .modify_memory_map(base + 12, &0u16.to_le_bytes(), false);
1096 self.test_mem
1098 .modify_memory_map(base + 14, &0u16.to_le_bytes(), false);
1099 }
1100
1101 let base = self.get_queue_available_base_address(queue_index);
1103 self.test_mem
1104 .modify_memory_map(base, &0u16.to_le_bytes(), false);
1105 self.test_mem
1106 .modify_memory_map(base + 2, &0u16.to_le_bytes(), false);
1107 for i in 0..self.queue_size {
1109 let base = base + 4 + 2 * i as u64;
1110 self.test_mem
1111 .modify_memory_map(base, &0u16.to_le_bytes(), false);
1112 }
1113 if self.use_ring_event_index {
1115 self.test_mem.modify_memory_map(
1116 base + 4 + 2 * self.queue_size as u64,
1117 &0u16.to_le_bytes(),
1118 false,
1119 );
1120 }
1121
1122 let base = self.get_queue_used_base_address(queue_index);
1124 self.test_mem
1125 .modify_memory_map(base, &0u16.to_le_bytes(), true);
1126 self.test_mem
1127 .modify_memory_map(base + 2, &0u16.to_le_bytes(), true);
1128 for i in 0..self.queue_size {
1129 let base = base + 4 + 8 * i as u64;
1130 self.test_mem
1132 .modify_memory_map(base, &0u32.to_le_bytes(), true);
1133 self.test_mem
1135 .modify_memory_map(base + 4, &0u32.to_le_bytes(), true);
1136 }
1137 if self.use_ring_event_index {
1139 self.test_mem.modify_memory_map(
1140 base + 4 + 8 * self.queue_size as u64,
1141 &0u16.to_le_bytes(),
1142 true,
1143 );
1144 }
1145 }
1146
1147 fn reserve_descriptor(&mut self, queue_index: u16) -> u16 {
1148 let avail_descriptors = &mut self.avail_descriptors[queue_index as usize];
1149 for (i, desc) in avail_descriptors.iter_mut().enumerate() {
1150 if *desc {
1151 *desc = false;
1152 return i as u16;
1153 }
1154 }
1155
1156 panic!("No descriptors are available!");
1157 }
1158
1159 fn free_descriptor(&mut self, queue_index: u16, desc_index: u16) {
1160 assert!(desc_index < self.queue_size);
1161 let desc_addr = self.get_queue_descriptor(queue_index, desc_index);
1162 let flags: DescriptorFlags = self.test_mem.memory_map_get_u16(desc_addr + 12).into();
1163 if flags.next() {
1164 let next = self.test_mem.memory_map_get_u16(desc_addr + 14);
1165 self.free_descriptor(queue_index, next);
1166 }
1167 let avail_descriptors = &mut self.avail_descriptors[queue_index as usize];
1168 assert_eq!(avail_descriptors[desc_index as usize], false);
1169 avail_descriptors[desc_index as usize] = true;
1170 }
1171
1172 fn queue_available_desc(&mut self, queue_index: u16, desc_index: u16) {
1173 let avail_base_addr = self.get_queue_available_base_address(queue_index);
1174 let last_avail_index = &mut self.last_avail_index[queue_index as usize];
1175 let next_index = *last_avail_index % self.queue_size;
1176 *last_avail_index = last_avail_index.wrapping_add(1);
1177 self.test_mem.modify_memory_map(
1178 avail_base_addr + 4 + 2 * next_index as u64,
1179 &desc_index.to_le_bytes(),
1180 false,
1181 );
1182 self.test_mem.modify_memory_map(
1183 avail_base_addr + 2,
1184 &last_avail_index.to_le_bytes(),
1185 false,
1186 );
1187 }
1188
1189 fn add_to_avail_queue(&mut self, queue_index: u16) {
1190 let next_descriptor = self.reserve_descriptor(queue_index);
1191 self.test_mem.modify_memory_map(
1193 self.get_queue_descriptor(queue_index, next_descriptor) + 12,
1194 &0u16.to_le_bytes(),
1195 false,
1196 );
1197 self.queue_available_desc(queue_index, next_descriptor);
1198 }
1199
1200 fn add_indirect_to_avail_queue(&mut self, queue_index: u16) {
1201 let next_descriptor = self.reserve_descriptor(queue_index);
1202 self.test_mem.modify_memory_map(
1204 self.get_queue_descriptor(queue_index, next_descriptor) + 12,
1205 &u16::from(DescriptorFlags::new().with_indirect(true)).to_le_bytes(),
1206 false,
1207 );
1208 let buffer_addr = self.get_queue_descriptor_backing_memory_address(queue_index);
1210 self.test_mem.modify_memory_map(
1212 buffer_addr,
1213 &0xffffffff00000000u64.to_le_bytes(),
1214 false,
1215 );
1216 self.test_mem
1218 .modify_memory_map(buffer_addr + 8, &0x1000u32.to_le_bytes(), false);
1219 self.test_mem
1221 .modify_memory_map(buffer_addr + 12, &0u16.to_le_bytes(), false);
1222 self.test_mem
1224 .modify_memory_map(buffer_addr + 14, &0u16.to_le_bytes(), false);
1225 self.queue_available_desc(queue_index, next_descriptor);
1226 }
1227
1228 fn add_linked_to_avail_queue(&mut self, queue_index: u16, desc_count: u16) {
1229 let mut descriptors = Vec::with_capacity(desc_count as usize);
1230 for _ in 0..desc_count {
1231 descriptors.push(self.reserve_descriptor(queue_index));
1232 }
1233
1234 for i in 0..descriptors.len() {
1235 let base = self.get_queue_descriptor(queue_index, descriptors[i]);
1236 let flags = if i < descriptors.len() - 1 {
1237 u16::from(DescriptorFlags::new().with_next(true))
1238 } else {
1239 0
1240 };
1241 self.test_mem
1242 .modify_memory_map(base + 12, &flags.to_le_bytes(), false);
1243 let next = if i < descriptors.len() - 1 {
1244 descriptors[i + 1]
1245 } else {
1246 0
1247 };
1248 self.test_mem
1249 .modify_memory_map(base + 14, &next.to_le_bytes(), false);
1250 }
1251 self.queue_available_desc(queue_index, descriptors[0]);
1252 }
1253
1254 fn add_indirect_linked_to_avail_queue(&mut self, queue_index: u16, desc_count: u16) {
1255 let next_descriptor = self.reserve_descriptor(queue_index);
1256 self.test_mem.modify_memory_map(
1258 self.get_queue_descriptor(queue_index, next_descriptor) + 12,
1259 &u16::from(DescriptorFlags::new().with_indirect(true)).to_le_bytes(),
1260 false,
1261 );
1262 let buffer_addr = self.get_queue_descriptor_backing_memory_address(queue_index);
1264 for i in 0..desc_count {
1265 let base = buffer_addr + 0x10 * i as u64;
1266 let indirect_buffer_addr = 0xffffffff00000000u64 + 0x1000 * i as u64;
1267 self.test_mem
1269 .modify_memory_map(base, &indirect_buffer_addr.to_le_bytes(), false);
1270 self.test_mem
1272 .modify_memory_map(base + 8, &0x1000u32.to_le_bytes(), false);
1273 let flags = if i < desc_count - 1 {
1275 u16::from(DescriptorFlags::new().with_next(true))
1276 } else {
1277 0
1278 };
1279 self.test_mem
1280 .modify_memory_map(base + 12, &flags.to_le_bytes(), false);
1281 let next = if i < desc_count - 1 { i + 1 } else { 0 };
1283 self.test_mem
1284 .modify_memory_map(base + 14, &next.to_le_bytes(), false);
1285 }
1286 self.queue_available_desc(queue_index, next_descriptor);
1287 }
1288
1289 fn get_next_completed(&mut self, queue_index: u16) -> Option<(u16, u32)> {
1290 let avail_base_addr = self.get_queue_available_base_address(queue_index);
1291 let used_base_addr = self.get_queue_used_base_address(queue_index);
1292 let cur_used_index = self.test_mem.memory_map_get_u16(used_base_addr + 2);
1293 let last_used_index = &mut self.last_used_index[queue_index as usize];
1294 if *last_used_index == cur_used_index {
1295 return None;
1296 }
1297
1298 if self.use_ring_event_index {
1299 self.test_mem.modify_memory_map(
1300 avail_base_addr + 4 + 2 * self.queue_size as u64,
1301 &cur_used_index.to_le_bytes(),
1302 false,
1303 );
1304 }
1305
1306 let next_index = *last_used_index % self.queue_size;
1307 *last_used_index = last_used_index.wrapping_add(1);
1308 let desc_index = self
1309 .test_mem
1310 .memory_map_get_u32(used_base_addr + 4 + 8 * next_index as u64);
1311 let desc_index = desc_index as u16;
1312 let bytes_written = self
1313 .test_mem
1314 .memory_map_get_u32(used_base_addr + 8 + 8 * next_index as u64);
1315 self.free_descriptor(queue_index, desc_index);
1316 Some((desc_index, bytes_written))
1317 }
1318 }
1319
1320 struct VirtioTestWork {
1321 callback: VirtioTestWorkCallback,
1322 }
1323
1324 #[async_trait]
1325 impl VirtioQueueWorkerContext for VirtioTestWork {
1326 async fn process_work(&mut self, work: anyhow::Result<VirtioQueueCallbackWork>) -> bool {
1327 (self.callback)(work)
1328 }
1329 }
1330 struct VirtioPciTestDevice {
1331 pci_device: VirtioPciDevice,
1332 test_intc: Arc<TestPciInterruptController>,
1333 }
1334
1335 type TestDeviceQueueWorkFn = Arc<dyn Fn(u16, VirtioQueueCallbackWork) + Send + Sync>;
1336
1337 struct TestDevice {
1338 traits: DeviceTraits,
1339 queue_work: Option<TestDeviceQueueWorkFn>,
1340 }
1341
1342 impl TestDevice {
1343 fn new(traits: DeviceTraits, queue_work: Option<TestDeviceQueueWorkFn>) -> Self {
1344 Self { traits, queue_work }
1345 }
1346 }
1347
1348 impl LegacyVirtioDevice for TestDevice {
1349 fn traits(&self) -> DeviceTraits {
1350 self.traits
1351 }
1352
1353 fn read_registers_u32(&self, _offset: u16) -> u32 {
1354 0
1355 }
1356
1357 fn write_registers_u32(&mut self, _offset: u16, _val: u32) {}
1358
1359 fn get_work_callback(&mut self, index: u16) -> Box<dyn VirtioQueueWorkerContext + Send> {
1360 Box::new(TestDeviceWorker {
1361 index,
1362 queue_work: self.queue_work.clone(),
1363 })
1364 }
1365
1366 fn state_change(&mut self, _state: &VirtioState) {}
1367 }
1368
1369 struct TestDeviceWorker {
1370 index: u16,
1371 queue_work: Option<TestDeviceQueueWorkFn>,
1372 }
1373
1374 #[async_trait]
1375 impl VirtioQueueWorkerContext for TestDeviceWorker {
1376 async fn process_work(&mut self, work: anyhow::Result<VirtioQueueCallbackWork>) -> bool {
1377 if let Err(err) = work {
1378 panic!(
1379 "Invalid virtio queue state index {} error {}",
1380 self.index,
1381 err.as_ref() as &dyn std::error::Error
1382 );
1383 }
1384 if let Some(ref func) = self.queue_work {
1385 (func)(self.index, work.unwrap());
1386 }
1387 true
1388 }
1389 }
1390
1391 impl VirtioPciTestDevice {
1392 fn new(
1393 driver: &DefaultDriver,
1394 num_queues: u16,
1395 test_mem: &Arc<VirtioTestMemoryAccess>,
1396 queue_work: Option<TestDeviceQueueWorkFn>,
1397 ) -> Self {
1398 let doorbell_registration: Arc<dyn DoorbellRegistration> = test_mem.clone();
1399 let mem = GuestMemory::new("test", test_mem.clone());
1400 let mut msi_set = MsiInterruptSet::new();
1401
1402 let dev = VirtioPciDevice::new(
1403 Box::new(LegacyWrapper::new(
1404 &VmTaskDriverSource::new(SingleDriverBackend::new(driver.clone())),
1405 TestDevice::new(
1406 DeviceTraits {
1407 device_id: 3,
1408 device_features: 2,
1409 max_queues: num_queues,
1410 device_register_length: 12,
1411 ..Default::default()
1412 },
1413 queue_work,
1414 ),
1415 &mem,
1416 )),
1417 PciInterruptModel::Msix(&mut msi_set),
1418 Some(doorbell_registration),
1419 &mut ExternallyManagedMmioIntercepts,
1420 None,
1421 )
1422 .unwrap();
1423
1424 let test_intc = Arc::new(TestPciInterruptController::new());
1425 msi_set.connect(test_intc.as_ref());
1426
1427 Self {
1428 pci_device: dev,
1429 test_intc,
1430 }
1431 }
1432
1433 fn read_u32(&mut self, address: u64) -> u32 {
1434 let mut value = [0; 4];
1435 self.pci_device.mmio_read(address, &mut value).unwrap();
1436 u32::from_ne_bytes(value)
1437 }
1438
1439 fn write_u32(&mut self, address: u64, value: u32) {
1440 self.pci_device
1441 .mmio_write(address, &value.to_ne_bytes())
1442 .unwrap();
1443 }
1444 }
1445
1446 #[async_test]
1447 async fn verify_chipset_config(driver: DefaultDriver) {
1448 let mem = VirtioTestMemoryAccess::new();
1449 let doorbell_registration: Arc<dyn DoorbellRegistration> = mem.clone();
1450 let mem = GuestMemory::new("test", mem);
1451 let interrupt = LineInterrupt::detached();
1452
1453 let mut dev = VirtioMmioDevice::new(
1454 Box::new(LegacyWrapper::new(
1455 &VmTaskDriverSource::new(SingleDriverBackend::new(driver)),
1456 TestDevice::new(
1457 DeviceTraits {
1458 device_id: 3,
1459 device_features: 2,
1460 max_queues: 1,
1461 device_register_length: 0,
1462 ..Default::default()
1463 },
1464 None,
1465 ),
1466 &mem,
1467 )),
1468 interrupt,
1469 Some(doorbell_registration),
1470 0,
1471 1,
1472 );
1473 assert_eq!(dev.read_u32(0), u32::from_le_bytes(*b"virt"));
1475 assert_eq!(dev.read_u32(4), 2);
1477 assert_eq!(dev.read_u32(8), 3);
1479 assert_eq!(dev.read_u32(12), 0x1af4);
1481 assert_eq!(
1483 dev.read_u32(16),
1484 VIRTIO_F_RING_INDIRECT_DESC | VIRTIO_F_RING_EVENT_IDX | 2
1485 );
1486 assert_eq!(dev.read_u32(20), 0);
1488 dev.write_u32(20, 1);
1490 assert_eq!(dev.read_u32(20), 1);
1491 assert_eq!(dev.read_u32(16), VIRTIO_F_VERSION_1);
1492 dev.write_u32(20, 2);
1494 assert_eq!(dev.read_u32(16), 0);
1495 assert_eq!(dev.read_u32(32), 0);
1497 dev.write_u32(32, 2);
1498 assert_eq!(dev.read_u32(32), 2);
1499 dev.write_u32(32, 0xffffffff);
1500 assert_eq!(
1501 dev.read_u32(32),
1502 VIRTIO_F_RING_INDIRECT_DESC | VIRTIO_F_RING_EVENT_IDX | 2
1503 );
1504 assert_eq!(dev.read_u32(36), 0);
1506 dev.write_u32(36, 1);
1507 assert_eq!(dev.read_u32(36), 1);
1508 assert_eq!(dev.read_u32(32), 0);
1510 dev.write_u32(32, 0xffffffff);
1511 assert_eq!(dev.read_u32(32), VIRTIO_F_VERSION_1);
1512 dev.write_u32(36, 2);
1514 assert_eq!(dev.read_u32(32), 0);
1515 dev.write_u32(32, 0xffffffff);
1516 assert_eq!(dev.read_u32(32), 0);
1517 assert_eq!(dev.read_u32(80), 0);
1519 assert_eq!(dev.read_u32(96), 0);
1521 assert_eq!(dev.read_u32(100), 0);
1523 assert_eq!(dev.read_u32(112), 0);
1525 assert_eq!(dev.read_u32(0xfc), 0);
1527
1528 assert_eq!(dev.read_u32(48), 0);
1530 assert_eq!(dev.read_u32(52), 0x40);
1532 assert_eq!(dev.read_u32(56), 0x40);
1534 dev.write_u32(56, 0x20);
1535 assert_eq!(dev.read_u32(56), 0x20);
1536 assert_eq!(dev.read_u32(68), 0);
1538 dev.write_u32(68, 1);
1539 assert_eq!(dev.read_u32(68), 1);
1540 dev.write_u32(68, 0xffffffff);
1541 assert_eq!(dev.read_u32(68), 1);
1542 dev.write_u32(68, 0);
1543 assert_eq!(dev.read_u32(68), 0);
1544 assert_eq!(dev.read_u32(128), 0);
1546 dev.write_u32(128, 0xffff);
1547 assert_eq!(dev.read_u32(128), 0xffff);
1548 assert_eq!(dev.read_u32(132), 0);
1550 dev.write_u32(132, 1);
1551 assert_eq!(dev.read_u32(132), 1);
1552 assert_eq!(dev.read_u32(144), 0);
1554 dev.write_u32(144, 0xeeee);
1555 assert_eq!(dev.read_u32(144), 0xeeee);
1556 assert_eq!(dev.read_u32(148), 0);
1558 dev.write_u32(148, 2);
1559 assert_eq!(dev.read_u32(148), 2);
1560 assert_eq!(dev.read_u32(160), 0);
1562 dev.write_u32(160, 0xdddd);
1563 assert_eq!(dev.read_u32(160), 0xdddd);
1564 assert_eq!(dev.read_u32(164), 0);
1566 dev.write_u32(164, 3);
1567 assert_eq!(dev.read_u32(164), 3);
1568
1569 dev.write_u32(48, 1);
1571 assert_eq!(dev.read_u32(48), 1);
1572 assert_eq!(dev.read_u32(52), 0);
1574 assert_eq!(dev.read_u32(56), 0);
1576 dev.write_u32(56, 2);
1577 assert_eq!(dev.read_u32(56), 0);
1578 assert_eq!(dev.read_u32(68), 0);
1580 dev.write_u32(68, 1);
1581 assert_eq!(dev.read_u32(68), 0);
1582 assert_eq!(dev.read_u32(128), 0);
1584 dev.write_u32(128, 1);
1585 assert_eq!(dev.read_u32(128), 0);
1586 assert_eq!(dev.read_u32(132), 0);
1588 dev.write_u32(132, 1);
1589 assert_eq!(dev.read_u32(132), 0);
1590 assert_eq!(dev.read_u32(144), 0);
1592 dev.write_u32(144, 1);
1593 assert_eq!(dev.read_u32(144), 0);
1594 assert_eq!(dev.read_u32(148), 0);
1596 dev.write_u32(148, 1);
1597 assert_eq!(dev.read_u32(148), 0);
1598 assert_eq!(dev.read_u32(160), 0);
1600 dev.write_u32(160, 1);
1601 assert_eq!(dev.read_u32(160), 0);
1602 assert_eq!(dev.read_u32(164), 0);
1604 dev.write_u32(164, 1);
1605 assert_eq!(dev.read_u32(164), 0);
1606 }
1607
1608 #[async_test]
1609 async fn verify_pci_config(driver: DefaultDriver) {
1610 let mut pci_test_device =
1611 VirtioPciTestDevice::new(&driver, 1, &VirtioTestMemoryAccess::new(), None);
1612 let mut capabilities = 0;
1613 pci_test_device
1614 .pci_device
1615 .pci_cfg_read(4, &mut capabilities)
1616 .unwrap();
1617 assert_eq!(
1618 capabilities,
1619 (cfg_space::Status::new()
1620 .with_capabilities_list(true)
1621 .into_bits() as u32)
1622 << 16
1623 );
1624 let mut next_cap_offset = 0;
1625 pci_test_device
1626 .pci_device
1627 .pci_cfg_read(0x34, &mut next_cap_offset)
1628 .unwrap();
1629 assert_ne!(next_cap_offset, 0);
1630
1631 let mut header = 0;
1632 pci_test_device
1633 .pci_device
1634 .pci_cfg_read(next_cap_offset as u16, &mut header)
1635 .unwrap();
1636 let header = header.to_le_bytes();
1637 assert_eq!(header[0], CapabilityId::MSIX.0);
1638 next_cap_offset = header[1] as u32;
1639 assert_ne!(next_cap_offset, 0);
1640
1641 let mut header = 0;
1642 pci_test_device
1643 .pci_device
1644 .pci_cfg_read(next_cap_offset as u16, &mut header)
1645 .unwrap();
1646 let header = header.to_le_bytes();
1647 assert_eq!(header[0], CapabilityId::VENDOR_SPECIFIC.0);
1648 assert_eq!(header[3], VIRTIO_PCI_CAP_COMMON_CFG);
1649 assert_eq!(header[2], 16);
1650 let mut buf = 0;
1651
1652 pci_test_device
1653 .pci_device
1654 .pci_cfg_read(next_cap_offset as u16 + 4, &mut buf)
1655 .unwrap();
1656 assert_eq!(buf, 0);
1657 pci_test_device
1658 .pci_device
1659 .pci_cfg_read(next_cap_offset as u16 + 8, &mut buf)
1660 .unwrap();
1661 assert_eq!(buf, 0);
1662 pci_test_device
1663 .pci_device
1664 .pci_cfg_read(next_cap_offset as u16 + 12, &mut buf)
1665 .unwrap();
1666 assert_eq!(buf, 0x38);
1667 next_cap_offset = header[1] as u32;
1668 assert_ne!(next_cap_offset, 0);
1669
1670 let mut header = 0;
1671 pci_test_device
1672 .pci_device
1673 .pci_cfg_read(next_cap_offset as u16, &mut header)
1674 .unwrap();
1675 let header = header.to_le_bytes();
1676 assert_eq!(header[0], CapabilityId::VENDOR_SPECIFIC.0);
1677 assert_eq!(header[3], VIRTIO_PCI_CAP_NOTIFY_CFG);
1678 assert_eq!(header[2], 20);
1679 pci_test_device
1680 .pci_device
1681 .pci_cfg_read(next_cap_offset as u16 + 4, &mut buf)
1682 .unwrap();
1683 assert_eq!(buf, 0);
1684 pci_test_device
1685 .pci_device
1686 .pci_cfg_read(next_cap_offset as u16 + 8, &mut buf)
1687 .unwrap();
1688 assert_eq!(buf, 0x38);
1689 pci_test_device
1690 .pci_device
1691 .pci_cfg_read(next_cap_offset as u16 + 12, &mut buf)
1692 .unwrap();
1693 assert_eq!(buf, 4);
1694 next_cap_offset = header[1] as u32;
1695 assert_ne!(next_cap_offset, 0);
1696
1697 let mut header = 0;
1698 pci_test_device
1699 .pci_device
1700 .pci_cfg_read(next_cap_offset as u16, &mut header)
1701 .unwrap();
1702 let header = header.to_le_bytes();
1703 assert_eq!(header[0], CapabilityId::VENDOR_SPECIFIC.0);
1704 assert_eq!(header[3], VIRTIO_PCI_CAP_ISR_CFG);
1705 assert_eq!(header[2], 16);
1706 pci_test_device
1707 .pci_device
1708 .pci_cfg_read(next_cap_offset as u16 + 4, &mut buf)
1709 .unwrap();
1710 assert_eq!(buf, 0);
1711 pci_test_device
1712 .pci_device
1713 .pci_cfg_read(next_cap_offset as u16 + 8, &mut buf)
1714 .unwrap();
1715 assert_eq!(buf, 0x3c);
1716 pci_test_device
1717 .pci_device
1718 .pci_cfg_read(next_cap_offset as u16 + 12, &mut buf)
1719 .unwrap();
1720 assert_eq!(buf, 4);
1721 next_cap_offset = header[1] as u32;
1722 assert_ne!(next_cap_offset, 0);
1723
1724 let mut header = 0;
1725 pci_test_device
1726 .pci_device
1727 .pci_cfg_read(next_cap_offset as u16, &mut header)
1728 .unwrap();
1729 let header = header.to_le_bytes();
1730 assert_eq!(header[0], CapabilityId::VENDOR_SPECIFIC.0);
1731 assert_eq!(header[3], VIRTIO_PCI_CAP_DEVICE_CFG);
1732 assert_eq!(header[2], 16);
1733 pci_test_device
1734 .pci_device
1735 .pci_cfg_read(next_cap_offset as u16 + 4, &mut buf)
1736 .unwrap();
1737 assert_eq!(buf, 0);
1738 pci_test_device
1739 .pci_device
1740 .pci_cfg_read(next_cap_offset as u16 + 8, &mut buf)
1741 .unwrap();
1742 assert_eq!(buf, 0x40);
1743 pci_test_device
1744 .pci_device
1745 .pci_cfg_read(next_cap_offset as u16 + 12, &mut buf)
1746 .unwrap();
1747 assert_eq!(buf, 12);
1748 next_cap_offset = header[1] as u32;
1749 assert_eq!(next_cap_offset, 0);
1750 }
1751
1752 #[async_test]
1753 async fn verify_pci_registers(driver: DefaultDriver) {
1754 let mut pci_test_device =
1755 VirtioPciTestDevice::new(&driver, 1, &VirtioTestMemoryAccess::new(), None);
1756 let bar_address1: u64 = 0x2000000000;
1757 pci_test_device
1758 .pci_device
1759 .pci_cfg_write(0x14, (bar_address1 >> 32) as u32)
1760 .unwrap();
1761 pci_test_device
1762 .pci_device
1763 .pci_cfg_write(0x10, bar_address1 as u32)
1764 .unwrap();
1765
1766 let bar_address2: u64 = 0x4000;
1767 pci_test_device
1768 .pci_device
1769 .pci_cfg_write(0x1c, (bar_address2 >> 32) as u32)
1770 .unwrap();
1771 pci_test_device
1772 .pci_device
1773 .pci_cfg_write(0x18, bar_address2 as u32)
1774 .unwrap();
1775
1776 pci_test_device
1777 .pci_device
1778 .pci_cfg_write(
1779 0x4,
1780 cfg_space::Command::new()
1781 .with_mmio_enabled(true)
1782 .into_bits() as u32,
1783 )
1784 .unwrap();
1785
1786 assert_eq!(pci_test_device.read_u32(bar_address1), 0);
1788 assert_eq!(
1790 pci_test_device.read_u32(bar_address1 + 4),
1791 VIRTIO_F_RING_INDIRECT_DESC | VIRTIO_F_RING_EVENT_IDX | 2
1792 );
1793 pci_test_device.write_u32(bar_address1, 1);
1795 assert_eq!(pci_test_device.read_u32(bar_address1), 1);
1796 assert_eq!(
1797 pci_test_device.read_u32(bar_address1 + 4),
1798 VIRTIO_F_VERSION_1
1799 );
1800 pci_test_device.write_u32(bar_address1, 2);
1802 assert_eq!(pci_test_device.read_u32(bar_address1), 2);
1803 assert_eq!(pci_test_device.read_u32(bar_address1 + 4), 0);
1804 assert_eq!(pci_test_device.read_u32(bar_address1 + 8), 0);
1806 assert_eq!(pci_test_device.read_u32(bar_address1 + 12), 0);
1808 pci_test_device.write_u32(bar_address1 + 12, 2);
1809 assert_eq!(pci_test_device.read_u32(bar_address1 + 12), 2);
1810 pci_test_device.write_u32(bar_address1 + 12, 0xffffffff);
1811 assert_eq!(
1812 pci_test_device.read_u32(bar_address1 + 12),
1813 VIRTIO_F_RING_INDIRECT_DESC | VIRTIO_F_RING_EVENT_IDX | 2
1814 );
1815 pci_test_device.write_u32(bar_address1 + 8, 1);
1817 assert_eq!(pci_test_device.read_u32(bar_address1 + 8), 1);
1818 assert_eq!(pci_test_device.read_u32(bar_address1 + 12), 0);
1819 pci_test_device.write_u32(bar_address1 + 12, 0xffffffff);
1820 assert_eq!(
1821 pci_test_device.read_u32(bar_address1 + 12),
1822 VIRTIO_F_VERSION_1
1823 );
1824 pci_test_device.write_u32(bar_address1 + 8, 2);
1826 assert_eq!(pci_test_device.read_u32(bar_address1 + 8), 2);
1827 assert_eq!(pci_test_device.read_u32(bar_address1 + 12), 0);
1828 pci_test_device.write_u32(bar_address1 + 12, 0xffffffff);
1829 assert_eq!(pci_test_device.read_u32(bar_address1 + 12), 0);
1830 assert_eq!(pci_test_device.read_u32(bar_address1 + 16), 1 << 16);
1832 assert_eq!(pci_test_device.read_u32(bar_address1 + 20), 0);
1834 assert_eq!(pci_test_device.read_u32(bar_address1 + 24), 0x40);
1836 pci_test_device.write_u32(bar_address1 + 24, 0x20);
1837 assert_eq!(pci_test_device.read_u32(bar_address1 + 24), 0x20);
1838 assert_eq!(pci_test_device.read_u32(bar_address1 + 28), 0);
1840 pci_test_device.write_u32(bar_address1 + 28, 1);
1841 assert_eq!(pci_test_device.read_u32(bar_address1 + 28), 1);
1842 pci_test_device.write_u32(bar_address1 + 28, 0xffff);
1843 assert_eq!(pci_test_device.read_u32(bar_address1 + 28), 1);
1844 pci_test_device.write_u32(bar_address1 + 28, 0);
1845 assert_eq!(pci_test_device.read_u32(bar_address1 + 28), 0);
1846 assert_eq!(pci_test_device.read_u32(bar_address1 + 32), 0);
1848 pci_test_device.write_u32(bar_address1 + 32, 0xffff);
1849 assert_eq!(pci_test_device.read_u32(bar_address1 + 32), 0xffff);
1850 assert_eq!(pci_test_device.read_u32(bar_address1 + 36), 0);
1852 pci_test_device.write_u32(bar_address1 + 36, 1);
1853 assert_eq!(pci_test_device.read_u32(bar_address1 + 36), 1);
1854 assert_eq!(pci_test_device.read_u32(bar_address1 + 40), 0);
1856 pci_test_device.write_u32(bar_address1 + 40, 0xeeee);
1857 assert_eq!(pci_test_device.read_u32(bar_address1 + 40), 0xeeee);
1858 assert_eq!(pci_test_device.read_u32(bar_address1 + 44), 0);
1860 pci_test_device.write_u32(bar_address1 + 44, 2);
1861 assert_eq!(pci_test_device.read_u32(bar_address1 + 44), 2);
1862 assert_eq!(pci_test_device.read_u32(bar_address1 + 48), 0);
1864 pci_test_device.write_u32(bar_address1 + 48, 0xdddd);
1865 assert_eq!(pci_test_device.read_u32(bar_address1 + 48), 0xdddd);
1866 assert_eq!(pci_test_device.read_u32(bar_address1 + 52), 0);
1868 pci_test_device.write_u32(bar_address1 + 52, 3);
1869 assert_eq!(pci_test_device.read_u32(bar_address1 + 52), 3);
1870 assert_eq!(pci_test_device.read_u32(bar_address1 + 56), 0);
1872 assert_eq!(pci_test_device.read_u32(bar_address1 + 60), 0);
1874
1875 let queue_index: u16 = 1;
1877 pci_test_device
1878 .pci_device
1879 .mmio_write(bar_address1 + 22, &queue_index.to_le_bytes())
1880 .unwrap();
1881 assert_eq!(pci_test_device.read_u32(bar_address1 + 20), 1 << 24);
1882 assert_eq!(pci_test_device.read_u32(bar_address1 + 24), 0);
1884 pci_test_device.write_u32(bar_address1 + 24, 2);
1885 assert_eq!(pci_test_device.read_u32(bar_address1 + 24), 0);
1886 assert_eq!(pci_test_device.read_u32(bar_address1 + 28), 0);
1888 pci_test_device.write_u32(bar_address1 + 28, 1);
1889 assert_eq!(pci_test_device.read_u32(bar_address1 + 28), 0);
1890 assert_eq!(pci_test_device.read_u32(bar_address1 + 32), 0);
1892 pci_test_device.write_u32(bar_address1 + 32, 0x10);
1893 assert_eq!(pci_test_device.read_u32(bar_address1 + 32), 0);
1894 assert_eq!(pci_test_device.read_u32(bar_address1 + 36), 0);
1896 pci_test_device.write_u32(bar_address1 + 36, 0x10);
1897 assert_eq!(pci_test_device.read_u32(bar_address1 + 36), 0);
1898 assert_eq!(pci_test_device.read_u32(bar_address1 + 40), 0);
1900 pci_test_device.write_u32(bar_address1 + 40, 0x10);
1901 assert_eq!(pci_test_device.read_u32(bar_address1 + 40), 0);
1902 assert_eq!(pci_test_device.read_u32(bar_address1 + 44), 0);
1904 pci_test_device.write_u32(bar_address1 + 44, 0x10);
1905 assert_eq!(pci_test_device.read_u32(bar_address1 + 44), 0);
1906 assert_eq!(pci_test_device.read_u32(bar_address1 + 48), 0);
1908 pci_test_device.write_u32(bar_address1 + 48, 0x10);
1909 assert_eq!(pci_test_device.read_u32(bar_address1 + 48), 0);
1910 assert_eq!(pci_test_device.read_u32(bar_address1 + 52), 0);
1912 pci_test_device.write_u32(bar_address1 + 52, 0x10);
1913 assert_eq!(pci_test_device.read_u32(bar_address1 + 52), 0);
1914 }
1915
1916 #[async_test]
1917 async fn verify_queue_simple(driver: DefaultDriver) {
1918 let test_mem = VirtioTestMemoryAccess::new();
1919 let mut guest = VirtioTestGuest::new(&driver, &test_mem, 1, 2, true);
1920 let base_addr = guest.get_queue_descriptor_backing_memory_address(0);
1921 let (tx, mut rx) = mesh::mpsc_channel();
1922 let event = Event::new();
1923 let mut queues = guest.create_direct_queues(|i| {
1924 let tx = tx.clone();
1925 CreateDirectQueueParams {
1926 process_work: Box::new(move |work: anyhow::Result<VirtioQueueCallbackWork>| {
1927 let mut work = work.expect("Queue failure");
1928 assert_eq!(work.payload.len(), 1);
1929 assert_eq!(work.payload[0].address, base_addr);
1930 assert_eq!(work.payload[0].length, 0x1000);
1931 work.complete(123);
1932 true
1933 }),
1934 notify: Interrupt::from_fn(move || {
1935 tx.send(i as usize);
1936 }),
1937 event: event.clone(),
1938 }
1939 });
1940
1941 guest.add_to_avail_queue(0);
1942 event.signal();
1943 must_recv_in_timeout(&mut rx, Duration::from_millis(100)).await;
1944 let (desc, len) = guest.get_next_completed(0).unwrap();
1945 assert_eq!(desc, 0u16);
1946 assert_eq!(len, 123);
1947 assert_eq!(guest.get_next_completed(0).is_none(), true);
1948 queues[0].stop().await;
1949 }
1950
1951 #[async_test]
1952 async fn verify_queue_indirect(driver: DefaultDriver) {
1953 let test_mem = VirtioTestMemoryAccess::new();
1954 let mut guest = VirtioTestGuest::new(&driver, &test_mem, 1, 2, true);
1955 let (tx, mut rx) = mesh::mpsc_channel();
1956 let event = Event::new();
1957 let mut queues = guest.create_direct_queues(|i| {
1958 let tx = tx.clone();
1959 CreateDirectQueueParams {
1960 process_work: Box::new(move |work: anyhow::Result<VirtioQueueCallbackWork>| {
1961 let mut work = work.expect("Queue failure");
1962 assert_eq!(work.payload.len(), 1);
1963 assert_eq!(work.payload[0].address, 0xffffffff00000000u64);
1964 assert_eq!(work.payload[0].length, 0x1000);
1965 work.complete(123);
1966 true
1967 }),
1968 notify: Interrupt::from_fn(move || {
1969 tx.send(i as usize);
1970 }),
1971 event: event.clone(),
1972 }
1973 });
1974
1975 guest.add_indirect_to_avail_queue(0);
1976 event.signal();
1977 must_recv_in_timeout(&mut rx, Duration::from_millis(100)).await;
1978 let (desc, len) = guest.get_next_completed(0).unwrap();
1979 assert_eq!(desc, 0u16);
1980 assert_eq!(len, 123);
1981 assert_eq!(guest.get_next_completed(0).is_none(), true);
1982 queues[0].stop().await;
1983 }
1984
1985 #[async_test]
1986 async fn verify_queue_linked(driver: DefaultDriver) {
1987 let test_mem = VirtioTestMemoryAccess::new();
1988 let mut guest = VirtioTestGuest::new(&driver, &test_mem, 1, 5, true);
1989 let (tx, mut rx) = mesh::mpsc_channel();
1990 let base_address = guest.get_queue_descriptor_backing_memory_address(0);
1991 let event = Event::new();
1992 let mut queues = guest.create_direct_queues(|i| {
1993 let tx = tx.clone();
1994 CreateDirectQueueParams {
1995 process_work: Box::new(move |work: anyhow::Result<VirtioQueueCallbackWork>| {
1996 let mut work = work.expect("Queue failure");
1997 assert_eq!(work.payload.len(), 3);
1998 for i in 0..work.payload.len() {
1999 assert_eq!(work.payload[i].address, base_address + 0x1000 * i as u64);
2000 assert_eq!(work.payload[i].length, 0x1000);
2001 }
2002 work.complete(123 * 3);
2003 true
2004 }),
2005 notify: Interrupt::from_fn(move || {
2006 tx.send(i as usize);
2007 }),
2008 event: event.clone(),
2009 }
2010 });
2011
2012 guest.add_linked_to_avail_queue(0, 3);
2013 event.signal();
2014 must_recv_in_timeout(&mut rx, Duration::from_millis(100)).await;
2015 let (desc, len) = guest.get_next_completed(0).unwrap();
2016 assert_eq!(desc, 0u16);
2017 assert_eq!(len, 123 * 3);
2018 assert_eq!(guest.get_next_completed(0).is_none(), true);
2019 queues[0].stop().await;
2020 }
2021
2022 #[async_test]
2023 async fn verify_queue_indirect_linked(driver: DefaultDriver) {
2024 let test_mem = VirtioTestMemoryAccess::new();
2025 let mut guest = VirtioTestGuest::new(&driver, &test_mem, 1, 5, true);
2026 let (tx, mut rx) = mesh::mpsc_channel();
2027 let event = Event::new();
2028 let mut queues = guest.create_direct_queues(|i| {
2029 let tx = tx.clone();
2030 CreateDirectQueueParams {
2031 process_work: Box::new(move |work: anyhow::Result<VirtioQueueCallbackWork>| {
2032 let mut work = work.expect("Queue failure");
2033 assert_eq!(work.payload.len(), 3);
2034 for i in 0..work.payload.len() {
2035 assert_eq!(
2036 work.payload[i].address,
2037 0xffffffff00000000u64 + 0x1000 * i as u64
2038 );
2039 assert_eq!(work.payload[i].length, 0x1000);
2040 }
2041 work.complete(123 * 3);
2042 true
2043 }),
2044 notify: Interrupt::from_fn(move || {
2045 tx.send(i as usize);
2046 }),
2047 event: event.clone(),
2048 }
2049 });
2050
2051 guest.add_indirect_linked_to_avail_queue(0, 3);
2052 event.signal();
2053 must_recv_in_timeout(&mut rx, Duration::from_millis(100)).await;
2054 let (desc, len) = guest.get_next_completed(0).unwrap();
2055 assert_eq!(desc, 0u16);
2056 assert_eq!(len, 123 * 3);
2057 assert_eq!(guest.get_next_completed(0).is_none(), true);
2058 queues[0].stop().await;
2059 }
2060
2061 #[async_test]
2062 async fn verify_queue_avail_rollover(driver: DefaultDriver) {
2063 let test_mem = VirtioTestMemoryAccess::new();
2064 let mut guest = VirtioTestGuest::new(&driver, &test_mem, 1, 2, true);
2065 let base_addr = guest.get_queue_descriptor_backing_memory_address(0);
2066 let (tx, mut rx) = mesh::mpsc_channel();
2067 let event = Event::new();
2068 let mut queues = guest.create_direct_queues(|i| {
2069 let tx = tx.clone();
2070 CreateDirectQueueParams {
2071 process_work: Box::new(move |work: anyhow::Result<VirtioQueueCallbackWork>| {
2072 let mut work = work.expect("Queue failure");
2073 assert_eq!(work.payload.len(), 1);
2074 assert_eq!(work.payload[0].address, base_addr);
2075 assert_eq!(work.payload[0].length, 0x1000);
2076 work.complete(123);
2077 true
2078 }),
2079 notify: Interrupt::from_fn(move || {
2080 tx.send(i as usize);
2081 }),
2082 event: event.clone(),
2083 }
2084 });
2085
2086 for _ in 0..3 {
2087 guest.add_to_avail_queue(0);
2088 event.signal();
2089 must_recv_in_timeout(&mut rx, Duration::from_millis(100)).await;
2090 let (desc, len) = guest.get_next_completed(0).unwrap();
2091 assert_eq!(desc, 0u16);
2092 assert_eq!(len, 123);
2093 assert_eq!(guest.get_next_completed(0).is_none(), true);
2094 }
2095
2096 queues[0].stop().await;
2097 }
2098
2099 #[async_test]
2100 async fn verify_multi_queue(driver: DefaultDriver) {
2101 let test_mem = VirtioTestMemoryAccess::new();
2102 let mut guest = VirtioTestGuest::new(&driver, &test_mem, 5, 2, true);
2103 let (tx, mut rx) = mesh::mpsc_channel();
2104 let events = (0..guest.num_queues)
2105 .map(|_| Event::new())
2106 .collect::<Vec<_>>();
2107 let mut queues = guest.create_direct_queues(|queue_index| {
2108 let tx = tx.clone();
2109 let base_addr = guest.get_queue_descriptor_backing_memory_address(queue_index);
2110 CreateDirectQueueParams {
2111 process_work: Box::new(move |work: anyhow::Result<VirtioQueueCallbackWork>| {
2112 let mut work = work.expect("Queue failure");
2113 assert_eq!(work.payload.len(), 1);
2114 assert_eq!(work.payload[0].address, base_addr);
2115 assert_eq!(work.payload[0].length, 0x1000);
2116 work.complete(123 * queue_index as u32);
2117 true
2118 }),
2119 notify: Interrupt::from_fn(move || {
2120 tx.send(queue_index as usize);
2121 }),
2122 event: events[queue_index as usize].clone(),
2123 }
2124 });
2125
2126 for (i, event) in events.iter().enumerate() {
2127 let queue_index = i as u16;
2128 guest.add_to_avail_queue(queue_index);
2129 event.signal();
2130 }
2131 for _ in 0..guest.num_queues {
2133 must_recv_in_timeout(&mut rx, Duration::from_millis(100)).await;
2134 }
2135 for queue_index in 0..guest.num_queues {
2137 let (desc, len) = guest.get_next_completed(queue_index).unwrap();
2138 assert_eq!(desc, 0u16);
2139 assert_eq!(len, 123 * queue_index as u32);
2140 }
2141 for (i, queue) in queues.iter_mut().enumerate() {
2143 let queue_index = i as u16;
2144 assert_eq!(guest.get_next_completed(queue_index).is_none(), true);
2145 queue.stop().await;
2146 }
2147 }
2148
2149 fn take_mmio_interrupt_status(dev: &mut VirtioMmioDevice, mask: u32) -> u32 {
2150 let mut v = [0; 4];
2151 dev.mmio_read(96, &mut v).unwrap();
2152 dev.mmio_write(100, &mask.to_ne_bytes()).unwrap();
2153 u32::from_ne_bytes(v)
2154 }
2155
2156 async fn expect_mmio_interrupt(
2157 dev: &mut VirtioMmioDevice,
2158 target: &TestLineInterruptTarget,
2159 mask: u32,
2160 multiple_expected: bool,
2161 ) {
2162 poll_fn(|cx| target.poll_high(cx, 0)).await;
2163 let v = take_mmio_interrupt_status(dev, mask);
2164 assert_eq!(v & mask, mask);
2165 assert!(multiple_expected || !target.is_high(0));
2166 }
2167
2168 #[async_test]
2169 async fn verify_device_queue_simple(driver: DefaultDriver) {
2170 let test_mem = VirtioTestMemoryAccess::new();
2171 let doorbell_registration: Arc<dyn DoorbellRegistration> = test_mem.clone();
2172 let mut guest = VirtioTestGuest::new(&driver, &test_mem, 1, 2, true);
2173 let mem = guest.mem();
2174 let features = ((VIRTIO_F_VERSION_1 as u64) << 32) | VIRTIO_F_RING_EVENT_IDX as u64 | 2;
2175 let target = TestLineInterruptTarget::new_arc();
2176 let interrupt = LineInterrupt::new_with_target("test", target.clone(), 0);
2177 let base_addr = guest.get_queue_descriptor_backing_memory_address(0);
2178 let queue_work = Arc::new(move |_: u16, mut work: VirtioQueueCallbackWork| {
2179 assert_eq!(work.payload.len(), 1);
2180 assert_eq!(work.payload[0].address, base_addr);
2181 assert_eq!(work.payload[0].length, 0x1000);
2182 work.complete(123);
2183 });
2184 let mut dev = VirtioMmioDevice::new(
2185 Box::new(LegacyWrapper::new(
2186 &VmTaskDriverSource::new(SingleDriverBackend::new(driver)),
2187 TestDevice::new(
2188 DeviceTraits {
2189 device_id: 3,
2190 device_features: features,
2191 max_queues: 1,
2192 device_register_length: 0,
2193 ..Default::default()
2194 },
2195 Some(queue_work),
2196 ),
2197 &mem,
2198 )),
2199 interrupt,
2200 Some(doorbell_registration),
2201 0,
2202 1,
2203 );
2204
2205 guest.setup_chipset_device(&mut dev, features);
2206 expect_mmio_interrupt(
2207 &mut dev,
2208 &target,
2209 VIRTIO_MMIO_INTERRUPT_STATUS_CONFIG_CHANGE,
2210 false,
2211 )
2212 .await;
2213 guest.add_to_avail_queue(0);
2214 dev.write_u32(80, 0);
2216 expect_mmio_interrupt(
2217 &mut dev,
2218 &target,
2219 VIRTIO_MMIO_INTERRUPT_STATUS_USED_BUFFER,
2220 false,
2221 )
2222 .await;
2223 let (desc, len) = guest.get_next_completed(0).unwrap();
2224 assert_eq!(desc, 0u16);
2225 assert_eq!(len, 123);
2226 assert_eq!(guest.get_next_completed(0).is_none(), true);
2227 dev.write_u32(112, 0);
2229 drop(dev);
2230 }
2231
2232 #[async_test]
2233 async fn verify_device_multi_queue(driver: DefaultDriver) {
2234 let num_queues = 5;
2235 let test_mem = VirtioTestMemoryAccess::new();
2236 let doorbell_registration: Arc<dyn DoorbellRegistration> = test_mem.clone();
2237 let mut guest = VirtioTestGuest::new(&driver, &test_mem, num_queues, 2, true);
2238 let mem = guest.mem();
2239 let features = ((VIRTIO_F_VERSION_1 as u64) << 32) | VIRTIO_F_RING_EVENT_IDX as u64 | 2;
2240 let target = TestLineInterruptTarget::new_arc();
2241 let interrupt = LineInterrupt::new_with_target("test", target.clone(), 0);
2242 let base_addr: Vec<_> = (0..num_queues)
2243 .map(|i| guest.get_queue_descriptor_backing_memory_address(i))
2244 .collect();
2245 let queue_work = Arc::new(move |i: u16, mut work: VirtioQueueCallbackWork| {
2246 assert_eq!(work.payload.len(), 1);
2247 assert_eq!(work.payload[0].address, base_addr[i as usize]);
2248 assert_eq!(work.payload[0].length, 0x1000);
2249 work.complete(123 * i as u32);
2250 });
2251 let mut dev = VirtioMmioDevice::new(
2252 Box::new(LegacyWrapper::new(
2253 &VmTaskDriverSource::new(SingleDriverBackend::new(driver)),
2254 TestDevice::new(
2255 DeviceTraits {
2256 device_id: 3,
2257 device_features: features,
2258 max_queues: num_queues + 1,
2259 device_register_length: 0,
2260 ..Default::default()
2261 },
2262 Some(queue_work),
2263 ),
2264 &mem,
2265 )),
2266 interrupt,
2267 Some(doorbell_registration),
2268 0,
2269 1,
2270 );
2271 guest.setup_chipset_device(&mut dev, features);
2272 expect_mmio_interrupt(
2273 &mut dev,
2274 &target,
2275 VIRTIO_MMIO_INTERRUPT_STATUS_CONFIG_CHANGE,
2276 false,
2277 )
2278 .await;
2279 for i in 0..num_queues {
2280 guest.add_to_avail_queue(i);
2281 dev.write_u32(80, i as u32);
2283 }
2284 for i in 0..num_queues {
2286 let (desc, len) = loop {
2287 if let Some(x) = guest.get_next_completed(i) {
2288 break x;
2289 }
2290 expect_mmio_interrupt(
2291 &mut dev,
2292 &target,
2293 VIRTIO_MMIO_INTERRUPT_STATUS_USED_BUFFER,
2294 i < (num_queues - 1),
2295 )
2296 .await;
2297 };
2298 assert_eq!(desc, 0u16);
2299 assert_eq!(len, 123 * i as u32);
2300 }
2301 for i in 0..num_queues {
2303 assert_eq!(guest.get_next_completed(i).is_none(), true);
2304 }
2305 dev.write_u32(112, 0);
2307 drop(dev);
2308 }
2309
2310 #[async_test]
2311 async fn verify_device_multi_queue_pci(driver: DefaultDriver) {
2312 let num_queues = 5;
2313 let test_mem = VirtioTestMemoryAccess::new();
2314 let mut guest = VirtioTestGuest::new(&driver, &test_mem, num_queues, 2, true);
2315 let features = ((VIRTIO_F_VERSION_1 as u64) << 32) | VIRTIO_F_RING_EVENT_IDX as u64 | 2;
2316 let base_addr: Vec<_> = (0..num_queues)
2317 .map(|i| guest.get_queue_descriptor_backing_memory_address(i))
2318 .collect();
2319 let mut dev = VirtioPciTestDevice::new(
2320 &driver,
2321 num_queues + 1,
2322 &test_mem,
2323 Some(Arc::new(move |i, mut work| {
2324 assert_eq!(work.payload.len(), 1);
2325 assert_eq!(work.payload[0].address, base_addr[i as usize]);
2326 assert_eq!(work.payload[0].length, 0x1000);
2327 work.complete(123 * i as u32);
2328 })),
2329 );
2330
2331 guest.setup_pci_device(&mut dev, features);
2332
2333 let mut timer = PolledTimer::new(&driver);
2334
2335 timer.sleep(Duration::from_millis(100)).await;
2337 let delivered = dev.test_intc.get_next_interrupt().unwrap();
2338 assert_eq!(delivered.0, 0);
2339 assert!(dev.test_intc.get_next_interrupt().is_none());
2340
2341 for i in 0..num_queues {
2342 guest.add_to_avail_queue(i);
2343 dev.write_u32(0x10000000000 + 0x38, i as u32);
2345 }
2346 timer.sleep(Duration::from_millis(100)).await;
2348 for _ in 0..num_queues {
2349 let delivered = dev.test_intc.get_next_interrupt();
2350 assert!(delivered.is_some());
2351 }
2352 for i in 0..num_queues {
2354 let (desc, len) = guest.get_next_completed(i).unwrap();
2355 assert_eq!(desc, 0u16);
2356 assert_eq!(len, 123 * i as u32);
2357 }
2358 for i in 0..num_queues {
2360 assert_eq!(guest.get_next_completed(i).is_none(), true);
2361 }
2362 let device_status: u8 = 0;
2364 dev.pci_device
2365 .mmio_write(0x10000000000 + 20, &device_status.to_le_bytes())
2366 .unwrap();
2367 drop(dev);
2368 }
2369}