virtio/
common.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4use crate::queue::QueueCore;
5use crate::queue::QueueError;
6use crate::queue::QueueParams;
7use crate::queue::VirtioQueuePayload;
8use async_trait::async_trait;
9use futures::FutureExt;
10use futures::Stream;
11use futures::StreamExt;
12use guestmem::DoorbellRegistration;
13use guestmem::GuestMemory;
14use guestmem::GuestMemoryError;
15use guestmem::MappedMemoryRegion;
16use pal_async::DefaultPool;
17use pal_async::driver::Driver;
18use pal_async::task::Spawn;
19use pal_async::wait::PolledWait;
20use pal_event::Event;
21use parking_lot::Mutex;
22use std::io::Error;
23use std::pin::Pin;
24use std::sync::Arc;
25use std::task::Context;
26use std::task::Poll;
27use std::task::ready;
28use task_control::AsyncRun;
29use task_control::StopTask;
30use task_control::TaskControl;
31use thiserror::Error;
32use vmcore::interrupt::Interrupt;
33use vmcore::vm_task::VmTaskDriver;
34use vmcore::vm_task::VmTaskDriverSource;
35
36#[async_trait]
37pub trait VirtioQueueWorkerContext {
38    async fn process_work(&mut self, work: anyhow::Result<VirtioQueueCallbackWork>) -> bool;
39}
40
41#[derive(Debug)]
42pub struct VirtioQueueUsedHandler {
43    core: QueueCore,
44    last_used_index: u16,
45    outstanding_desc_count: Arc<Mutex<(u16, event_listener::Event)>>,
46    notify_guest: Interrupt,
47}
48
49impl VirtioQueueUsedHandler {
50    fn new(core: QueueCore, notify_guest: Interrupt) -> Self {
51        Self {
52            core,
53            last_used_index: 0,
54            outstanding_desc_count: Arc::new(Mutex::new((0, event_listener::Event::new()))),
55            notify_guest,
56        }
57    }
58
59    pub fn add_outstanding_descriptor(&self) {
60        let (count, _) = &mut *self.outstanding_desc_count.lock();
61        *count += 1;
62    }
63
64    pub fn await_outstanding_descriptors(&self) -> event_listener::EventListener {
65        let (count, event) = &*self.outstanding_desc_count.lock();
66        let listener = event.listen();
67        if *count == 0 {
68            event.notify(usize::MAX);
69        }
70        listener
71    }
72
73    pub fn complete_descriptor(&mut self, descriptor_index: u16, bytes_written: u32) {
74        match self.core.complete_descriptor(
75            &mut self.last_used_index,
76            descriptor_index,
77            bytes_written,
78        ) {
79            Ok(true) => {
80                self.notify_guest.deliver();
81            }
82            Ok(false) => {}
83            Err(err) => {
84                tracelimit::error_ratelimited!(
85                    error = &err as &dyn std::error::Error,
86                    "failed to complete descriptor"
87                );
88            }
89        }
90        {
91            let (count, event) = &mut *self.outstanding_desc_count.lock();
92            *count -= 1;
93            if *count == 0 {
94                event.notify(usize::MAX);
95            }
96        }
97    }
98}
99
100pub struct VirtioQueueCallbackWork {
101    pub payload: Vec<VirtioQueuePayload>,
102    used_queue_handler: Arc<Mutex<VirtioQueueUsedHandler>>,
103    descriptor_index: u16,
104    completed: bool,
105}
106
107impl VirtioQueueCallbackWork {
108    pub fn new(
109        payload: Vec<VirtioQueuePayload>,
110        used_queue_handler: &Arc<Mutex<VirtioQueueUsedHandler>>,
111        descriptor_index: u16,
112    ) -> Self {
113        let used_queue_handler = used_queue_handler.clone();
114        used_queue_handler.lock().add_outstanding_descriptor();
115        Self {
116            payload,
117            used_queue_handler,
118            descriptor_index,
119            completed: false,
120        }
121    }
122
123    pub fn complete(&mut self, bytes_written: u32) {
124        assert!(!self.completed);
125        self.used_queue_handler
126            .lock()
127            .complete_descriptor(self.descriptor_index, bytes_written);
128        self.completed = true;
129    }
130
131    pub fn descriptor_index(&self) -> u16 {
132        self.descriptor_index
133    }
134
135    // Determine the total size of all readable or all writeable payload buffers.
136    pub fn get_payload_length(&self, writeable: bool) -> u64 {
137        self.payload
138            .iter()
139            .filter(|x| x.writeable == writeable)
140            .fold(0, |acc, x| acc + x.length as u64)
141    }
142
143    // Read all payload into a buffer.
144    pub fn read(&self, mem: &GuestMemory, target: &mut [u8]) -> Result<usize, GuestMemoryError> {
145        let mut remaining = target;
146        let mut read_bytes: usize = 0;
147        for payload in &self.payload {
148            if payload.writeable {
149                continue;
150            }
151
152            let size = std::cmp::min(payload.length as usize, remaining.len());
153            let (current, next) = remaining.split_at_mut(size);
154            mem.read_at(payload.address, current)?;
155            read_bytes += size;
156            if next.is_empty() {
157                break;
158            }
159
160            remaining = next;
161        }
162
163        Ok(read_bytes)
164    }
165
166    // Write the specified buffer to the payload buffers.
167    pub fn write_at_offset(
168        &self,
169        offset: u64,
170        mem: &GuestMemory,
171        source: &[u8],
172    ) -> Result<(), VirtioWriteError> {
173        let mut skip_bytes = offset;
174        let mut remaining = source;
175        for payload in &self.payload {
176            if !payload.writeable {
177                continue;
178            }
179
180            let payload_length = payload.length as u64;
181            if skip_bytes >= payload_length {
182                skip_bytes -= payload_length;
183                continue;
184            }
185
186            let size = std::cmp::min(
187                payload_length as usize - skip_bytes as usize,
188                remaining.len(),
189            );
190            let (current, next) = remaining.split_at(size);
191            mem.write_at(payload.address + skip_bytes, current)?;
192            remaining = next;
193            if remaining.is_empty() {
194                break;
195            }
196            skip_bytes = 0;
197        }
198
199        if !remaining.is_empty() {
200            return Err(VirtioWriteError::NotAllWritten(source.len()));
201        }
202
203        Ok(())
204    }
205
206    pub fn write(&self, mem: &GuestMemory, source: &[u8]) -> Result<(), VirtioWriteError> {
207        self.write_at_offset(0, mem, source)
208    }
209}
210
211#[derive(Debug, Error)]
212pub enum VirtioWriteError {
213    #[error(transparent)]
214    Memory(#[from] GuestMemoryError),
215    #[error("{0:#x} bytes not written")]
216    NotAllWritten(usize),
217}
218
219impl Drop for VirtioQueueCallbackWork {
220    fn drop(&mut self) {
221        if !self.completed {
222            self.complete(0);
223        }
224    }
225}
226
227#[derive(Debug)]
228pub struct VirtioQueue {
229    core: QueueCore,
230    last_avail_index: u16,
231    used_handler: Arc<Mutex<VirtioQueueUsedHandler>>,
232    queue_event: PolledWait<Event>,
233}
234
235impl VirtioQueue {
236    pub fn new(
237        features: u64,
238        params: QueueParams,
239        mem: GuestMemory,
240        notify: Interrupt,
241        queue_event: PolledWait<Event>,
242    ) -> Result<Self, QueueError> {
243        let core = QueueCore::new(features, mem, params)?;
244        let used_handler = Arc::new(Mutex::new(VirtioQueueUsedHandler::new(
245            core.clone(),
246            notify,
247        )));
248        Ok(Self {
249            core,
250            last_avail_index: 0,
251            used_handler,
252            queue_event,
253        })
254    }
255
256    async fn wait_for_outstanding_descriptors(&self) {
257        let wait_for_descriptors = self.used_handler.lock().await_outstanding_descriptors();
258        wait_for_descriptors.await;
259    }
260
261    fn poll_next_buffer(
262        &mut self,
263        cx: &mut Context<'_>,
264    ) -> Poll<Result<Option<VirtioQueueCallbackWork>, QueueError>> {
265        let descriptor_index = loop {
266            if let Some(descriptor_index) = self.core.descriptor_index(self.last_avail_index)? {
267                break descriptor_index;
268            };
269            ready!(self.queue_event.wait().poll_unpin(cx)).expect("waits on Event cannot fail");
270        };
271        let payload = self
272            .core
273            .reader(descriptor_index)
274            .collect::<Result<Vec<_>, _>>()?;
275
276        self.last_avail_index = self.last_avail_index.wrapping_add(1);
277        Poll::Ready(Ok(Some(VirtioQueueCallbackWork::new(
278            payload,
279            &self.used_handler,
280            descriptor_index,
281        ))))
282    }
283}
284
285impl Drop for VirtioQueue {
286    fn drop(&mut self) {
287        if Arc::get_mut(&mut self.used_handler).is_none() {
288            tracing::error!("Virtio queue dropped with outstanding work pending")
289        }
290    }
291}
292
293impl Stream for VirtioQueue {
294    type Item = Result<VirtioQueueCallbackWork, Error>;
295
296    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
297        let Some(r) = ready!(self.get_mut().poll_next_buffer(cx)).transpose() else {
298            return Poll::Ready(None);
299        };
300
301        Poll::Ready(Some(
302            r.map_err(|err| Error::new(std::io::ErrorKind::Other, err)),
303        ))
304    }
305}
306
307enum VirtioQueueStateInner {
308    Initializing {
309        mem: GuestMemory,
310        features: u64,
311        params: QueueParams,
312        event: Event,
313        notify: Interrupt,
314        exit_event: event_listener::EventListener,
315    },
316    InitializationInProgress,
317    Running {
318        queue: VirtioQueue,
319        exit_event: event_listener::EventListener,
320    },
321}
322
323pub struct VirtioQueueState {
324    inner: VirtioQueueStateInner,
325}
326
327pub struct VirtioQueueWorker {
328    driver: Box<dyn Driver>,
329    context: Box<dyn VirtioQueueWorkerContext + Send>,
330}
331
332impl VirtioQueueWorker {
333    pub fn new(driver: impl Driver, context: Box<dyn VirtioQueueWorkerContext + Send>) -> Self {
334        Self {
335            driver: Box::new(driver),
336            context,
337        }
338    }
339
340    pub fn into_running_task(
341        self,
342        name: impl Into<String>,
343        mem: GuestMemory,
344        features: u64,
345        queue_resources: QueueResources,
346        exit_event: event_listener::EventListener,
347    ) -> TaskControl<VirtioQueueWorker, VirtioQueueState> {
348        let name = name.into();
349        let (_, driver) = DefaultPool::spawn_on_thread(&name);
350
351        let mut task = TaskControl::new(self);
352        task.insert(
353            driver,
354            name,
355            VirtioQueueState {
356                inner: VirtioQueueStateInner::Initializing {
357                    mem,
358                    features,
359                    params: queue_resources.params,
360                    event: queue_resources.event,
361                    notify: queue_resources.notify,
362                    exit_event,
363                },
364            },
365        );
366        task.start();
367        task
368    }
369
370    async fn run_queue(&mut self, state: &mut VirtioQueueState) -> bool {
371        match &mut state.inner {
372            VirtioQueueStateInner::InitializationInProgress => unreachable!(),
373            VirtioQueueStateInner::Initializing { .. } => {
374                let VirtioQueueStateInner::Initializing {
375                    mem,
376                    features,
377                    params,
378                    event,
379                    notify,
380                    exit_event,
381                } = std::mem::replace(
382                    &mut state.inner,
383                    VirtioQueueStateInner::InitializationInProgress,
384                )
385                else {
386                    unreachable!()
387                };
388                let queue_event = PolledWait::new(&self.driver, event).unwrap();
389                let queue = VirtioQueue::new(features, params, mem, notify, queue_event);
390                if let Err(err) = queue {
391                    tracing::error!(
392                        err = &err as &dyn std::error::Error,
393                        "Failed to start queue"
394                    );
395                    false
396                } else {
397                    state.inner = VirtioQueueStateInner::Running {
398                        queue: queue.unwrap(),
399                        exit_event,
400                    };
401                    true
402                }
403            }
404            VirtioQueueStateInner::Running { queue, exit_event } => {
405                let mut exit = exit_event.fuse();
406                let mut queue_ready = queue.next().fuse();
407                let work = futures::select_biased! {
408                    _ = exit => return false,
409                    work = queue_ready => work.expect("queue will never complete").map_err(anyhow::Error::from),
410                };
411                self.context.process_work(work).await
412            }
413        }
414    }
415}
416
417impl AsyncRun<VirtioQueueState> for VirtioQueueWorker {
418    async fn run(
419        &mut self,
420        stop: &mut StopTask<'_>,
421        state: &mut VirtioQueueState,
422    ) -> Result<(), task_control::Cancelled> {
423        while stop.until_stopped(self.run_queue(state)).await? {}
424        Ok(())
425    }
426}
427
428pub struct VirtioRunningState {
429    pub features: u64,
430    pub enabled_queues: Vec<bool>,
431}
432
433pub enum VirtioState {
434    Unknown,
435    Running(VirtioRunningState),
436    Stopped,
437}
438
439pub(crate) struct VirtioDoorbells {
440    registration: Option<Arc<dyn DoorbellRegistration>>,
441    doorbells: Vec<Box<dyn Send + Sync>>,
442}
443
444impl VirtioDoorbells {
445    pub fn new(registration: Option<Arc<dyn DoorbellRegistration>>) -> Self {
446        Self {
447            registration,
448            doorbells: Vec::new(),
449        }
450    }
451
452    pub fn add(&mut self, address: u64, value: Option<u64>, length: Option<u32>, event: &Event) {
453        if let Some(registration) = &mut self.registration {
454            let doorbell = registration.register_doorbell(address, value, length, event);
455            if let Ok(doorbell) = doorbell {
456                self.doorbells.push(doorbell);
457            }
458        }
459    }
460
461    pub fn clear(&mut self) {
462        self.doorbells.clear();
463    }
464}
465
466#[derive(Copy, Clone, Debug, Default)]
467pub struct DeviceTraitsSharedMemory {
468    pub id: u8,
469    pub size: u64,
470}
471
472#[derive(Copy, Clone, Debug, Default)]
473pub struct DeviceTraits {
474    pub device_id: u16,
475    pub device_features: u64,
476    pub max_queues: u16,
477    pub device_register_length: u32,
478    pub shared_memory: DeviceTraitsSharedMemory,
479}
480
481pub trait LegacyVirtioDevice: Send {
482    fn traits(&self) -> DeviceTraits;
483    fn read_registers_u32(&self, offset: u16) -> u32;
484    fn write_registers_u32(&mut self, offset: u16, val: u32);
485    fn get_work_callback(&mut self, index: u16) -> Box<dyn VirtioQueueWorkerContext + Send>;
486    fn state_change(&mut self, state: &VirtioState);
487}
488
489pub trait VirtioDevice: Send {
490    fn traits(&self) -> DeviceTraits;
491    fn read_registers_u32(&self, offset: u16) -> u32;
492    fn write_registers_u32(&mut self, offset: u16, val: u32);
493    fn enable(&mut self, resources: Resources);
494    fn disable(&mut self);
495}
496
497pub struct QueueResources {
498    pub params: QueueParams,
499    pub notify: Interrupt,
500    pub event: Event,
501}
502
503pub struct Resources {
504    pub features: u64,
505    pub queues: Vec<QueueResources>,
506    pub shared_memory_region: Option<Arc<dyn MappedMemoryRegion>>,
507    pub shared_memory_size: u64,
508}
509
510/// Wraps an object implementing [`LegacyVirtioDevice`] and implements [`VirtioDevice`].
511pub struct LegacyWrapper<T: LegacyVirtioDevice> {
512    device: T,
513    driver: VmTaskDriver,
514    mem: GuestMemory,
515    workers: Vec<TaskControl<VirtioQueueWorker, VirtioQueueState>>,
516    exit_event: event_listener::Event,
517}
518
519impl<T: LegacyVirtioDevice> LegacyWrapper<T> {
520    pub fn new(driver_source: &VmTaskDriverSource, device: T, mem: &GuestMemory) -> Self {
521        Self {
522            device,
523            driver: driver_source.simple(),
524            mem: mem.clone(),
525            workers: Vec::new(),
526            exit_event: event_listener::Event::new(),
527        }
528    }
529}
530
531impl<T: LegacyVirtioDevice> VirtioDevice for LegacyWrapper<T> {
532    fn traits(&self) -> DeviceTraits {
533        self.device.traits()
534    }
535
536    fn read_registers_u32(&self, offset: u16) -> u32 {
537        self.device.read_registers_u32(offset)
538    }
539
540    fn write_registers_u32(&mut self, offset: u16, val: u32) {
541        self.device.write_registers_u32(offset, val)
542    }
543
544    fn enable(&mut self, resources: Resources) {
545        let running_state = VirtioRunningState {
546            features: resources.features,
547            enabled_queues: resources
548                .queues
549                .iter()
550                .map(|QueueResources { params, .. }| params.enable)
551                .collect(),
552        };
553
554        self.device
555            .state_change(&VirtioState::Running(running_state));
556        self.workers = resources
557            .queues
558            .into_iter()
559            .enumerate()
560            .filter_map(|(i, queue_resources)| {
561                if !queue_resources.params.enable {
562                    return None;
563                }
564                let worker = VirtioQueueWorker::new(
565                    self.driver.clone(),
566                    self.device.get_work_callback(i as u16),
567                );
568                Some(worker.into_running_task(
569                    "virtio-queue".to_string(),
570                    self.mem.clone(),
571                    resources.features,
572                    queue_resources,
573                    self.exit_event.listen(),
574                ))
575            })
576            .collect();
577    }
578
579    fn disable(&mut self) {
580        if self.workers.is_empty() {
581            return;
582        }
583        self.exit_event.notify(usize::MAX);
584        self.device.state_change(&VirtioState::Stopped);
585        let mut workers = self.workers.drain(..).collect::<Vec<_>>();
586        self.driver
587            .spawn("shutdown-legacy-virtio-queues".to_owned(), async move {
588                futures::future::join_all(workers.iter_mut().map(async |worker| {
589                    worker.stop().await;
590                    if let Some(VirtioQueueStateInner::Running { queue, .. }) =
591                        worker.state_mut().map(|s| &s.inner)
592                    {
593                        queue.wait_for_outstanding_descriptors().await;
594                    }
595                }))
596                .await;
597            })
598            .detach();
599    }
600}
601
602impl<T: LegacyVirtioDevice> Drop for LegacyWrapper<T> {
603    fn drop(&mut self) {
604        self.disable();
605    }
606}
607
608// UNSAFETY: test code implements a custom `GuestMemory` backing, which requires
609// unsafe.
610#[expect(unsafe_code)]
611#[cfg(test)]
612mod tests {
613    use super::*;
614    use crate::PciInterruptModel;
615    use crate::spec::pci::*;
616    use crate::spec::queue::*;
617    use crate::spec::*;
618    use crate::transport::VirtioMmioDevice;
619    use crate::transport::VirtioPciDevice;
620    use chipset_device::mmio::ExternallyManagedMmioIntercepts;
621    use chipset_device::mmio::MmioIntercept;
622    use chipset_device::pci::PciConfigSpace;
623    use futures::StreamExt;
624    use guestmem::GuestMemoryAccess;
625    use guestmem::GuestMemoryBackingError;
626    use pal_async::DefaultDriver;
627    use pal_async::async_test;
628    use pal_async::timer::PolledTimer;
629    use pci_core::msi::MsiInterruptSet;
630    use pci_core::spec::caps::CapabilityId;
631    use pci_core::spec::cfg_space;
632    use pci_core::test_helpers::TestPciInterruptController;
633    use std::collections::BTreeMap;
634    use std::future::poll_fn;
635    use std::io;
636    use std::ptr::NonNull;
637    use std::time::Duration;
638    use test_with_tracing::test;
639    use vmcore::line_interrupt::LineInterrupt;
640    use vmcore::line_interrupt::test_helpers::TestLineInterruptTarget;
641    use vmcore::vm_task::SingleDriverBackend;
642    use vmcore::vm_task::VmTaskDriverSource;
643
644    async fn must_recv_in_timeout<T: 'static + Send>(
645        recv: &mut mesh::Receiver<T>,
646        timeout: Duration,
647    ) -> T {
648        mesh::CancelContext::new()
649            .with_timeout(timeout)
650            .until_cancelled(recv.next())
651            .await
652            .unwrap()
653            .unwrap()
654    }
655
656    #[derive(Default)]
657    struct VirtioTestMemoryAccess {
658        memory_map: Mutex<MemoryMap>,
659    }
660
661    #[derive(Default)]
662    struct MemoryMap {
663        map: BTreeMap<u64, (bool, Vec<u8>)>,
664    }
665
666    impl MemoryMap {
667        fn get(&mut self, address: u64, len: usize) -> Option<(bool, &mut [u8])> {
668            let (&base, &mut (writable, ref mut data)) = self.map.range_mut(..=address).last()?;
669            let data = data
670                .get_mut(usize::try_from(address - base).ok()?..)?
671                .get_mut(..len)?;
672
673            Some((writable, data))
674        }
675
676        fn insert(&mut self, address: u64, data: &[u8], writable: bool) {
677            if let Some((is_writable, v)) = self.get(address, data.len()) {
678                assert_eq!(writable, is_writable);
679                v.copy_from_slice(data);
680                return;
681            }
682
683            let end = address + data.len() as u64;
684            let mut data = data.to_vec();
685            if let Some((&next, &(next_writable, ref next_data))) = self.map.range(address..).next()
686            {
687                if end > next {
688                    let next_end = next + next_data.len() as u64;
689                    panic!(
690                        "overlapping memory map: {address:#x}..{end:#x} > {next:#x}..={next_end:#x}"
691                    );
692                }
693                if end == next && next_writable == writable {
694                    data.extend(next_data.as_slice());
695                    self.map.remove(&next).unwrap();
696                }
697            }
698
699            if let Some((&prev, &mut (prev_writable, ref mut prev_data))) =
700                self.map.range_mut(..address).last()
701            {
702                let prev_end = prev + prev_data.len() as u64;
703                if prev_end > address {
704                    panic!(
705                        "overlapping memory map: {prev:#x}..{prev_end:#x} > {address:#x}..={end:#x}"
706                    );
707                }
708                if prev_end == address && prev_writable == writable {
709                    prev_data.extend_from_slice(&data);
710                    return;
711                }
712            }
713
714            self.map.insert(address, (writable, data));
715        }
716    }
717
718    impl VirtioTestMemoryAccess {
719        fn new() -> Arc<Self> {
720            Default::default()
721        }
722
723        fn modify_memory_map(&self, address: u64, data: &[u8], writeable: bool) {
724            self.memory_map.lock().insert(address, data, writeable);
725        }
726
727        fn memory_map_get_u16(&self, address: u64) -> u16 {
728            let mut map = self.memory_map.lock();
729            let (_, data) = map.get(address, 2).unwrap();
730            u16::from_le_bytes(data.try_into().unwrap())
731        }
732
733        fn memory_map_get_u32(&self, address: u64) -> u32 {
734            let mut map = self.memory_map.lock();
735            let (_, data) = map.get(address, 4).unwrap();
736            u32::from_le_bytes(data.try_into().unwrap())
737        }
738    }
739
740    // SAFETY: test code
741    unsafe impl GuestMemoryAccess for VirtioTestMemoryAccess {
742        fn mapping(&self) -> Option<NonNull<u8>> {
743            None
744        }
745
746        fn max_address(&self) -> u64 {
747            // No real bound, so use the max physical address width on
748            // AMD64/ARM64.
749            1 << 52
750        }
751
752        unsafe fn read_fallback(
753            &self,
754            address: u64,
755            dest: *mut u8,
756            len: usize,
757        ) -> Result<(), GuestMemoryBackingError> {
758            match self.memory_map.lock().get(address, len) {
759                Some((_, value)) => {
760                    // SAFETY: guaranteed by caller
761                    unsafe {
762                        std::ptr::copy(value.as_ptr(), dest, len);
763                    }
764                }
765                None => panic!("Unexpected read request at address {:x}", address),
766            }
767            Ok(())
768        }
769
770        unsafe fn write_fallback(
771            &self,
772            address: u64,
773            src: *const u8,
774            len: usize,
775        ) -> Result<(), GuestMemoryBackingError> {
776            match self.memory_map.lock().get(address, len) {
777                Some((true, value)) => {
778                    // SAFETY: guaranteed by caller
779                    unsafe {
780                        std::ptr::copy(src, value.as_mut_ptr(), len);
781                    }
782                }
783                _ => panic!("Unexpected write request at address {:x}", address),
784            }
785            Ok(())
786        }
787
788        fn fill_fallback(
789            &self,
790            address: u64,
791            val: u8,
792            len: usize,
793        ) -> Result<(), GuestMemoryBackingError> {
794            match self.memory_map.lock().get(address, len) {
795                Some((true, value)) => value.fill(val),
796                _ => panic!("Unexpected write request at address {:x}", address),
797            };
798            Ok(())
799        }
800    }
801
802    struct DoorbellEntry;
803    impl Drop for DoorbellEntry {
804        fn drop(&mut self) {}
805    }
806
807    impl DoorbellRegistration for VirtioTestMemoryAccess {
808        fn register_doorbell(
809            &self,
810            _: u64,
811            _: Option<u64>,
812            _: Option<u32>,
813            _: &Event,
814        ) -> io::Result<Box<dyn Send + Sync>> {
815            Ok(Box::new(DoorbellEntry))
816        }
817    }
818
819    type VirtioTestWorkCallback =
820        Box<dyn Fn(anyhow::Result<VirtioQueueCallbackWork>) -> bool + Sync + Send>;
821    struct CreateDirectQueueParams {
822        process_work: VirtioTestWorkCallback,
823        notify: Interrupt,
824        event: Event,
825    }
826
827    struct VirtioTestGuest {
828        test_mem: Arc<VirtioTestMemoryAccess>,
829        driver: DefaultDriver,
830        num_queues: u16,
831        queue_size: u16,
832        use_ring_event_index: bool,
833        last_avail_index: Vec<u16>,
834        last_used_index: Vec<u16>,
835        avail_descriptors: Vec<Vec<bool>>,
836        exit_event: event_listener::Event,
837    }
838
839    impl VirtioTestGuest {
840        fn new(
841            driver: &DefaultDriver,
842            test_mem: &Arc<VirtioTestMemoryAccess>,
843            num_queues: u16,
844            queue_size: u16,
845            use_ring_event_index: bool,
846        ) -> Self {
847            let last_avail_index: Vec<u16> = vec![0; num_queues as usize];
848            let last_used_index: Vec<u16> = vec![0; num_queues as usize];
849            let avail_descriptors: Vec<Vec<bool>> =
850                vec![vec![true; queue_size as usize]; num_queues as usize];
851            let test_guest = Self {
852                test_mem: test_mem.clone(),
853                driver: driver.clone(),
854                num_queues,
855                queue_size,
856                use_ring_event_index,
857                last_avail_index,
858                last_used_index,
859                avail_descriptors,
860                exit_event: event_listener::Event::new(),
861            };
862            for i in 0..num_queues {
863                test_guest.add_queue_memory(i);
864            }
865            test_guest
866        }
867
868        fn mem(&self) -> GuestMemory {
869            GuestMemory::new("test", self.test_mem.clone())
870        }
871
872        fn create_direct_queues<F>(
873            &self,
874            f: F,
875        ) -> Vec<TaskControl<VirtioQueueWorker, VirtioQueueState>>
876        where
877            F: Fn(u16) -> CreateDirectQueueParams,
878        {
879            (0..self.num_queues)
880                .map(|i| {
881                    let params = f(i);
882                    let worker = VirtioQueueWorker::new(
883                        self.driver.clone(),
884                        Box::new(VirtioTestWork {
885                            callback: params.process_work,
886                        }),
887                    );
888                    worker.into_running_task(
889                        "virtio-test-queue".to_string(),
890                        self.mem(),
891                        self.queue_features(),
892                        QueueResources {
893                            params: self.queue_params(i),
894                            notify: params.notify,
895                            event: params.event,
896                        },
897                        self.exit_event.listen(),
898                    )
899                })
900                .collect::<Vec<_>>()
901        }
902
903        fn queue_features(&self) -> u64 {
904            if self.use_ring_event_index {
905                VIRTIO_F_RING_EVENT_IDX as u64
906            } else {
907                0
908            }
909        }
910
911        fn queue_params(&self, i: u16) -> QueueParams {
912            QueueParams {
913                size: self.queue_size,
914                enable: true,
915                desc_addr: self.get_queue_descriptor_base_address(i),
916                avail_addr: self.get_queue_available_base_address(i),
917                used_addr: self.get_queue_used_base_address(i),
918            }
919        }
920
921        fn get_queue_base_address(&self, index: u16) -> u64 {
922            0x10000000 * index as u64
923        }
924
925        fn get_queue_descriptor_base_address(&self, index: u16) -> u64 {
926            self.get_queue_base_address(index) + 0x1000
927        }
928
929        fn get_queue_available_base_address(&self, index: u16) -> u64 {
930            self.get_queue_base_address(index) + 0x2000
931        }
932
933        fn get_queue_used_base_address(&self, index: u16) -> u64 {
934            self.get_queue_base_address(index) + 0x3000
935        }
936
937        fn get_queue_descriptor_backing_memory_address(&self, index: u16) -> u64 {
938            self.get_queue_base_address(index) + 0x4000
939        }
940
941        fn setup_chipset_device(&self, dev: &mut VirtioMmioDevice, driver_features: u64) {
942            dev.write_u32(112, VIRTIO_ACKNOWLEDGE);
943            dev.write_u32(112, VIRTIO_DRIVER);
944            dev.write_u32(36, 0);
945            dev.write_u32(32, driver_features as u32);
946            dev.write_u32(36, 1);
947            dev.write_u32(32, (driver_features >> 32) as u32);
948            dev.write_u32(112, VIRTIO_FEATURES_OK);
949            for i in 0..self.num_queues {
950                let queue_index = i;
951                dev.write_u32(48, i as u32);
952                dev.write_u32(56, self.queue_size as u32);
953                let desc_addr = self.get_queue_descriptor_base_address(queue_index);
954                dev.write_u32(128, desc_addr as u32);
955                dev.write_u32(132, (desc_addr >> 32) as u32);
956                let avail_addr = self.get_queue_available_base_address(queue_index);
957                dev.write_u32(144, avail_addr as u32);
958                dev.write_u32(148, (avail_addr >> 32) as u32);
959                let used_addr = self.get_queue_used_base_address(queue_index);
960                dev.write_u32(160, used_addr as u32);
961                dev.write_u32(164, (used_addr >> 32) as u32);
962                // enable the queue
963                dev.write_u32(68, 1);
964            }
965            dev.write_u32(112, VIRTIO_DRIVER_OK);
966            assert_eq!(dev.read_u32(0xfc), 2);
967        }
968
969        fn setup_pci_device(&self, dev: &mut VirtioPciTestDevice, driver_features: u64) {
970            let bar_address1: u64 = 0x10000000000;
971            dev.pci_device
972                .pci_cfg_write(0x14, (bar_address1 >> 32) as u32)
973                .unwrap();
974            dev.pci_device
975                .pci_cfg_write(0x10, bar_address1 as u32)
976                .unwrap();
977
978            let bar_address2: u64 = 0x20000000000;
979            dev.pci_device
980                .pci_cfg_write(0x1c, (bar_address2 >> 32) as u32)
981                .unwrap();
982            dev.pci_device
983                .pci_cfg_write(0x18, bar_address2 as u32)
984                .unwrap();
985
986            dev.pci_device
987                .pci_cfg_write(
988                    0x4,
989                    cfg_space::Command::new()
990                        .with_mmio_enabled(true)
991                        .into_bits() as u32,
992                )
993                .unwrap();
994
995            let mut device_status = VIRTIO_ACKNOWLEDGE as u8;
996            dev.pci_device
997                .mmio_write(bar_address1 + 20, &device_status.to_le_bytes())
998                .unwrap();
999            device_status = VIRTIO_DRIVER as u8;
1000            dev.pci_device
1001                .mmio_write(bar_address1 + 20, &device_status.to_le_bytes())
1002                .unwrap();
1003            dev.write_u32(bar_address1 + 8, 0);
1004            dev.write_u32(bar_address1 + 12, driver_features as u32);
1005            dev.write_u32(bar_address1 + 8, 1);
1006            dev.write_u32(bar_address1 + 12, (driver_features >> 32) as u32);
1007            device_status = VIRTIO_FEATURES_OK as u8;
1008            dev.pci_device
1009                .mmio_write(bar_address1 + 20, &device_status.to_le_bytes())
1010                .unwrap();
1011            // setup config interrupt
1012            dev.pci_device
1013                .mmio_write(bar_address2, &0_u64.to_le_bytes())
1014                .unwrap(); // vector
1015            dev.pci_device
1016                .mmio_write(bar_address2 + 8, &0_u32.to_le_bytes())
1017                .unwrap(); // data
1018            dev.pci_device
1019                .mmio_write(bar_address2 + 12, &0_u32.to_le_bytes())
1020                .unwrap();
1021            for i in 0..self.num_queues {
1022                let queue_index = i;
1023                dev.pci_device
1024                    .mmio_write(bar_address1 + 22, &queue_index.to_le_bytes())
1025                    .unwrap();
1026                dev.pci_device
1027                    .mmio_write(bar_address1 + 24, &self.queue_size.to_le_bytes())
1028                    .unwrap();
1029                // setup MSI information for the queue
1030                let msix_vector = queue_index + 1;
1031                let address = bar_address2 + 0x10 * msix_vector as u64;
1032                dev.pci_device
1033                    .mmio_write(address, &(msix_vector as u64).to_le_bytes())
1034                    .unwrap();
1035                let address = bar_address2 + 0x10 * msix_vector as u64 + 8;
1036                dev.pci_device
1037                    .mmio_write(address, &0_u32.to_le_bytes())
1038                    .unwrap();
1039                let address = bar_address2 + 0x10 * msix_vector as u64 + 12;
1040                dev.pci_device
1041                    .mmio_write(address, &0_u32.to_le_bytes())
1042                    .unwrap();
1043                dev.pci_device
1044                    .mmio_write(bar_address1 + 26, &msix_vector.to_le_bytes())
1045                    .unwrap();
1046                // setup queue addresses
1047                let desc_addr = self.get_queue_descriptor_base_address(queue_index);
1048                dev.write_u32(bar_address1 + 32, desc_addr as u32);
1049                dev.write_u32(bar_address1 + 36, (desc_addr >> 32) as u32);
1050                let avail_addr = self.get_queue_available_base_address(queue_index);
1051                dev.write_u32(bar_address1 + 40, avail_addr as u32);
1052                dev.write_u32(bar_address1 + 44, (avail_addr >> 32) as u32);
1053                let used_addr = self.get_queue_used_base_address(queue_index);
1054                dev.write_u32(bar_address1 + 48, used_addr as u32);
1055                dev.write_u32(bar_address1 + 52, (used_addr >> 32) as u32);
1056                // enable the queue
1057                let enabled: u16 = 1;
1058                dev.pci_device
1059                    .mmio_write(bar_address1 + 28, &enabled.to_le_bytes())
1060                    .unwrap();
1061            }
1062            // enable all device MSI interrupts
1063            dev.pci_device.pci_cfg_write(0x40, 0x80000000).unwrap();
1064            // run device
1065            device_status = VIRTIO_DRIVER_OK as u8;
1066            dev.pci_device
1067                .mmio_write(bar_address1 + 20, &device_status.to_le_bytes())
1068                .unwrap();
1069            let mut config_generation: [u8; 1] = [0];
1070            dev.pci_device
1071                .mmio_read(bar_address1 + 21, &mut config_generation)
1072                .unwrap();
1073            assert_eq!(config_generation[0], 2);
1074        }
1075
1076        fn get_queue_descriptor(&self, queue_index: u16, descriptor_index: u16) -> u64 {
1077            self.get_queue_descriptor_base_address(queue_index) + 0x10 * descriptor_index as u64
1078        }
1079
1080        fn add_queue_memory(&self, queue_index: u16) {
1081            // descriptors
1082            for i in 0..self.queue_size {
1083                let base = self.get_queue_descriptor(queue_index, i);
1084                // physical address
1085                self.test_mem.modify_memory_map(
1086                    base,
1087                    &(self.get_queue_descriptor_backing_memory_address(queue_index)
1088                        + 0x1000 * i as u64)
1089                        .to_le_bytes(),
1090                    false,
1091                );
1092                // length
1093                self.test_mem
1094                    .modify_memory_map(base + 8, &0x1000u32.to_le_bytes(), false);
1095                // flags
1096                self.test_mem
1097                    .modify_memory_map(base + 12, &0u16.to_le_bytes(), false);
1098                // next index
1099                self.test_mem
1100                    .modify_memory_map(base + 14, &0u16.to_le_bytes(), false);
1101            }
1102
1103            // available queue (flags, index)
1104            let base = self.get_queue_available_base_address(queue_index);
1105            self.test_mem
1106                .modify_memory_map(base, &0u16.to_le_bytes(), false);
1107            self.test_mem
1108                .modify_memory_map(base + 2, &0u16.to_le_bytes(), false);
1109            // available queue ring buffer
1110            for i in 0..self.queue_size {
1111                let base = base + 4 + 2 * i as u64;
1112                self.test_mem
1113                    .modify_memory_map(base, &0u16.to_le_bytes(), false);
1114            }
1115            // used event
1116            if self.use_ring_event_index {
1117                self.test_mem.modify_memory_map(
1118                    base + 4 + 2 * self.queue_size as u64,
1119                    &0u16.to_le_bytes(),
1120                    false,
1121                );
1122            }
1123
1124            // used queue (flags, index)
1125            let base = self.get_queue_used_base_address(queue_index);
1126            self.test_mem
1127                .modify_memory_map(base, &0u16.to_le_bytes(), true);
1128            self.test_mem
1129                .modify_memory_map(base + 2, &0u16.to_le_bytes(), true);
1130            for i in 0..self.queue_size {
1131                let base = base + 4 + 8 * i as u64;
1132                // index
1133                self.test_mem
1134                    .modify_memory_map(base, &0u32.to_le_bytes(), true);
1135                // length
1136                self.test_mem
1137                    .modify_memory_map(base + 4, &0u32.to_le_bytes(), true);
1138            }
1139            // available event
1140            if self.use_ring_event_index {
1141                self.test_mem.modify_memory_map(
1142                    base + 4 + 8 * self.queue_size as u64,
1143                    &0u16.to_le_bytes(),
1144                    true,
1145                );
1146            }
1147        }
1148
1149        fn reserve_descriptor(&mut self, queue_index: u16) -> u16 {
1150            let avail_descriptors = &mut self.avail_descriptors[queue_index as usize];
1151            for (i, desc) in avail_descriptors.iter_mut().enumerate() {
1152                if *desc {
1153                    *desc = false;
1154                    return i as u16;
1155                }
1156            }
1157
1158            panic!("No descriptors are available!");
1159        }
1160
1161        fn free_descriptor(&mut self, queue_index: u16, desc_index: u16) {
1162            assert!(desc_index < self.queue_size);
1163            let desc_addr = self.get_queue_descriptor(queue_index, desc_index);
1164            let flags: DescriptorFlags = self.test_mem.memory_map_get_u16(desc_addr + 12).into();
1165            if flags.next() {
1166                let next = self.test_mem.memory_map_get_u16(desc_addr + 14);
1167                self.free_descriptor(queue_index, next);
1168            }
1169            let avail_descriptors = &mut self.avail_descriptors[queue_index as usize];
1170            assert_eq!(avail_descriptors[desc_index as usize], false);
1171            avail_descriptors[desc_index as usize] = true;
1172        }
1173
1174        fn queue_available_desc(&mut self, queue_index: u16, desc_index: u16) {
1175            let avail_base_addr = self.get_queue_available_base_address(queue_index);
1176            let last_avail_index = &mut self.last_avail_index[queue_index as usize];
1177            let next_index = *last_avail_index % self.queue_size;
1178            *last_avail_index = last_avail_index.wrapping_add(1);
1179            self.test_mem.modify_memory_map(
1180                avail_base_addr + 4 + 2 * next_index as u64,
1181                &desc_index.to_le_bytes(),
1182                false,
1183            );
1184            self.test_mem.modify_memory_map(
1185                avail_base_addr + 2,
1186                &last_avail_index.to_le_bytes(),
1187                false,
1188            );
1189        }
1190
1191        fn add_to_avail_queue(&mut self, queue_index: u16) {
1192            let next_descriptor = self.reserve_descriptor(queue_index);
1193            // flags
1194            self.test_mem.modify_memory_map(
1195                self.get_queue_descriptor(queue_index, next_descriptor) + 12,
1196                &0u16.to_le_bytes(),
1197                false,
1198            );
1199            self.queue_available_desc(queue_index, next_descriptor);
1200        }
1201
1202        fn add_indirect_to_avail_queue(&mut self, queue_index: u16) {
1203            let next_descriptor = self.reserve_descriptor(queue_index);
1204            // flags
1205            self.test_mem.modify_memory_map(
1206                self.get_queue_descriptor(queue_index, next_descriptor) + 12,
1207                &u16::from(DescriptorFlags::new().with_indirect(true)).to_le_bytes(),
1208                false,
1209            );
1210            // create another (indirect) descriptor in the buffer
1211            let buffer_addr = self.get_queue_descriptor_backing_memory_address(queue_index);
1212            // physical address
1213            self.test_mem.modify_memory_map(
1214                buffer_addr,
1215                &0xffffffff00000000u64.to_le_bytes(),
1216                false,
1217            );
1218            // length
1219            self.test_mem
1220                .modify_memory_map(buffer_addr + 8, &0x1000u32.to_le_bytes(), false);
1221            // flags
1222            self.test_mem
1223                .modify_memory_map(buffer_addr + 12, &0u16.to_le_bytes(), false);
1224            // next index
1225            self.test_mem
1226                .modify_memory_map(buffer_addr + 14, &0u16.to_le_bytes(), false);
1227            self.queue_available_desc(queue_index, next_descriptor);
1228        }
1229
1230        fn add_linked_to_avail_queue(&mut self, queue_index: u16, desc_count: u16) {
1231            let mut descriptors = Vec::with_capacity(desc_count as usize);
1232            for _ in 0..desc_count {
1233                descriptors.push(self.reserve_descriptor(queue_index));
1234            }
1235
1236            for i in 0..descriptors.len() {
1237                let base = self.get_queue_descriptor(queue_index, descriptors[i]);
1238                let flags = if i < descriptors.len() - 1 {
1239                    u16::from(DescriptorFlags::new().with_next(true))
1240                } else {
1241                    0
1242                };
1243                self.test_mem
1244                    .modify_memory_map(base + 12, &flags.to_le_bytes(), false);
1245                let next = if i < descriptors.len() - 1 {
1246                    descriptors[i + 1]
1247                } else {
1248                    0
1249                };
1250                self.test_mem
1251                    .modify_memory_map(base + 14, &next.to_le_bytes(), false);
1252            }
1253            self.queue_available_desc(queue_index, descriptors[0]);
1254        }
1255
1256        fn add_indirect_linked_to_avail_queue(&mut self, queue_index: u16, desc_count: u16) {
1257            let next_descriptor = self.reserve_descriptor(queue_index);
1258            // flags
1259            self.test_mem.modify_memory_map(
1260                self.get_queue_descriptor(queue_index, next_descriptor) + 12,
1261                &u16::from(DescriptorFlags::new().with_indirect(true)).to_le_bytes(),
1262                false,
1263            );
1264            // create indirect descriptors in the buffer
1265            let buffer_addr = self.get_queue_descriptor_backing_memory_address(queue_index);
1266            for i in 0..desc_count {
1267                let base = buffer_addr + 0x10 * i as u64;
1268                let indirect_buffer_addr = 0xffffffff00000000u64 + 0x1000 * i as u64;
1269                // physical address
1270                self.test_mem
1271                    .modify_memory_map(base, &indirect_buffer_addr.to_le_bytes(), false);
1272                // length
1273                self.test_mem
1274                    .modify_memory_map(base + 8, &0x1000u32.to_le_bytes(), false);
1275                // flags
1276                let flags = if i < desc_count - 1 {
1277                    u16::from(DescriptorFlags::new().with_next(true))
1278                } else {
1279                    0
1280                };
1281                self.test_mem
1282                    .modify_memory_map(base + 12, &flags.to_le_bytes(), false);
1283                // next index
1284                let next = if i < desc_count - 1 { i + 1 } else { 0 };
1285                self.test_mem
1286                    .modify_memory_map(base + 14, &next.to_le_bytes(), false);
1287            }
1288            self.queue_available_desc(queue_index, next_descriptor);
1289        }
1290
1291        fn get_next_completed(&mut self, queue_index: u16) -> Option<(u16, u32)> {
1292            let avail_base_addr = self.get_queue_available_base_address(queue_index);
1293            let used_base_addr = self.get_queue_used_base_address(queue_index);
1294            let cur_used_index = self.test_mem.memory_map_get_u16(used_base_addr + 2);
1295            let last_used_index = &mut self.last_used_index[queue_index as usize];
1296            if *last_used_index == cur_used_index {
1297                return None;
1298            }
1299
1300            if self.use_ring_event_index {
1301                self.test_mem.modify_memory_map(
1302                    avail_base_addr + 4 + 2 * self.queue_size as u64,
1303                    &cur_used_index.to_le_bytes(),
1304                    false,
1305                );
1306            }
1307
1308            let next_index = *last_used_index % self.queue_size;
1309            *last_used_index = last_used_index.wrapping_add(1);
1310            let desc_index = self
1311                .test_mem
1312                .memory_map_get_u32(used_base_addr + 4 + 8 * next_index as u64);
1313            let desc_index = desc_index as u16;
1314            let bytes_written = self
1315                .test_mem
1316                .memory_map_get_u32(used_base_addr + 8 + 8 * next_index as u64);
1317            self.free_descriptor(queue_index, desc_index);
1318            Some((desc_index, bytes_written))
1319        }
1320    }
1321
1322    struct VirtioTestWork {
1323        callback: VirtioTestWorkCallback,
1324    }
1325
1326    #[async_trait]
1327    impl VirtioQueueWorkerContext for VirtioTestWork {
1328        async fn process_work(&mut self, work: anyhow::Result<VirtioQueueCallbackWork>) -> bool {
1329            (self.callback)(work)
1330        }
1331    }
1332    struct VirtioPciTestDevice {
1333        pci_device: VirtioPciDevice,
1334        test_intc: Arc<TestPciInterruptController>,
1335    }
1336
1337    type TestDeviceQueueWorkFn = Arc<dyn Fn(u16, VirtioQueueCallbackWork) + Send + Sync>;
1338
1339    struct TestDevice {
1340        traits: DeviceTraits,
1341        queue_work: Option<TestDeviceQueueWorkFn>,
1342    }
1343
1344    impl TestDevice {
1345        fn new(traits: DeviceTraits, queue_work: Option<TestDeviceQueueWorkFn>) -> Self {
1346            Self { traits, queue_work }
1347        }
1348    }
1349
1350    impl LegacyVirtioDevice for TestDevice {
1351        fn traits(&self) -> DeviceTraits {
1352            self.traits
1353        }
1354
1355        fn read_registers_u32(&self, _offset: u16) -> u32 {
1356            0
1357        }
1358
1359        fn write_registers_u32(&mut self, _offset: u16, _val: u32) {}
1360
1361        fn get_work_callback(&mut self, index: u16) -> Box<dyn VirtioQueueWorkerContext + Send> {
1362            Box::new(TestDeviceWorker {
1363                index,
1364                queue_work: self.queue_work.clone(),
1365            })
1366        }
1367
1368        fn state_change(&mut self, _state: &VirtioState) {}
1369    }
1370
1371    struct TestDeviceWorker {
1372        index: u16,
1373        queue_work: Option<TestDeviceQueueWorkFn>,
1374    }
1375
1376    #[async_trait]
1377    impl VirtioQueueWorkerContext for TestDeviceWorker {
1378        async fn process_work(&mut self, work: anyhow::Result<VirtioQueueCallbackWork>) -> bool {
1379            if let Err(err) = work {
1380                panic!(
1381                    "Invalid virtio queue state index {} error {}",
1382                    self.index,
1383                    err.as_ref() as &dyn std::error::Error
1384                );
1385            }
1386            if let Some(ref func) = self.queue_work {
1387                (func)(self.index, work.unwrap());
1388            }
1389            true
1390        }
1391    }
1392
1393    impl VirtioPciTestDevice {
1394        fn new(
1395            driver: &DefaultDriver,
1396            num_queues: u16,
1397            test_mem: &Arc<VirtioTestMemoryAccess>,
1398            queue_work: Option<TestDeviceQueueWorkFn>,
1399        ) -> Self {
1400            let doorbell_registration: Arc<dyn DoorbellRegistration> = test_mem.clone();
1401            let mem = GuestMemory::new("test", test_mem.clone());
1402            let mut msi_set = MsiInterruptSet::new();
1403
1404            let dev = VirtioPciDevice::new(
1405                Box::new(LegacyWrapper::new(
1406                    &VmTaskDriverSource::new(SingleDriverBackend::new(driver.clone())),
1407                    TestDevice::new(
1408                        DeviceTraits {
1409                            device_id: 3,
1410                            device_features: 2,
1411                            max_queues: num_queues,
1412                            device_register_length: 12,
1413                            ..Default::default()
1414                        },
1415                        queue_work,
1416                    ),
1417                    &mem,
1418                )),
1419                PciInterruptModel::Msix(&mut msi_set),
1420                Some(doorbell_registration),
1421                &mut ExternallyManagedMmioIntercepts,
1422                None,
1423            )
1424            .unwrap();
1425
1426            let test_intc = Arc::new(TestPciInterruptController::new());
1427            msi_set.connect(test_intc.as_ref());
1428
1429            Self {
1430                pci_device: dev,
1431                test_intc,
1432            }
1433        }
1434
1435        fn read_u32(&mut self, address: u64) -> u32 {
1436            let mut value = [0; 4];
1437            self.pci_device.mmio_read(address, &mut value).unwrap();
1438            u32::from_ne_bytes(value)
1439        }
1440
1441        fn write_u32(&mut self, address: u64, value: u32) {
1442            self.pci_device
1443                .mmio_write(address, &value.to_ne_bytes())
1444                .unwrap();
1445        }
1446    }
1447
1448    #[async_test]
1449    async fn verify_chipset_config(driver: DefaultDriver) {
1450        let mem = VirtioTestMemoryAccess::new();
1451        let doorbell_registration: Arc<dyn DoorbellRegistration> = mem.clone();
1452        let mem = GuestMemory::new("test", mem);
1453        let interrupt = LineInterrupt::detached();
1454
1455        let mut dev = VirtioMmioDevice::new(
1456            Box::new(LegacyWrapper::new(
1457                &VmTaskDriverSource::new(SingleDriverBackend::new(driver)),
1458                TestDevice::new(
1459                    DeviceTraits {
1460                        device_id: 3,
1461                        device_features: 2,
1462                        max_queues: 1,
1463                        device_register_length: 0,
1464                        ..Default::default()
1465                    },
1466                    None,
1467                ),
1468                &mem,
1469            )),
1470            interrupt,
1471            Some(doorbell_registration),
1472            0,
1473            1,
1474        );
1475        // magic value
1476        assert_eq!(dev.read_u32(0), u32::from_le_bytes(*b"virt"));
1477        // version
1478        assert_eq!(dev.read_u32(4), 2);
1479        // device ID
1480        assert_eq!(dev.read_u32(8), 3);
1481        // vendor ID
1482        assert_eq!(dev.read_u32(12), 0x1af4);
1483        // device feature (bank 0)
1484        assert_eq!(
1485            dev.read_u32(16),
1486            VIRTIO_F_RING_INDIRECT_DESC | VIRTIO_F_RING_EVENT_IDX | 2
1487        );
1488        // device feature bank index
1489        assert_eq!(dev.read_u32(20), 0);
1490        // device feature (bank 1)
1491        dev.write_u32(20, 1);
1492        assert_eq!(dev.read_u32(20), 1);
1493        assert_eq!(dev.read_u32(16), VIRTIO_F_VERSION_1);
1494        // device feature (bank 2)
1495        dev.write_u32(20, 2);
1496        assert_eq!(dev.read_u32(16), 0);
1497        // driver feature (bank 0)
1498        assert_eq!(dev.read_u32(32), 0);
1499        dev.write_u32(32, 2);
1500        assert_eq!(dev.read_u32(32), 2);
1501        dev.write_u32(32, 0xffffffff);
1502        assert_eq!(
1503            dev.read_u32(32),
1504            VIRTIO_F_RING_INDIRECT_DESC | VIRTIO_F_RING_EVENT_IDX | 2
1505        );
1506        // driver feature bank index
1507        assert_eq!(dev.read_u32(36), 0);
1508        dev.write_u32(36, 1);
1509        assert_eq!(dev.read_u32(36), 1);
1510        // driver feature (bank 1)
1511        assert_eq!(dev.read_u32(32), 0);
1512        dev.write_u32(32, 0xffffffff);
1513        assert_eq!(dev.read_u32(32), VIRTIO_F_VERSION_1);
1514        // driver feature (bank 2)
1515        dev.write_u32(36, 2);
1516        assert_eq!(dev.read_u32(32), 0);
1517        dev.write_u32(32, 0xffffffff);
1518        assert_eq!(dev.read_u32(32), 0);
1519        // host notify
1520        assert_eq!(dev.read_u32(80), 0);
1521        // interrupt status
1522        assert_eq!(dev.read_u32(96), 0);
1523        // interrupt ACK (queue 0)
1524        assert_eq!(dev.read_u32(100), 0);
1525        // device status
1526        assert_eq!(dev.read_u32(112), 0);
1527        // config generation
1528        assert_eq!(dev.read_u32(0xfc), 0);
1529
1530        // queue index
1531        assert_eq!(dev.read_u32(48), 0);
1532        // queue max size (queue 0)
1533        assert_eq!(dev.read_u32(52), 0x40);
1534        // queue size (queue 0)
1535        assert_eq!(dev.read_u32(56), 0x40);
1536        dev.write_u32(56, 0x20);
1537        assert_eq!(dev.read_u32(56), 0x20);
1538        // queue enable (queue 0)
1539        assert_eq!(dev.read_u32(68), 0);
1540        dev.write_u32(68, 1);
1541        assert_eq!(dev.read_u32(68), 1);
1542        dev.write_u32(68, 0xffffffff);
1543        assert_eq!(dev.read_u32(68), 1);
1544        dev.write_u32(68, 0);
1545        assert_eq!(dev.read_u32(68), 0);
1546        // queue descriptor address low (queue 0)
1547        assert_eq!(dev.read_u32(128), 0);
1548        dev.write_u32(128, 0xffff);
1549        assert_eq!(dev.read_u32(128), 0xffff);
1550        // queue descriptor address high (queue 0)
1551        assert_eq!(dev.read_u32(132), 0);
1552        dev.write_u32(132, 1);
1553        assert_eq!(dev.read_u32(132), 1);
1554        // queue available address low (queue 0)
1555        assert_eq!(dev.read_u32(144), 0);
1556        dev.write_u32(144, 0xeeee);
1557        assert_eq!(dev.read_u32(144), 0xeeee);
1558        // queue available address high (queue 0)
1559        assert_eq!(dev.read_u32(148), 0);
1560        dev.write_u32(148, 2);
1561        assert_eq!(dev.read_u32(148), 2);
1562        // queue used address low (queue 0)
1563        assert_eq!(dev.read_u32(160), 0);
1564        dev.write_u32(160, 0xdddd);
1565        assert_eq!(dev.read_u32(160), 0xdddd);
1566        // queue used address high (queue 0)
1567        assert_eq!(dev.read_u32(164), 0);
1568        dev.write_u32(164, 3);
1569        assert_eq!(dev.read_u32(164), 3);
1570
1571        // switch to queue #1
1572        dev.write_u32(48, 1);
1573        assert_eq!(dev.read_u32(48), 1);
1574        // queue max size (queue 1)
1575        assert_eq!(dev.read_u32(52), 0);
1576        // queue size (queue 1)
1577        assert_eq!(dev.read_u32(56), 0);
1578        dev.write_u32(56, 2);
1579        assert_eq!(dev.read_u32(56), 0);
1580        // queue enable (queue 1)
1581        assert_eq!(dev.read_u32(68), 0);
1582        dev.write_u32(68, 1);
1583        assert_eq!(dev.read_u32(68), 0);
1584        // queue descriptor address low (queue 1)
1585        assert_eq!(dev.read_u32(128), 0);
1586        dev.write_u32(128, 1);
1587        assert_eq!(dev.read_u32(128), 0);
1588        // queue descriptor address high (queue 1)
1589        assert_eq!(dev.read_u32(132), 0);
1590        dev.write_u32(132, 1);
1591        assert_eq!(dev.read_u32(132), 0);
1592        // queue available address low (queue 1)
1593        assert_eq!(dev.read_u32(144), 0);
1594        dev.write_u32(144, 1);
1595        assert_eq!(dev.read_u32(144), 0);
1596        // queue available address high (queue 1)
1597        assert_eq!(dev.read_u32(148), 0);
1598        dev.write_u32(148, 1);
1599        assert_eq!(dev.read_u32(148), 0);
1600        // queue used address low (queue 1)
1601        assert_eq!(dev.read_u32(160), 0);
1602        dev.write_u32(160, 1);
1603        assert_eq!(dev.read_u32(160), 0);
1604        // queue used address high (queue 1)
1605        assert_eq!(dev.read_u32(164), 0);
1606        dev.write_u32(164, 1);
1607        assert_eq!(dev.read_u32(164), 0);
1608    }
1609
1610    #[async_test]
1611    async fn verify_pci_config(driver: DefaultDriver) {
1612        let mut pci_test_device =
1613            VirtioPciTestDevice::new(&driver, 1, &VirtioTestMemoryAccess::new(), None);
1614        let mut capabilities = 0;
1615        pci_test_device
1616            .pci_device
1617            .pci_cfg_read(4, &mut capabilities)
1618            .unwrap();
1619        assert_eq!(
1620            capabilities,
1621            (cfg_space::Status::new()
1622                .with_capabilities_list(true)
1623                .into_bits() as u32)
1624                << 16
1625        );
1626        let mut next_cap_offset = 0;
1627        pci_test_device
1628            .pci_device
1629            .pci_cfg_read(0x34, &mut next_cap_offset)
1630            .unwrap();
1631        assert_ne!(next_cap_offset, 0);
1632
1633        let mut header = 0;
1634        pci_test_device
1635            .pci_device
1636            .pci_cfg_read(next_cap_offset as u16, &mut header)
1637            .unwrap();
1638        let header = header.to_le_bytes();
1639        assert_eq!(header[0], CapabilityId::MSIX.0);
1640        next_cap_offset = header[1] as u32;
1641        assert_ne!(next_cap_offset, 0);
1642
1643        let mut header = 0;
1644        pci_test_device
1645            .pci_device
1646            .pci_cfg_read(next_cap_offset as u16, &mut header)
1647            .unwrap();
1648        let header = header.to_le_bytes();
1649        assert_eq!(header[0], CapabilityId::VENDOR_SPECIFIC.0);
1650        assert_eq!(header[3], VIRTIO_PCI_CAP_COMMON_CFG);
1651        assert_eq!(header[2], 16);
1652        let mut buf = 0;
1653
1654        pci_test_device
1655            .pci_device
1656            .pci_cfg_read(next_cap_offset as u16 + 4, &mut buf)
1657            .unwrap();
1658        assert_eq!(buf, 0);
1659        pci_test_device
1660            .pci_device
1661            .pci_cfg_read(next_cap_offset as u16 + 8, &mut buf)
1662            .unwrap();
1663        assert_eq!(buf, 0);
1664        pci_test_device
1665            .pci_device
1666            .pci_cfg_read(next_cap_offset as u16 + 12, &mut buf)
1667            .unwrap();
1668        assert_eq!(buf, 0x38);
1669        next_cap_offset = header[1] as u32;
1670        assert_ne!(next_cap_offset, 0);
1671
1672        let mut header = 0;
1673        pci_test_device
1674            .pci_device
1675            .pci_cfg_read(next_cap_offset as u16, &mut header)
1676            .unwrap();
1677        let header = header.to_le_bytes();
1678        assert_eq!(header[0], CapabilityId::VENDOR_SPECIFIC.0);
1679        assert_eq!(header[3], VIRTIO_PCI_CAP_NOTIFY_CFG);
1680        assert_eq!(header[2], 20);
1681        pci_test_device
1682            .pci_device
1683            .pci_cfg_read(next_cap_offset as u16 + 4, &mut buf)
1684            .unwrap();
1685        assert_eq!(buf, 0);
1686        pci_test_device
1687            .pci_device
1688            .pci_cfg_read(next_cap_offset as u16 + 8, &mut buf)
1689            .unwrap();
1690        assert_eq!(buf, 0x38);
1691        pci_test_device
1692            .pci_device
1693            .pci_cfg_read(next_cap_offset as u16 + 12, &mut buf)
1694            .unwrap();
1695        assert_eq!(buf, 4);
1696        next_cap_offset = header[1] as u32;
1697        assert_ne!(next_cap_offset, 0);
1698
1699        let mut header = 0;
1700        pci_test_device
1701            .pci_device
1702            .pci_cfg_read(next_cap_offset as u16, &mut header)
1703            .unwrap();
1704        let header = header.to_le_bytes();
1705        assert_eq!(header[0], CapabilityId::VENDOR_SPECIFIC.0);
1706        assert_eq!(header[3], VIRTIO_PCI_CAP_ISR_CFG);
1707        assert_eq!(header[2], 16);
1708        pci_test_device
1709            .pci_device
1710            .pci_cfg_read(next_cap_offset as u16 + 4, &mut buf)
1711            .unwrap();
1712        assert_eq!(buf, 0);
1713        pci_test_device
1714            .pci_device
1715            .pci_cfg_read(next_cap_offset as u16 + 8, &mut buf)
1716            .unwrap();
1717        assert_eq!(buf, 0x3c);
1718        pci_test_device
1719            .pci_device
1720            .pci_cfg_read(next_cap_offset as u16 + 12, &mut buf)
1721            .unwrap();
1722        assert_eq!(buf, 4);
1723        next_cap_offset = header[1] as u32;
1724        assert_ne!(next_cap_offset, 0);
1725
1726        let mut header = 0;
1727        pci_test_device
1728            .pci_device
1729            .pci_cfg_read(next_cap_offset as u16, &mut header)
1730            .unwrap();
1731        let header = header.to_le_bytes();
1732        assert_eq!(header[0], CapabilityId::VENDOR_SPECIFIC.0);
1733        assert_eq!(header[3], VIRTIO_PCI_CAP_DEVICE_CFG);
1734        assert_eq!(header[2], 16);
1735        pci_test_device
1736            .pci_device
1737            .pci_cfg_read(next_cap_offset as u16 + 4, &mut buf)
1738            .unwrap();
1739        assert_eq!(buf, 0);
1740        pci_test_device
1741            .pci_device
1742            .pci_cfg_read(next_cap_offset as u16 + 8, &mut buf)
1743            .unwrap();
1744        assert_eq!(buf, 0x40);
1745        pci_test_device
1746            .pci_device
1747            .pci_cfg_read(next_cap_offset as u16 + 12, &mut buf)
1748            .unwrap();
1749        assert_eq!(buf, 12);
1750        next_cap_offset = header[1] as u32;
1751        assert_eq!(next_cap_offset, 0);
1752    }
1753
1754    #[async_test]
1755    async fn verify_pci_registers(driver: DefaultDriver) {
1756        let mut pci_test_device =
1757            VirtioPciTestDevice::new(&driver, 1, &VirtioTestMemoryAccess::new(), None);
1758        let bar_address1: u64 = 0x2000000000;
1759        pci_test_device
1760            .pci_device
1761            .pci_cfg_write(0x14, (bar_address1 >> 32) as u32)
1762            .unwrap();
1763        pci_test_device
1764            .pci_device
1765            .pci_cfg_write(0x10, bar_address1 as u32)
1766            .unwrap();
1767
1768        let bar_address2: u64 = 0x4000;
1769        pci_test_device
1770            .pci_device
1771            .pci_cfg_write(0x1c, (bar_address2 >> 32) as u32)
1772            .unwrap();
1773        pci_test_device
1774            .pci_device
1775            .pci_cfg_write(0x18, bar_address2 as u32)
1776            .unwrap();
1777
1778        pci_test_device
1779            .pci_device
1780            .pci_cfg_write(
1781                0x4,
1782                cfg_space::Command::new()
1783                    .with_mmio_enabled(true)
1784                    .into_bits() as u32,
1785            )
1786            .unwrap();
1787
1788        // device feature bank index
1789        assert_eq!(pci_test_device.read_u32(bar_address1), 0);
1790        // device feature (bank 0)
1791        assert_eq!(
1792            pci_test_device.read_u32(bar_address1 + 4),
1793            VIRTIO_F_RING_INDIRECT_DESC | VIRTIO_F_RING_EVENT_IDX | 2
1794        );
1795        // device feature (bank 1)
1796        pci_test_device.write_u32(bar_address1, 1);
1797        assert_eq!(pci_test_device.read_u32(bar_address1), 1);
1798        assert_eq!(
1799            pci_test_device.read_u32(bar_address1 + 4),
1800            VIRTIO_F_VERSION_1
1801        );
1802        // device feature (bank 2)
1803        pci_test_device.write_u32(bar_address1, 2);
1804        assert_eq!(pci_test_device.read_u32(bar_address1), 2);
1805        assert_eq!(pci_test_device.read_u32(bar_address1 + 4), 0);
1806        // driver feature bank index
1807        assert_eq!(pci_test_device.read_u32(bar_address1 + 8), 0);
1808        // driver feature (bank 0)
1809        assert_eq!(pci_test_device.read_u32(bar_address1 + 12), 0);
1810        pci_test_device.write_u32(bar_address1 + 12, 2);
1811        assert_eq!(pci_test_device.read_u32(bar_address1 + 12), 2);
1812        pci_test_device.write_u32(bar_address1 + 12, 0xffffffff);
1813        assert_eq!(
1814            pci_test_device.read_u32(bar_address1 + 12),
1815            VIRTIO_F_RING_INDIRECT_DESC | VIRTIO_F_RING_EVENT_IDX | 2
1816        );
1817        // driver feature (bank 1)
1818        pci_test_device.write_u32(bar_address1 + 8, 1);
1819        assert_eq!(pci_test_device.read_u32(bar_address1 + 8), 1);
1820        assert_eq!(pci_test_device.read_u32(bar_address1 + 12), 0);
1821        pci_test_device.write_u32(bar_address1 + 12, 0xffffffff);
1822        assert_eq!(
1823            pci_test_device.read_u32(bar_address1 + 12),
1824            VIRTIO_F_VERSION_1
1825        );
1826        // driver feature (bank 2)
1827        pci_test_device.write_u32(bar_address1 + 8, 2);
1828        assert_eq!(pci_test_device.read_u32(bar_address1 + 8), 2);
1829        assert_eq!(pci_test_device.read_u32(bar_address1 + 12), 0);
1830        pci_test_device.write_u32(bar_address1 + 12, 0xffffffff);
1831        assert_eq!(pci_test_device.read_u32(bar_address1 + 12), 0);
1832        // max queues and the msix vector for config changes
1833        assert_eq!(pci_test_device.read_u32(bar_address1 + 16), 1 << 16);
1834        // queue index, config generation and device status
1835        assert_eq!(pci_test_device.read_u32(bar_address1 + 20), 0);
1836        // current queue size and msix vector
1837        assert_eq!(pci_test_device.read_u32(bar_address1 + 24), 0x40);
1838        pci_test_device.write_u32(bar_address1 + 24, 0x20);
1839        assert_eq!(pci_test_device.read_u32(bar_address1 + 24), 0x20);
1840        // current queue enabled and notify offset
1841        assert_eq!(pci_test_device.read_u32(bar_address1 + 28), 0);
1842        pci_test_device.write_u32(bar_address1 + 28, 1);
1843        assert_eq!(pci_test_device.read_u32(bar_address1 + 28), 1);
1844        pci_test_device.write_u32(bar_address1 + 28, 0xffff);
1845        assert_eq!(pci_test_device.read_u32(bar_address1 + 28), 1);
1846        pci_test_device.write_u32(bar_address1 + 28, 0);
1847        assert_eq!(pci_test_device.read_u32(bar_address1 + 28), 0);
1848        // current queue descriptor table address (low)
1849        assert_eq!(pci_test_device.read_u32(bar_address1 + 32), 0);
1850        pci_test_device.write_u32(bar_address1 + 32, 0xffff);
1851        assert_eq!(pci_test_device.read_u32(bar_address1 + 32), 0xffff);
1852        // current queue descriptor table address (high)
1853        assert_eq!(pci_test_device.read_u32(bar_address1 + 36), 0);
1854        pci_test_device.write_u32(bar_address1 + 36, 1);
1855        assert_eq!(pci_test_device.read_u32(bar_address1 + 36), 1);
1856        // current queue available ring address (low)
1857        assert_eq!(pci_test_device.read_u32(bar_address1 + 40), 0);
1858        pci_test_device.write_u32(bar_address1 + 40, 0xeeee);
1859        assert_eq!(pci_test_device.read_u32(bar_address1 + 40), 0xeeee);
1860        // current queue available ring address (high)
1861        assert_eq!(pci_test_device.read_u32(bar_address1 + 44), 0);
1862        pci_test_device.write_u32(bar_address1 + 44, 2);
1863        assert_eq!(pci_test_device.read_u32(bar_address1 + 44), 2);
1864        // current queue used ring address (low)
1865        assert_eq!(pci_test_device.read_u32(bar_address1 + 48), 0);
1866        pci_test_device.write_u32(bar_address1 + 48, 0xdddd);
1867        assert_eq!(pci_test_device.read_u32(bar_address1 + 48), 0xdddd);
1868        // current queue used ring address (high)
1869        assert_eq!(pci_test_device.read_u32(bar_address1 + 52), 0);
1870        pci_test_device.write_u32(bar_address1 + 52, 3);
1871        assert_eq!(pci_test_device.read_u32(bar_address1 + 52), 3);
1872        // VIRTIO_PCI_CAP_NOTIFY_CFG notification register
1873        assert_eq!(pci_test_device.read_u32(bar_address1 + 56), 0);
1874        // VIRTIO_PCI_CAP_ISR_CFG register
1875        assert_eq!(pci_test_device.read_u32(bar_address1 + 60), 0);
1876
1877        // switch to queue #1 (disabled, only one queue on this device)
1878        let queue_index: u16 = 1;
1879        pci_test_device
1880            .pci_device
1881            .mmio_write(bar_address1 + 22, &queue_index.to_le_bytes())
1882            .unwrap();
1883        assert_eq!(pci_test_device.read_u32(bar_address1 + 20), 1 << 24);
1884        // current queue size and msix vector
1885        assert_eq!(pci_test_device.read_u32(bar_address1 + 24), 0);
1886        pci_test_device.write_u32(bar_address1 + 24, 2);
1887        assert_eq!(pci_test_device.read_u32(bar_address1 + 24), 0);
1888        // current queue enabled and notify offset
1889        assert_eq!(pci_test_device.read_u32(bar_address1 + 28), 0);
1890        pci_test_device.write_u32(bar_address1 + 28, 1);
1891        assert_eq!(pci_test_device.read_u32(bar_address1 + 28), 0);
1892        // current queue descriptor table address (low)
1893        assert_eq!(pci_test_device.read_u32(bar_address1 + 32), 0);
1894        pci_test_device.write_u32(bar_address1 + 32, 0x10);
1895        assert_eq!(pci_test_device.read_u32(bar_address1 + 32), 0);
1896        // current queue descriptor table address (high)
1897        assert_eq!(pci_test_device.read_u32(bar_address1 + 36), 0);
1898        pci_test_device.write_u32(bar_address1 + 36, 0x10);
1899        assert_eq!(pci_test_device.read_u32(bar_address1 + 36), 0);
1900        // current queue available ring address (low)
1901        assert_eq!(pci_test_device.read_u32(bar_address1 + 40), 0);
1902        pci_test_device.write_u32(bar_address1 + 40, 0x10);
1903        assert_eq!(pci_test_device.read_u32(bar_address1 + 40), 0);
1904        // current queue available ring address (high)
1905        assert_eq!(pci_test_device.read_u32(bar_address1 + 44), 0);
1906        pci_test_device.write_u32(bar_address1 + 44, 0x10);
1907        assert_eq!(pci_test_device.read_u32(bar_address1 + 44), 0);
1908        // current queue used ring address (low)
1909        assert_eq!(pci_test_device.read_u32(bar_address1 + 48), 0);
1910        pci_test_device.write_u32(bar_address1 + 48, 0x10);
1911        assert_eq!(pci_test_device.read_u32(bar_address1 + 48), 0);
1912        // current queue used ring address (high)
1913        assert_eq!(pci_test_device.read_u32(bar_address1 + 52), 0);
1914        pci_test_device.write_u32(bar_address1 + 52, 0x10);
1915        assert_eq!(pci_test_device.read_u32(bar_address1 + 52), 0);
1916    }
1917
1918    #[async_test]
1919    async fn verify_queue_simple(driver: DefaultDriver) {
1920        let test_mem = VirtioTestMemoryAccess::new();
1921        let mut guest = VirtioTestGuest::new(&driver, &test_mem, 1, 2, true);
1922        let base_addr = guest.get_queue_descriptor_backing_memory_address(0);
1923        let (tx, mut rx) = mesh::mpsc_channel();
1924        let event = Event::new();
1925        let mut queues = guest.create_direct_queues(|i| {
1926            let tx = tx.clone();
1927            CreateDirectQueueParams {
1928                process_work: Box::new(move |work: anyhow::Result<VirtioQueueCallbackWork>| {
1929                    let mut work = work.expect("Queue failure");
1930                    assert_eq!(work.payload.len(), 1);
1931                    assert_eq!(work.payload[0].address, base_addr);
1932                    assert_eq!(work.payload[0].length, 0x1000);
1933                    work.complete(123);
1934                    true
1935                }),
1936                notify: Interrupt::from_fn(move || {
1937                    tx.send(i as usize);
1938                }),
1939                event: event.clone(),
1940            }
1941        });
1942
1943        guest.add_to_avail_queue(0);
1944        event.signal();
1945        must_recv_in_timeout(&mut rx, Duration::from_millis(100)).await;
1946        let (desc, len) = guest.get_next_completed(0).unwrap();
1947        assert_eq!(desc, 0u16);
1948        assert_eq!(len, 123);
1949        assert_eq!(guest.get_next_completed(0).is_none(), true);
1950        queues[0].stop().await;
1951    }
1952
1953    #[async_test]
1954    async fn verify_queue_indirect(driver: DefaultDriver) {
1955        let test_mem = VirtioTestMemoryAccess::new();
1956        let mut guest = VirtioTestGuest::new(&driver, &test_mem, 1, 2, true);
1957        let (tx, mut rx) = mesh::mpsc_channel();
1958        let event = Event::new();
1959        let mut queues = guest.create_direct_queues(|i| {
1960            let tx = tx.clone();
1961            CreateDirectQueueParams {
1962                process_work: Box::new(move |work: anyhow::Result<VirtioQueueCallbackWork>| {
1963                    let mut work = work.expect("Queue failure");
1964                    assert_eq!(work.payload.len(), 1);
1965                    assert_eq!(work.payload[0].address, 0xffffffff00000000u64);
1966                    assert_eq!(work.payload[0].length, 0x1000);
1967                    work.complete(123);
1968                    true
1969                }),
1970                notify: Interrupt::from_fn(move || {
1971                    tx.send(i as usize);
1972                }),
1973                event: event.clone(),
1974            }
1975        });
1976
1977        guest.add_indirect_to_avail_queue(0);
1978        event.signal();
1979        must_recv_in_timeout(&mut rx, Duration::from_millis(100)).await;
1980        let (desc, len) = guest.get_next_completed(0).unwrap();
1981        assert_eq!(desc, 0u16);
1982        assert_eq!(len, 123);
1983        assert_eq!(guest.get_next_completed(0).is_none(), true);
1984        queues[0].stop().await;
1985    }
1986
1987    #[async_test]
1988    async fn verify_queue_linked(driver: DefaultDriver) {
1989        let test_mem = VirtioTestMemoryAccess::new();
1990        let mut guest = VirtioTestGuest::new(&driver, &test_mem, 1, 5, true);
1991        let (tx, mut rx) = mesh::mpsc_channel();
1992        let base_address = guest.get_queue_descriptor_backing_memory_address(0);
1993        let event = Event::new();
1994        let mut queues = guest.create_direct_queues(|i| {
1995            let tx = tx.clone();
1996            CreateDirectQueueParams {
1997                process_work: Box::new(move |work: anyhow::Result<VirtioQueueCallbackWork>| {
1998                    let mut work = work.expect("Queue failure");
1999                    assert_eq!(work.payload.len(), 3);
2000                    for i in 0..work.payload.len() {
2001                        assert_eq!(work.payload[i].address, base_address + 0x1000 * i as u64);
2002                        assert_eq!(work.payload[i].length, 0x1000);
2003                    }
2004                    work.complete(123 * 3);
2005                    true
2006                }),
2007                notify: Interrupt::from_fn(move || {
2008                    tx.send(i as usize);
2009                }),
2010                event: event.clone(),
2011            }
2012        });
2013
2014        guest.add_linked_to_avail_queue(0, 3);
2015        event.signal();
2016        must_recv_in_timeout(&mut rx, Duration::from_millis(100)).await;
2017        let (desc, len) = guest.get_next_completed(0).unwrap();
2018        assert_eq!(desc, 0u16);
2019        assert_eq!(len, 123 * 3);
2020        assert_eq!(guest.get_next_completed(0).is_none(), true);
2021        queues[0].stop().await;
2022    }
2023
2024    #[async_test]
2025    async fn verify_queue_indirect_linked(driver: DefaultDriver) {
2026        let test_mem = VirtioTestMemoryAccess::new();
2027        let mut guest = VirtioTestGuest::new(&driver, &test_mem, 1, 5, true);
2028        let (tx, mut rx) = mesh::mpsc_channel();
2029        let event = Event::new();
2030        let mut queues = guest.create_direct_queues(|i| {
2031            let tx = tx.clone();
2032            CreateDirectQueueParams {
2033                process_work: Box::new(move |work: anyhow::Result<VirtioQueueCallbackWork>| {
2034                    let mut work = work.expect("Queue failure");
2035                    assert_eq!(work.payload.len(), 3);
2036                    for i in 0..work.payload.len() {
2037                        assert_eq!(
2038                            work.payload[i].address,
2039                            0xffffffff00000000u64 + 0x1000 * i as u64
2040                        );
2041                        assert_eq!(work.payload[i].length, 0x1000);
2042                    }
2043                    work.complete(123 * 3);
2044                    true
2045                }),
2046                notify: Interrupt::from_fn(move || {
2047                    tx.send(i as usize);
2048                }),
2049                event: event.clone(),
2050            }
2051        });
2052
2053        guest.add_indirect_linked_to_avail_queue(0, 3);
2054        event.signal();
2055        must_recv_in_timeout(&mut rx, Duration::from_millis(100)).await;
2056        let (desc, len) = guest.get_next_completed(0).unwrap();
2057        assert_eq!(desc, 0u16);
2058        assert_eq!(len, 123 * 3);
2059        assert_eq!(guest.get_next_completed(0).is_none(), true);
2060        queues[0].stop().await;
2061    }
2062
2063    #[async_test]
2064    async fn verify_queue_avail_rollover(driver: DefaultDriver) {
2065        let test_mem = VirtioTestMemoryAccess::new();
2066        let mut guest = VirtioTestGuest::new(&driver, &test_mem, 1, 2, true);
2067        let base_addr = guest.get_queue_descriptor_backing_memory_address(0);
2068        let (tx, mut rx) = mesh::mpsc_channel();
2069        let event = Event::new();
2070        let mut queues = guest.create_direct_queues(|i| {
2071            let tx = tx.clone();
2072            CreateDirectQueueParams {
2073                process_work: Box::new(move |work: anyhow::Result<VirtioQueueCallbackWork>| {
2074                    let mut work = work.expect("Queue failure");
2075                    assert_eq!(work.payload.len(), 1);
2076                    assert_eq!(work.payload[0].address, base_addr);
2077                    assert_eq!(work.payload[0].length, 0x1000);
2078                    work.complete(123);
2079                    true
2080                }),
2081                notify: Interrupt::from_fn(move || {
2082                    tx.send(i as usize);
2083                }),
2084                event: event.clone(),
2085            }
2086        });
2087
2088        for _ in 0..3 {
2089            guest.add_to_avail_queue(0);
2090            event.signal();
2091            must_recv_in_timeout(&mut rx, Duration::from_millis(100)).await;
2092            let (desc, len) = guest.get_next_completed(0).unwrap();
2093            assert_eq!(desc, 0u16);
2094            assert_eq!(len, 123);
2095            assert_eq!(guest.get_next_completed(0).is_none(), true);
2096        }
2097
2098        queues[0].stop().await;
2099    }
2100
2101    #[async_test]
2102    async fn verify_multi_queue(driver: DefaultDriver) {
2103        let test_mem = VirtioTestMemoryAccess::new();
2104        let mut guest = VirtioTestGuest::new(&driver, &test_mem, 5, 2, true);
2105        let (tx, mut rx) = mesh::mpsc_channel();
2106        let events = (0..guest.num_queues)
2107            .map(|_| Event::new())
2108            .collect::<Vec<_>>();
2109        let mut queues = guest.create_direct_queues(|queue_index| {
2110            let tx = tx.clone();
2111            let base_addr = guest.get_queue_descriptor_backing_memory_address(queue_index);
2112            CreateDirectQueueParams {
2113                process_work: Box::new(move |work: anyhow::Result<VirtioQueueCallbackWork>| {
2114                    let mut work = work.expect("Queue failure");
2115                    assert_eq!(work.payload.len(), 1);
2116                    assert_eq!(work.payload[0].address, base_addr);
2117                    assert_eq!(work.payload[0].length, 0x1000);
2118                    work.complete(123 * queue_index as u32);
2119                    true
2120                }),
2121                notify: Interrupt::from_fn(move || {
2122                    tx.send(queue_index as usize);
2123                }),
2124                event: events[queue_index as usize].clone(),
2125            }
2126        });
2127
2128        for (i, event) in events.iter().enumerate() {
2129            let queue_index = i as u16;
2130            guest.add_to_avail_queue(queue_index);
2131            event.signal();
2132        }
2133        // wait for all queue processing to finish
2134        for _ in 0..guest.num_queues {
2135            must_recv_in_timeout(&mut rx, Duration::from_millis(100)).await;
2136        }
2137        // check results
2138        for queue_index in 0..guest.num_queues {
2139            let (desc, len) = guest.get_next_completed(queue_index).unwrap();
2140            assert_eq!(desc, 0u16);
2141            assert_eq!(len, 123 * queue_index as u32);
2142        }
2143        // verify no extraneous completions
2144        for (i, queue) in queues.iter_mut().enumerate() {
2145            let queue_index = i as u16;
2146            assert_eq!(guest.get_next_completed(queue_index).is_none(), true);
2147            queue.stop().await;
2148        }
2149    }
2150
2151    fn take_mmio_interrupt_status(dev: &mut VirtioMmioDevice, mask: u32) -> u32 {
2152        let mut v = [0; 4];
2153        dev.mmio_read(96, &mut v).unwrap();
2154        dev.mmio_write(100, &mask.to_ne_bytes()).unwrap();
2155        u32::from_ne_bytes(v)
2156    }
2157
2158    async fn expect_mmio_interrupt(
2159        dev: &mut VirtioMmioDevice,
2160        target: &TestLineInterruptTarget,
2161        mask: u32,
2162        multiple_expected: bool,
2163    ) {
2164        poll_fn(|cx| target.poll_high(cx, 0)).await;
2165        let v = take_mmio_interrupt_status(dev, mask);
2166        assert_eq!(v & mask, mask);
2167        assert!(multiple_expected || !target.is_high(0));
2168    }
2169
2170    #[async_test]
2171    async fn verify_device_queue_simple(driver: DefaultDriver) {
2172        let test_mem = VirtioTestMemoryAccess::new();
2173        let doorbell_registration: Arc<dyn DoorbellRegistration> = test_mem.clone();
2174        let mut guest = VirtioTestGuest::new(&driver, &test_mem, 1, 2, true);
2175        let mem = guest.mem();
2176        let features = ((VIRTIO_F_VERSION_1 as u64) << 32) | VIRTIO_F_RING_EVENT_IDX as u64 | 2;
2177        let target = TestLineInterruptTarget::new_arc();
2178        let interrupt = LineInterrupt::new_with_target("test", target.clone(), 0);
2179        let base_addr = guest.get_queue_descriptor_backing_memory_address(0);
2180        let queue_work = Arc::new(move |_: u16, mut work: VirtioQueueCallbackWork| {
2181            assert_eq!(work.payload.len(), 1);
2182            assert_eq!(work.payload[0].address, base_addr);
2183            assert_eq!(work.payload[0].length, 0x1000);
2184            work.complete(123);
2185        });
2186        let mut dev = VirtioMmioDevice::new(
2187            Box::new(LegacyWrapper::new(
2188                &VmTaskDriverSource::new(SingleDriverBackend::new(driver)),
2189                TestDevice::new(
2190                    DeviceTraits {
2191                        device_id: 3,
2192                        device_features: features,
2193                        max_queues: 1,
2194                        device_register_length: 0,
2195                        ..Default::default()
2196                    },
2197                    Some(queue_work),
2198                ),
2199                &mem,
2200            )),
2201            interrupt,
2202            Some(doorbell_registration),
2203            0,
2204            1,
2205        );
2206
2207        guest.setup_chipset_device(&mut dev, features);
2208        expect_mmio_interrupt(
2209            &mut dev,
2210            &target,
2211            VIRTIO_MMIO_INTERRUPT_STATUS_CONFIG_CHANGE,
2212            false,
2213        )
2214        .await;
2215        guest.add_to_avail_queue(0);
2216        // notify device
2217        dev.write_u32(80, 0);
2218        expect_mmio_interrupt(
2219            &mut dev,
2220            &target,
2221            VIRTIO_MMIO_INTERRUPT_STATUS_USED_BUFFER,
2222            false,
2223        )
2224        .await;
2225        let (desc, len) = guest.get_next_completed(0).unwrap();
2226        assert_eq!(desc, 0u16);
2227        assert_eq!(len, 123);
2228        assert_eq!(guest.get_next_completed(0).is_none(), true);
2229        // reset the device
2230        dev.write_u32(112, 0);
2231        drop(dev);
2232    }
2233
2234    #[async_test]
2235    async fn verify_device_multi_queue(driver: DefaultDriver) {
2236        let num_queues = 5;
2237        let test_mem = VirtioTestMemoryAccess::new();
2238        let doorbell_registration: Arc<dyn DoorbellRegistration> = test_mem.clone();
2239        let mut guest = VirtioTestGuest::new(&driver, &test_mem, num_queues, 2, true);
2240        let mem = guest.mem();
2241        let features = ((VIRTIO_F_VERSION_1 as u64) << 32) | VIRTIO_F_RING_EVENT_IDX as u64 | 2;
2242        let target = TestLineInterruptTarget::new_arc();
2243        let interrupt = LineInterrupt::new_with_target("test", target.clone(), 0);
2244        let base_addr: Vec<_> = (0..num_queues)
2245            .map(|i| guest.get_queue_descriptor_backing_memory_address(i))
2246            .collect();
2247        let queue_work = Arc::new(move |i: u16, mut work: VirtioQueueCallbackWork| {
2248            assert_eq!(work.payload.len(), 1);
2249            assert_eq!(work.payload[0].address, base_addr[i as usize]);
2250            assert_eq!(work.payload[0].length, 0x1000);
2251            work.complete(123 * i as u32);
2252        });
2253        let mut dev = VirtioMmioDevice::new(
2254            Box::new(LegacyWrapper::new(
2255                &VmTaskDriverSource::new(SingleDriverBackend::new(driver)),
2256                TestDevice::new(
2257                    DeviceTraits {
2258                        device_id: 3,
2259                        device_features: features,
2260                        max_queues: num_queues + 1,
2261                        device_register_length: 0,
2262                        ..Default::default()
2263                    },
2264                    Some(queue_work),
2265                ),
2266                &mem,
2267            )),
2268            interrupt,
2269            Some(doorbell_registration),
2270            0,
2271            1,
2272        );
2273        guest.setup_chipset_device(&mut dev, features);
2274        expect_mmio_interrupt(
2275            &mut dev,
2276            &target,
2277            VIRTIO_MMIO_INTERRUPT_STATUS_CONFIG_CHANGE,
2278            false,
2279        )
2280        .await;
2281        for i in 0..num_queues {
2282            guest.add_to_avail_queue(i);
2283            // notify device
2284            dev.write_u32(80, i as u32);
2285        }
2286        // check results
2287        for i in 0..num_queues {
2288            let (desc, len) = loop {
2289                if let Some(x) = guest.get_next_completed(i) {
2290                    break x;
2291                }
2292                expect_mmio_interrupt(
2293                    &mut dev,
2294                    &target,
2295                    VIRTIO_MMIO_INTERRUPT_STATUS_USED_BUFFER,
2296                    i < (num_queues - 1),
2297                )
2298                .await;
2299            };
2300            assert_eq!(desc, 0u16);
2301            assert_eq!(len, 123 * i as u32);
2302        }
2303        // verify no extraneous completions
2304        for i in 0..num_queues {
2305            assert_eq!(guest.get_next_completed(i).is_none(), true);
2306        }
2307        // reset the device
2308        dev.write_u32(112, 0);
2309        drop(dev);
2310    }
2311
2312    #[async_test]
2313    async fn verify_device_multi_queue_pci(driver: DefaultDriver) {
2314        let num_queues = 5;
2315        let test_mem = VirtioTestMemoryAccess::new();
2316        let mut guest = VirtioTestGuest::new(&driver, &test_mem, num_queues, 2, true);
2317        let features = ((VIRTIO_F_VERSION_1 as u64) << 32) | VIRTIO_F_RING_EVENT_IDX as u64 | 2;
2318        let base_addr: Vec<_> = (0..num_queues)
2319            .map(|i| guest.get_queue_descriptor_backing_memory_address(i))
2320            .collect();
2321        let mut dev = VirtioPciTestDevice::new(
2322            &driver,
2323            num_queues + 1,
2324            &test_mem,
2325            Some(Arc::new(move |i, mut work| {
2326                assert_eq!(work.payload.len(), 1);
2327                assert_eq!(work.payload[0].address, base_addr[i as usize]);
2328                assert_eq!(work.payload[0].length, 0x1000);
2329                work.complete(123 * i as u32);
2330            })),
2331        );
2332
2333        guest.setup_pci_device(&mut dev, features);
2334
2335        let mut timer = PolledTimer::new(&driver);
2336
2337        // expect a config generation interrupt
2338        timer.sleep(Duration::from_millis(100)).await;
2339        let delivered = dev.test_intc.get_next_interrupt().unwrap();
2340        assert_eq!(delivered.0, 0);
2341        assert!(dev.test_intc.get_next_interrupt().is_none());
2342
2343        for i in 0..num_queues {
2344            guest.add_to_avail_queue(i);
2345            // notify device
2346            dev.write_u32(0x10000000000 + 0x38, i as u32);
2347        }
2348        // verify all queue processing finished
2349        timer.sleep(Duration::from_millis(100)).await;
2350        for _ in 0..num_queues {
2351            let delivered = dev.test_intc.get_next_interrupt();
2352            assert!(delivered.is_some());
2353        }
2354        // check results
2355        for i in 0..num_queues {
2356            let (desc, len) = guest.get_next_completed(i).unwrap();
2357            assert_eq!(desc, 0u16);
2358            assert_eq!(len, 123 * i as u32);
2359        }
2360        // verify no extraneous completions
2361        for i in 0..num_queues {
2362            assert_eq!(guest.get_next_completed(i).is_none(), true);
2363        }
2364        // reset the device
2365        let device_status: u8 = 0;
2366        dev.pci_device
2367            .mmio_write(0x10000000000 + 20, &device_status.to_le_bytes())
2368            .unwrap();
2369        drop(dev);
2370    }
2371}