virtio/
common.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4use crate::queue::QueueCore;
5use crate::queue::QueueError;
6use crate::queue::QueueParams;
7use crate::queue::VirtioQueuePayload;
8use async_trait::async_trait;
9use futures::FutureExt;
10use futures::Stream;
11use futures::StreamExt;
12use guestmem::DoorbellRegistration;
13use guestmem::GuestMemory;
14use guestmem::GuestMemoryError;
15use guestmem::MappedMemoryRegion;
16use pal_async::DefaultPool;
17use pal_async::driver::Driver;
18use pal_async::task::Spawn;
19use pal_async::wait::PolledWait;
20use pal_event::Event;
21use parking_lot::Mutex;
22use std::io::Error;
23use std::pin::Pin;
24use std::sync::Arc;
25use std::task::Context;
26use std::task::Poll;
27use std::task::ready;
28use task_control::AsyncRun;
29use task_control::StopTask;
30use task_control::TaskControl;
31use thiserror::Error;
32use vmcore::interrupt::Interrupt;
33use vmcore::vm_task::VmTaskDriver;
34use vmcore::vm_task::VmTaskDriverSource;
35
36#[async_trait]
37pub trait VirtioQueueWorkerContext {
38    async fn process_work(&mut self, work: anyhow::Result<VirtioQueueCallbackWork>) -> bool;
39}
40
41#[derive(Debug)]
42pub struct VirtioQueueUsedHandler {
43    core: QueueCore,
44    last_used_index: u16,
45    outstanding_desc_count: Arc<Mutex<(u16, event_listener::Event)>>,
46    notify_guest: Interrupt,
47}
48
49impl VirtioQueueUsedHandler {
50    fn new(core: QueueCore, notify_guest: Interrupt) -> Self {
51        Self {
52            core,
53            last_used_index: 0,
54            outstanding_desc_count: Arc::new(Mutex::new((0, event_listener::Event::new()))),
55            notify_guest,
56        }
57    }
58
59    pub fn add_outstanding_descriptor(&self) {
60        let (count, _) = &mut *self.outstanding_desc_count.lock();
61        *count += 1;
62    }
63
64    pub fn await_outstanding_descriptors(&self) -> event_listener::EventListener {
65        let (count, event) = &*self.outstanding_desc_count.lock();
66        let listener = event.listen();
67        if *count == 0 {
68            event.notify(usize::MAX);
69        }
70        listener
71    }
72
73    pub fn complete_descriptor(&mut self, descriptor_index: u16, bytes_written: u32) {
74        match self.core.complete_descriptor(
75            &mut self.last_used_index,
76            descriptor_index,
77            bytes_written,
78        ) {
79            Ok(true) => {
80                self.notify_guest.deliver();
81            }
82            Ok(false) => {}
83            Err(err) => {
84                tracelimit::error_ratelimited!(
85                    error = &err as &dyn std::error::Error,
86                    "failed to complete descriptor"
87                );
88            }
89        }
90        {
91            let (count, event) = &mut *self.outstanding_desc_count.lock();
92            *count -= 1;
93            if *count == 0 {
94                event.notify(usize::MAX);
95            }
96        }
97    }
98}
99
100pub struct VirtioQueueCallbackWork {
101    pub payload: Vec<VirtioQueuePayload>,
102    used_queue_handler: Arc<Mutex<VirtioQueueUsedHandler>>,
103    descriptor_index: u16,
104    completed: bool,
105}
106
107impl VirtioQueueCallbackWork {
108    pub fn new(
109        payload: Vec<VirtioQueuePayload>,
110        used_queue_handler: &Arc<Mutex<VirtioQueueUsedHandler>>,
111        descriptor_index: u16,
112    ) -> Self {
113        let used_queue_handler = used_queue_handler.clone();
114        used_queue_handler.lock().add_outstanding_descriptor();
115        Self {
116            payload,
117            used_queue_handler,
118            descriptor_index,
119            completed: false,
120        }
121    }
122
123    pub fn complete(&mut self, bytes_written: u32) {
124        assert!(!self.completed);
125        self.used_queue_handler
126            .lock()
127            .complete_descriptor(self.descriptor_index, bytes_written);
128        self.completed = true;
129    }
130
131    pub fn descriptor_index(&self) -> u16 {
132        self.descriptor_index
133    }
134
135    // Determine the total size of all readable or all writeable payload buffers.
136    pub fn get_payload_length(&self, writeable: bool) -> u64 {
137        self.payload
138            .iter()
139            .filter(|x| x.writeable == writeable)
140            .fold(0, |acc, x| acc + x.length as u64)
141    }
142
143    // Read all payload into a buffer.
144    pub fn read(&self, mem: &GuestMemory, target: &mut [u8]) -> Result<usize, GuestMemoryError> {
145        let mut remaining = target;
146        let mut read_bytes: usize = 0;
147        for payload in &self.payload {
148            if payload.writeable {
149                continue;
150            }
151
152            let size = std::cmp::min(payload.length as usize, remaining.len());
153            let (current, next) = remaining.split_at_mut(size);
154            mem.read_at(payload.address, current)?;
155            read_bytes += size;
156            if next.is_empty() {
157                break;
158            }
159
160            remaining = next;
161        }
162
163        Ok(read_bytes)
164    }
165
166    // Write the specified buffer to the payload buffers.
167    pub fn write_at_offset(
168        &self,
169        offset: u64,
170        mem: &GuestMemory,
171        source: &[u8],
172    ) -> Result<(), VirtioWriteError> {
173        let mut skip_bytes = offset;
174        let mut remaining = source;
175        for payload in &self.payload {
176            if !payload.writeable {
177                continue;
178            }
179
180            let payload_length = payload.length as u64;
181            if skip_bytes >= payload_length {
182                skip_bytes -= payload_length;
183                continue;
184            }
185
186            let size = std::cmp::min(
187                payload_length as usize - skip_bytes as usize,
188                remaining.len(),
189            );
190            let (current, next) = remaining.split_at(size);
191            mem.write_at(payload.address + skip_bytes, current)?;
192            remaining = next;
193            if remaining.is_empty() {
194                break;
195            }
196            skip_bytes = 0;
197        }
198
199        if !remaining.is_empty() {
200            return Err(VirtioWriteError::NotAllWritten(source.len()));
201        }
202
203        Ok(())
204    }
205
206    pub fn write(&self, mem: &GuestMemory, source: &[u8]) -> Result<(), VirtioWriteError> {
207        self.write_at_offset(0, mem, source)
208    }
209}
210
211#[derive(Debug, Error)]
212pub enum VirtioWriteError {
213    #[error(transparent)]
214    Memory(#[from] GuestMemoryError),
215    #[error("{0:#x} bytes not written")]
216    NotAllWritten(usize),
217}
218
219impl Drop for VirtioQueueCallbackWork {
220    fn drop(&mut self) {
221        if !self.completed {
222            self.complete(0);
223        }
224    }
225}
226
227#[derive(Debug)]
228pub struct VirtioQueue {
229    core: QueueCore,
230    last_avail_index: u16,
231    used_handler: Arc<Mutex<VirtioQueueUsedHandler>>,
232    queue_event: PolledWait<Event>,
233}
234
235impl VirtioQueue {
236    pub fn new(
237        features: u64,
238        params: QueueParams,
239        mem: GuestMemory,
240        notify: Interrupt,
241        queue_event: PolledWait<Event>,
242    ) -> Result<Self, QueueError> {
243        let core = QueueCore::new(features, mem, params)?;
244        let used_handler = Arc::new(Mutex::new(VirtioQueueUsedHandler::new(
245            core.clone(),
246            notify,
247        )));
248        Ok(Self {
249            core,
250            last_avail_index: 0,
251            used_handler,
252            queue_event,
253        })
254    }
255
256    async fn wait_for_outstanding_descriptors(&self) {
257        let wait_for_descriptors = self.used_handler.lock().await_outstanding_descriptors();
258        wait_for_descriptors.await;
259    }
260
261    fn poll_next_buffer(
262        &mut self,
263        cx: &mut Context<'_>,
264    ) -> Poll<Result<Option<VirtioQueueCallbackWork>, QueueError>> {
265        let descriptor_index = loop {
266            if let Some(descriptor_index) = self.core.descriptor_index(self.last_avail_index)? {
267                break descriptor_index;
268            };
269            ready!(self.queue_event.wait().poll_unpin(cx)).expect("waits on Event cannot fail");
270        };
271        let payload = self
272            .core
273            .reader(descriptor_index)
274            .collect::<Result<Vec<_>, _>>()?;
275
276        self.last_avail_index = self.last_avail_index.wrapping_add(1);
277        Poll::Ready(Ok(Some(VirtioQueueCallbackWork::new(
278            payload,
279            &self.used_handler,
280            descriptor_index,
281        ))))
282    }
283}
284
285impl Drop for VirtioQueue {
286    fn drop(&mut self) {
287        if Arc::get_mut(&mut self.used_handler).is_none() {
288            tracing::error!("Virtio queue dropped with outstanding work pending")
289        }
290    }
291}
292
293impl Stream for VirtioQueue {
294    type Item = Result<VirtioQueueCallbackWork, Error>;
295
296    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
297        let Some(r) = ready!(self.get_mut().poll_next_buffer(cx)).transpose() else {
298            return Poll::Ready(None);
299        };
300
301        Poll::Ready(Some(r.map_err(Error::other)))
302    }
303}
304
305enum VirtioQueueStateInner {
306    Initializing {
307        mem: GuestMemory,
308        features: u64,
309        params: QueueParams,
310        event: Event,
311        notify: Interrupt,
312        exit_event: event_listener::EventListener,
313    },
314    InitializationInProgress,
315    Running {
316        queue: VirtioQueue,
317        exit_event: event_listener::EventListener,
318    },
319}
320
321pub struct VirtioQueueState {
322    inner: VirtioQueueStateInner,
323}
324
325pub struct VirtioQueueWorker {
326    driver: Box<dyn Driver>,
327    context: Box<dyn VirtioQueueWorkerContext + Send>,
328}
329
330impl VirtioQueueWorker {
331    pub fn new(driver: impl Driver, context: Box<dyn VirtioQueueWorkerContext + Send>) -> Self {
332        Self {
333            driver: Box::new(driver),
334            context,
335        }
336    }
337
338    pub fn into_running_task(
339        self,
340        name: impl Into<String>,
341        mem: GuestMemory,
342        features: u64,
343        queue_resources: QueueResources,
344        exit_event: event_listener::EventListener,
345    ) -> TaskControl<VirtioQueueWorker, VirtioQueueState> {
346        let name = name.into();
347        let (_, driver) = DefaultPool::spawn_on_thread(&name);
348
349        let mut task = TaskControl::new(self);
350        task.insert(
351            driver,
352            name,
353            VirtioQueueState {
354                inner: VirtioQueueStateInner::Initializing {
355                    mem,
356                    features,
357                    params: queue_resources.params,
358                    event: queue_resources.event,
359                    notify: queue_resources.notify,
360                    exit_event,
361                },
362            },
363        );
364        task.start();
365        task
366    }
367
368    async fn run_queue(&mut self, state: &mut VirtioQueueState) -> bool {
369        match &mut state.inner {
370            VirtioQueueStateInner::InitializationInProgress => unreachable!(),
371            VirtioQueueStateInner::Initializing { .. } => {
372                let VirtioQueueStateInner::Initializing {
373                    mem,
374                    features,
375                    params,
376                    event,
377                    notify,
378                    exit_event,
379                } = std::mem::replace(
380                    &mut state.inner,
381                    VirtioQueueStateInner::InitializationInProgress,
382                )
383                else {
384                    unreachable!()
385                };
386                let queue_event = PolledWait::new(&self.driver, event).unwrap();
387                let queue = VirtioQueue::new(features, params, mem, notify, queue_event);
388                if let Err(err) = queue {
389                    tracing::error!(
390                        err = &err as &dyn std::error::Error,
391                        "Failed to start queue"
392                    );
393                    false
394                } else {
395                    state.inner = VirtioQueueStateInner::Running {
396                        queue: queue.unwrap(),
397                        exit_event,
398                    };
399                    true
400                }
401            }
402            VirtioQueueStateInner::Running { queue, exit_event } => {
403                let mut exit = exit_event.fuse();
404                let mut queue_ready = queue.next().fuse();
405                let work = futures::select_biased! {
406                    _ = exit => return false,
407                    work = queue_ready => work.expect("queue will never complete").map_err(anyhow::Error::from),
408                };
409                self.context.process_work(work).await
410            }
411        }
412    }
413}
414
415impl AsyncRun<VirtioQueueState> for VirtioQueueWorker {
416    async fn run(
417        &mut self,
418        stop: &mut StopTask<'_>,
419        state: &mut VirtioQueueState,
420    ) -> Result<(), task_control::Cancelled> {
421        while stop.until_stopped(self.run_queue(state)).await? {}
422        Ok(())
423    }
424}
425
426pub struct VirtioRunningState {
427    pub features: u64,
428    pub enabled_queues: Vec<bool>,
429}
430
431pub enum VirtioState {
432    Unknown,
433    Running(VirtioRunningState),
434    Stopped,
435}
436
437pub(crate) struct VirtioDoorbells {
438    registration: Option<Arc<dyn DoorbellRegistration>>,
439    doorbells: Vec<Box<dyn Send + Sync>>,
440}
441
442impl VirtioDoorbells {
443    pub fn new(registration: Option<Arc<dyn DoorbellRegistration>>) -> Self {
444        Self {
445            registration,
446            doorbells: Vec::new(),
447        }
448    }
449
450    pub fn add(&mut self, address: u64, value: Option<u64>, length: Option<u32>, event: &Event) {
451        if let Some(registration) = &mut self.registration {
452            let doorbell = registration.register_doorbell(address, value, length, event);
453            if let Ok(doorbell) = doorbell {
454                self.doorbells.push(doorbell);
455            }
456        }
457    }
458
459    pub fn clear(&mut self) {
460        self.doorbells.clear();
461    }
462}
463
464#[derive(Copy, Clone, Debug, Default)]
465pub struct DeviceTraitsSharedMemory {
466    pub id: u8,
467    pub size: u64,
468}
469
470#[derive(Copy, Clone, Debug, Default)]
471pub struct DeviceTraits {
472    pub device_id: u16,
473    pub device_features: u64,
474    pub max_queues: u16,
475    pub device_register_length: u32,
476    pub shared_memory: DeviceTraitsSharedMemory,
477}
478
479pub trait LegacyVirtioDevice: Send {
480    fn traits(&self) -> DeviceTraits;
481    fn read_registers_u32(&self, offset: u16) -> u32;
482    fn write_registers_u32(&mut self, offset: u16, val: u32);
483    fn get_work_callback(&mut self, index: u16) -> Box<dyn VirtioQueueWorkerContext + Send>;
484    fn state_change(&mut self, state: &VirtioState);
485}
486
487pub trait VirtioDevice: Send {
488    fn traits(&self) -> DeviceTraits;
489    fn read_registers_u32(&self, offset: u16) -> u32;
490    fn write_registers_u32(&mut self, offset: u16, val: u32);
491    fn enable(&mut self, resources: Resources);
492    fn disable(&mut self);
493}
494
495pub struct QueueResources {
496    pub params: QueueParams,
497    pub notify: Interrupt,
498    pub event: Event,
499}
500
501pub struct Resources {
502    pub features: u64,
503    pub queues: Vec<QueueResources>,
504    pub shared_memory_region: Option<Arc<dyn MappedMemoryRegion>>,
505    pub shared_memory_size: u64,
506}
507
508/// Wraps an object implementing [`LegacyVirtioDevice`] and implements [`VirtioDevice`].
509pub struct LegacyWrapper<T: LegacyVirtioDevice> {
510    device: T,
511    driver: VmTaskDriver,
512    mem: GuestMemory,
513    workers: Vec<TaskControl<VirtioQueueWorker, VirtioQueueState>>,
514    exit_event: event_listener::Event,
515}
516
517impl<T: LegacyVirtioDevice> LegacyWrapper<T> {
518    pub fn new(driver_source: &VmTaskDriverSource, device: T, mem: &GuestMemory) -> Self {
519        Self {
520            device,
521            driver: driver_source.simple(),
522            mem: mem.clone(),
523            workers: Vec::new(),
524            exit_event: event_listener::Event::new(),
525        }
526    }
527}
528
529impl<T: LegacyVirtioDevice> VirtioDevice for LegacyWrapper<T> {
530    fn traits(&self) -> DeviceTraits {
531        self.device.traits()
532    }
533
534    fn read_registers_u32(&self, offset: u16) -> u32 {
535        self.device.read_registers_u32(offset)
536    }
537
538    fn write_registers_u32(&mut self, offset: u16, val: u32) {
539        self.device.write_registers_u32(offset, val)
540    }
541
542    fn enable(&mut self, resources: Resources) {
543        let running_state = VirtioRunningState {
544            features: resources.features,
545            enabled_queues: resources
546                .queues
547                .iter()
548                .map(|QueueResources { params, .. }| params.enable)
549                .collect(),
550        };
551
552        self.device
553            .state_change(&VirtioState::Running(running_state));
554        self.workers = resources
555            .queues
556            .into_iter()
557            .enumerate()
558            .filter_map(|(i, queue_resources)| {
559                if !queue_resources.params.enable {
560                    return None;
561                }
562                let worker = VirtioQueueWorker::new(
563                    self.driver.clone(),
564                    self.device.get_work_callback(i as u16),
565                );
566                Some(worker.into_running_task(
567                    "virtio-queue".to_string(),
568                    self.mem.clone(),
569                    resources.features,
570                    queue_resources,
571                    self.exit_event.listen(),
572                ))
573            })
574            .collect();
575    }
576
577    fn disable(&mut self) {
578        if self.workers.is_empty() {
579            return;
580        }
581        self.exit_event.notify(usize::MAX);
582        self.device.state_change(&VirtioState::Stopped);
583        let mut workers = self.workers.drain(..).collect::<Vec<_>>();
584        self.driver
585            .spawn("shutdown-legacy-virtio-queues".to_owned(), async move {
586                futures::future::join_all(workers.iter_mut().map(async |worker| {
587                    worker.stop().await;
588                    if let Some(VirtioQueueStateInner::Running { queue, .. }) =
589                        worker.state_mut().map(|s| &s.inner)
590                    {
591                        queue.wait_for_outstanding_descriptors().await;
592                    }
593                }))
594                .await;
595            })
596            .detach();
597    }
598}
599
600impl<T: LegacyVirtioDevice> Drop for LegacyWrapper<T> {
601    fn drop(&mut self) {
602        self.disable();
603    }
604}
605
606// UNSAFETY: test code implements a custom `GuestMemory` backing, which requires
607// unsafe.
608#[expect(unsafe_code)]
609#[cfg(test)]
610mod tests {
611    use super::*;
612    use crate::PciInterruptModel;
613    use crate::spec::pci::*;
614    use crate::spec::queue::*;
615    use crate::spec::*;
616    use crate::transport::VirtioMmioDevice;
617    use crate::transport::VirtioPciDevice;
618    use chipset_device::mmio::ExternallyManagedMmioIntercepts;
619    use chipset_device::mmio::MmioIntercept;
620    use chipset_device::pci::PciConfigSpace;
621    use futures::StreamExt;
622    use guestmem::GuestMemoryAccess;
623    use guestmem::GuestMemoryBackingError;
624    use pal_async::DefaultDriver;
625    use pal_async::async_test;
626    use pal_async::timer::PolledTimer;
627    use pci_core::msi::MsiInterruptSet;
628    use pci_core::spec::caps::CapabilityId;
629    use pci_core::spec::cfg_space;
630    use pci_core::test_helpers::TestPciInterruptController;
631    use std::collections::BTreeMap;
632    use std::future::poll_fn;
633    use std::io;
634    use std::ptr::NonNull;
635    use std::time::Duration;
636    use test_with_tracing::test;
637    use vmcore::line_interrupt::LineInterrupt;
638    use vmcore::line_interrupt::test_helpers::TestLineInterruptTarget;
639    use vmcore::vm_task::SingleDriverBackend;
640    use vmcore::vm_task::VmTaskDriverSource;
641
642    async fn must_recv_in_timeout<T: 'static + Send>(
643        recv: &mut mesh::Receiver<T>,
644        timeout: Duration,
645    ) -> T {
646        mesh::CancelContext::new()
647            .with_timeout(timeout)
648            .until_cancelled(recv.next())
649            .await
650            .unwrap()
651            .unwrap()
652    }
653
654    #[derive(Default)]
655    struct VirtioTestMemoryAccess {
656        memory_map: Mutex<MemoryMap>,
657    }
658
659    #[derive(Default)]
660    struct MemoryMap {
661        map: BTreeMap<u64, (bool, Vec<u8>)>,
662    }
663
664    impl MemoryMap {
665        fn get(&mut self, address: u64, len: usize) -> Option<(bool, &mut [u8])> {
666            let (&base, &mut (writable, ref mut data)) = self.map.range_mut(..=address).last()?;
667            let data = data
668                .get_mut(usize::try_from(address - base).ok()?..)?
669                .get_mut(..len)?;
670
671            Some((writable, data))
672        }
673
674        fn insert(&mut self, address: u64, data: &[u8], writable: bool) {
675            if let Some((is_writable, v)) = self.get(address, data.len()) {
676                assert_eq!(writable, is_writable);
677                v.copy_from_slice(data);
678                return;
679            }
680
681            let end = address + data.len() as u64;
682            let mut data = data.to_vec();
683            if let Some((&next, &(next_writable, ref next_data))) = self.map.range(address..).next()
684            {
685                if end > next {
686                    let next_end = next + next_data.len() as u64;
687                    panic!(
688                        "overlapping memory map: {address:#x}..{end:#x} > {next:#x}..={next_end:#x}"
689                    );
690                }
691                if end == next && next_writable == writable {
692                    data.extend(next_data.as_slice());
693                    self.map.remove(&next).unwrap();
694                }
695            }
696
697            if let Some((&prev, &mut (prev_writable, ref mut prev_data))) =
698                self.map.range_mut(..address).last()
699            {
700                let prev_end = prev + prev_data.len() as u64;
701                if prev_end > address {
702                    panic!(
703                        "overlapping memory map: {prev:#x}..{prev_end:#x} > {address:#x}..={end:#x}"
704                    );
705                }
706                if prev_end == address && prev_writable == writable {
707                    prev_data.extend_from_slice(&data);
708                    return;
709                }
710            }
711
712            self.map.insert(address, (writable, data));
713        }
714    }
715
716    impl VirtioTestMemoryAccess {
717        fn new() -> Arc<Self> {
718            Default::default()
719        }
720
721        fn modify_memory_map(&self, address: u64, data: &[u8], writeable: bool) {
722            self.memory_map.lock().insert(address, data, writeable);
723        }
724
725        fn memory_map_get_u16(&self, address: u64) -> u16 {
726            let mut map = self.memory_map.lock();
727            let (_, data) = map.get(address, 2).unwrap();
728            u16::from_le_bytes(data.try_into().unwrap())
729        }
730
731        fn memory_map_get_u32(&self, address: u64) -> u32 {
732            let mut map = self.memory_map.lock();
733            let (_, data) = map.get(address, 4).unwrap();
734            u32::from_le_bytes(data.try_into().unwrap())
735        }
736    }
737
738    // SAFETY: test code
739    unsafe impl GuestMemoryAccess for VirtioTestMemoryAccess {
740        fn mapping(&self) -> Option<NonNull<u8>> {
741            None
742        }
743
744        fn max_address(&self) -> u64 {
745            // No real bound, so use the max physical address width on
746            // AMD64/ARM64.
747            1 << 52
748        }
749
750        unsafe fn read_fallback(
751            &self,
752            address: u64,
753            dest: *mut u8,
754            len: usize,
755        ) -> Result<(), GuestMemoryBackingError> {
756            match self.memory_map.lock().get(address, len) {
757                Some((_, value)) => {
758                    // SAFETY: guaranteed by caller
759                    unsafe {
760                        std::ptr::copy(value.as_ptr(), dest, len);
761                    }
762                }
763                None => panic!("Unexpected read request at address {:x}", address),
764            }
765            Ok(())
766        }
767
768        unsafe fn write_fallback(
769            &self,
770            address: u64,
771            src: *const u8,
772            len: usize,
773        ) -> Result<(), GuestMemoryBackingError> {
774            match self.memory_map.lock().get(address, len) {
775                Some((true, value)) => {
776                    // SAFETY: guaranteed by caller
777                    unsafe {
778                        std::ptr::copy(src, value.as_mut_ptr(), len);
779                    }
780                }
781                _ => panic!("Unexpected write request at address {:x}", address),
782            }
783            Ok(())
784        }
785
786        fn fill_fallback(
787            &self,
788            address: u64,
789            val: u8,
790            len: usize,
791        ) -> Result<(), GuestMemoryBackingError> {
792            match self.memory_map.lock().get(address, len) {
793                Some((true, value)) => value.fill(val),
794                _ => panic!("Unexpected write request at address {:x}", address),
795            };
796            Ok(())
797        }
798    }
799
800    struct DoorbellEntry;
801    impl Drop for DoorbellEntry {
802        fn drop(&mut self) {}
803    }
804
805    impl DoorbellRegistration for VirtioTestMemoryAccess {
806        fn register_doorbell(
807            &self,
808            _: u64,
809            _: Option<u64>,
810            _: Option<u32>,
811            _: &Event,
812        ) -> io::Result<Box<dyn Send + Sync>> {
813            Ok(Box::new(DoorbellEntry))
814        }
815    }
816
817    type VirtioTestWorkCallback =
818        Box<dyn Fn(anyhow::Result<VirtioQueueCallbackWork>) -> bool + Sync + Send>;
819    struct CreateDirectQueueParams {
820        process_work: VirtioTestWorkCallback,
821        notify: Interrupt,
822        event: Event,
823    }
824
825    struct VirtioTestGuest {
826        test_mem: Arc<VirtioTestMemoryAccess>,
827        driver: DefaultDriver,
828        num_queues: u16,
829        queue_size: u16,
830        use_ring_event_index: bool,
831        last_avail_index: Vec<u16>,
832        last_used_index: Vec<u16>,
833        avail_descriptors: Vec<Vec<bool>>,
834        exit_event: event_listener::Event,
835    }
836
837    impl VirtioTestGuest {
838        fn new(
839            driver: &DefaultDriver,
840            test_mem: &Arc<VirtioTestMemoryAccess>,
841            num_queues: u16,
842            queue_size: u16,
843            use_ring_event_index: bool,
844        ) -> Self {
845            let last_avail_index: Vec<u16> = vec![0; num_queues as usize];
846            let last_used_index: Vec<u16> = vec![0; num_queues as usize];
847            let avail_descriptors: Vec<Vec<bool>> =
848                vec![vec![true; queue_size as usize]; num_queues as usize];
849            let test_guest = Self {
850                test_mem: test_mem.clone(),
851                driver: driver.clone(),
852                num_queues,
853                queue_size,
854                use_ring_event_index,
855                last_avail_index,
856                last_used_index,
857                avail_descriptors,
858                exit_event: event_listener::Event::new(),
859            };
860            for i in 0..num_queues {
861                test_guest.add_queue_memory(i);
862            }
863            test_guest
864        }
865
866        fn mem(&self) -> GuestMemory {
867            GuestMemory::new("test", self.test_mem.clone())
868        }
869
870        fn create_direct_queues<F>(
871            &self,
872            f: F,
873        ) -> Vec<TaskControl<VirtioQueueWorker, VirtioQueueState>>
874        where
875            F: Fn(u16) -> CreateDirectQueueParams,
876        {
877            (0..self.num_queues)
878                .map(|i| {
879                    let params = f(i);
880                    let worker = VirtioQueueWorker::new(
881                        self.driver.clone(),
882                        Box::new(VirtioTestWork {
883                            callback: params.process_work,
884                        }),
885                    );
886                    worker.into_running_task(
887                        "virtio-test-queue".to_string(),
888                        self.mem(),
889                        self.queue_features(),
890                        QueueResources {
891                            params: self.queue_params(i),
892                            notify: params.notify,
893                            event: params.event,
894                        },
895                        self.exit_event.listen(),
896                    )
897                })
898                .collect::<Vec<_>>()
899        }
900
901        fn queue_features(&self) -> u64 {
902            if self.use_ring_event_index {
903                VIRTIO_F_RING_EVENT_IDX as u64
904            } else {
905                0
906            }
907        }
908
909        fn queue_params(&self, i: u16) -> QueueParams {
910            QueueParams {
911                size: self.queue_size,
912                enable: true,
913                desc_addr: self.get_queue_descriptor_base_address(i),
914                avail_addr: self.get_queue_available_base_address(i),
915                used_addr: self.get_queue_used_base_address(i),
916            }
917        }
918
919        fn get_queue_base_address(&self, index: u16) -> u64 {
920            0x10000000 * index as u64
921        }
922
923        fn get_queue_descriptor_base_address(&self, index: u16) -> u64 {
924            self.get_queue_base_address(index) + 0x1000
925        }
926
927        fn get_queue_available_base_address(&self, index: u16) -> u64 {
928            self.get_queue_base_address(index) + 0x2000
929        }
930
931        fn get_queue_used_base_address(&self, index: u16) -> u64 {
932            self.get_queue_base_address(index) + 0x3000
933        }
934
935        fn get_queue_descriptor_backing_memory_address(&self, index: u16) -> u64 {
936            self.get_queue_base_address(index) + 0x4000
937        }
938
939        fn setup_chipset_device(&self, dev: &mut VirtioMmioDevice, driver_features: u64) {
940            dev.write_u32(112, VIRTIO_ACKNOWLEDGE);
941            dev.write_u32(112, VIRTIO_DRIVER);
942            dev.write_u32(36, 0);
943            dev.write_u32(32, driver_features as u32);
944            dev.write_u32(36, 1);
945            dev.write_u32(32, (driver_features >> 32) as u32);
946            dev.write_u32(112, VIRTIO_FEATURES_OK);
947            for i in 0..self.num_queues {
948                let queue_index = i;
949                dev.write_u32(48, i as u32);
950                dev.write_u32(56, self.queue_size as u32);
951                let desc_addr = self.get_queue_descriptor_base_address(queue_index);
952                dev.write_u32(128, desc_addr as u32);
953                dev.write_u32(132, (desc_addr >> 32) as u32);
954                let avail_addr = self.get_queue_available_base_address(queue_index);
955                dev.write_u32(144, avail_addr as u32);
956                dev.write_u32(148, (avail_addr >> 32) as u32);
957                let used_addr = self.get_queue_used_base_address(queue_index);
958                dev.write_u32(160, used_addr as u32);
959                dev.write_u32(164, (used_addr >> 32) as u32);
960                // enable the queue
961                dev.write_u32(68, 1);
962            }
963            dev.write_u32(112, VIRTIO_DRIVER_OK);
964            assert_eq!(dev.read_u32(0xfc), 2);
965        }
966
967        fn setup_pci_device(&self, dev: &mut VirtioPciTestDevice, driver_features: u64) {
968            let bar_address1: u64 = 0x10000000000;
969            dev.pci_device
970                .pci_cfg_write(0x14, (bar_address1 >> 32) as u32)
971                .unwrap();
972            dev.pci_device
973                .pci_cfg_write(0x10, bar_address1 as u32)
974                .unwrap();
975
976            let bar_address2: u64 = 0x20000000000;
977            dev.pci_device
978                .pci_cfg_write(0x1c, (bar_address2 >> 32) as u32)
979                .unwrap();
980            dev.pci_device
981                .pci_cfg_write(0x18, bar_address2 as u32)
982                .unwrap();
983
984            dev.pci_device
985                .pci_cfg_write(
986                    0x4,
987                    cfg_space::Command::new()
988                        .with_mmio_enabled(true)
989                        .into_bits() as u32,
990                )
991                .unwrap();
992
993            let mut device_status = VIRTIO_ACKNOWLEDGE as u8;
994            dev.pci_device
995                .mmio_write(bar_address1 + 20, &device_status.to_le_bytes())
996                .unwrap();
997            device_status = VIRTIO_DRIVER as u8;
998            dev.pci_device
999                .mmio_write(bar_address1 + 20, &device_status.to_le_bytes())
1000                .unwrap();
1001            dev.write_u32(bar_address1 + 8, 0);
1002            dev.write_u32(bar_address1 + 12, driver_features as u32);
1003            dev.write_u32(bar_address1 + 8, 1);
1004            dev.write_u32(bar_address1 + 12, (driver_features >> 32) as u32);
1005            device_status = VIRTIO_FEATURES_OK as u8;
1006            dev.pci_device
1007                .mmio_write(bar_address1 + 20, &device_status.to_le_bytes())
1008                .unwrap();
1009            // setup config interrupt
1010            dev.pci_device
1011                .mmio_write(bar_address2, &0_u64.to_le_bytes())
1012                .unwrap(); // vector
1013            dev.pci_device
1014                .mmio_write(bar_address2 + 8, &0_u32.to_le_bytes())
1015                .unwrap(); // data
1016            dev.pci_device
1017                .mmio_write(bar_address2 + 12, &0_u32.to_le_bytes())
1018                .unwrap();
1019            for i in 0..self.num_queues {
1020                let queue_index = i;
1021                dev.pci_device
1022                    .mmio_write(bar_address1 + 22, &queue_index.to_le_bytes())
1023                    .unwrap();
1024                dev.pci_device
1025                    .mmio_write(bar_address1 + 24, &self.queue_size.to_le_bytes())
1026                    .unwrap();
1027                // setup MSI information for the queue
1028                let msix_vector = queue_index + 1;
1029                let address = bar_address2 + 0x10 * msix_vector as u64;
1030                dev.pci_device
1031                    .mmio_write(address, &(msix_vector as u64).to_le_bytes())
1032                    .unwrap();
1033                let address = bar_address2 + 0x10 * msix_vector as u64 + 8;
1034                dev.pci_device
1035                    .mmio_write(address, &0_u32.to_le_bytes())
1036                    .unwrap();
1037                let address = bar_address2 + 0x10 * msix_vector as u64 + 12;
1038                dev.pci_device
1039                    .mmio_write(address, &0_u32.to_le_bytes())
1040                    .unwrap();
1041                dev.pci_device
1042                    .mmio_write(bar_address1 + 26, &msix_vector.to_le_bytes())
1043                    .unwrap();
1044                // setup queue addresses
1045                let desc_addr = self.get_queue_descriptor_base_address(queue_index);
1046                dev.write_u32(bar_address1 + 32, desc_addr as u32);
1047                dev.write_u32(bar_address1 + 36, (desc_addr >> 32) as u32);
1048                let avail_addr = self.get_queue_available_base_address(queue_index);
1049                dev.write_u32(bar_address1 + 40, avail_addr as u32);
1050                dev.write_u32(bar_address1 + 44, (avail_addr >> 32) as u32);
1051                let used_addr = self.get_queue_used_base_address(queue_index);
1052                dev.write_u32(bar_address1 + 48, used_addr as u32);
1053                dev.write_u32(bar_address1 + 52, (used_addr >> 32) as u32);
1054                // enable the queue
1055                let enabled: u16 = 1;
1056                dev.pci_device
1057                    .mmio_write(bar_address1 + 28, &enabled.to_le_bytes())
1058                    .unwrap();
1059            }
1060            // enable all device MSI interrupts
1061            dev.pci_device.pci_cfg_write(0x40, 0x80000000).unwrap();
1062            // run device
1063            device_status = VIRTIO_DRIVER_OK as u8;
1064            dev.pci_device
1065                .mmio_write(bar_address1 + 20, &device_status.to_le_bytes())
1066                .unwrap();
1067            let mut config_generation: [u8; 1] = [0];
1068            dev.pci_device
1069                .mmio_read(bar_address1 + 21, &mut config_generation)
1070                .unwrap();
1071            assert_eq!(config_generation[0], 2);
1072        }
1073
1074        fn get_queue_descriptor(&self, queue_index: u16, descriptor_index: u16) -> u64 {
1075            self.get_queue_descriptor_base_address(queue_index) + 0x10 * descriptor_index as u64
1076        }
1077
1078        fn add_queue_memory(&self, queue_index: u16) {
1079            // descriptors
1080            for i in 0..self.queue_size {
1081                let base = self.get_queue_descriptor(queue_index, i);
1082                // physical address
1083                self.test_mem.modify_memory_map(
1084                    base,
1085                    &(self.get_queue_descriptor_backing_memory_address(queue_index)
1086                        + 0x1000 * i as u64)
1087                        .to_le_bytes(),
1088                    false,
1089                );
1090                // length
1091                self.test_mem
1092                    .modify_memory_map(base + 8, &0x1000u32.to_le_bytes(), false);
1093                // flags
1094                self.test_mem
1095                    .modify_memory_map(base + 12, &0u16.to_le_bytes(), false);
1096                // next index
1097                self.test_mem
1098                    .modify_memory_map(base + 14, &0u16.to_le_bytes(), false);
1099            }
1100
1101            // available queue (flags, index)
1102            let base = self.get_queue_available_base_address(queue_index);
1103            self.test_mem
1104                .modify_memory_map(base, &0u16.to_le_bytes(), false);
1105            self.test_mem
1106                .modify_memory_map(base + 2, &0u16.to_le_bytes(), false);
1107            // available queue ring buffer
1108            for i in 0..self.queue_size {
1109                let base = base + 4 + 2 * i as u64;
1110                self.test_mem
1111                    .modify_memory_map(base, &0u16.to_le_bytes(), false);
1112            }
1113            // used event
1114            if self.use_ring_event_index {
1115                self.test_mem.modify_memory_map(
1116                    base + 4 + 2 * self.queue_size as u64,
1117                    &0u16.to_le_bytes(),
1118                    false,
1119                );
1120            }
1121
1122            // used queue (flags, index)
1123            let base = self.get_queue_used_base_address(queue_index);
1124            self.test_mem
1125                .modify_memory_map(base, &0u16.to_le_bytes(), true);
1126            self.test_mem
1127                .modify_memory_map(base + 2, &0u16.to_le_bytes(), true);
1128            for i in 0..self.queue_size {
1129                let base = base + 4 + 8 * i as u64;
1130                // index
1131                self.test_mem
1132                    .modify_memory_map(base, &0u32.to_le_bytes(), true);
1133                // length
1134                self.test_mem
1135                    .modify_memory_map(base + 4, &0u32.to_le_bytes(), true);
1136            }
1137            // available event
1138            if self.use_ring_event_index {
1139                self.test_mem.modify_memory_map(
1140                    base + 4 + 8 * self.queue_size as u64,
1141                    &0u16.to_le_bytes(),
1142                    true,
1143                );
1144            }
1145        }
1146
1147        fn reserve_descriptor(&mut self, queue_index: u16) -> u16 {
1148            let avail_descriptors = &mut self.avail_descriptors[queue_index as usize];
1149            for (i, desc) in avail_descriptors.iter_mut().enumerate() {
1150                if *desc {
1151                    *desc = false;
1152                    return i as u16;
1153                }
1154            }
1155
1156            panic!("No descriptors are available!");
1157        }
1158
1159        fn free_descriptor(&mut self, queue_index: u16, desc_index: u16) {
1160            assert!(desc_index < self.queue_size);
1161            let desc_addr = self.get_queue_descriptor(queue_index, desc_index);
1162            let flags: DescriptorFlags = self.test_mem.memory_map_get_u16(desc_addr + 12).into();
1163            if flags.next() {
1164                let next = self.test_mem.memory_map_get_u16(desc_addr + 14);
1165                self.free_descriptor(queue_index, next);
1166            }
1167            let avail_descriptors = &mut self.avail_descriptors[queue_index as usize];
1168            assert_eq!(avail_descriptors[desc_index as usize], false);
1169            avail_descriptors[desc_index as usize] = true;
1170        }
1171
1172        fn queue_available_desc(&mut self, queue_index: u16, desc_index: u16) {
1173            let avail_base_addr = self.get_queue_available_base_address(queue_index);
1174            let last_avail_index = &mut self.last_avail_index[queue_index as usize];
1175            let next_index = *last_avail_index % self.queue_size;
1176            *last_avail_index = last_avail_index.wrapping_add(1);
1177            self.test_mem.modify_memory_map(
1178                avail_base_addr + 4 + 2 * next_index as u64,
1179                &desc_index.to_le_bytes(),
1180                false,
1181            );
1182            self.test_mem.modify_memory_map(
1183                avail_base_addr + 2,
1184                &last_avail_index.to_le_bytes(),
1185                false,
1186            );
1187        }
1188
1189        fn add_to_avail_queue(&mut self, queue_index: u16) {
1190            let next_descriptor = self.reserve_descriptor(queue_index);
1191            // flags
1192            self.test_mem.modify_memory_map(
1193                self.get_queue_descriptor(queue_index, next_descriptor) + 12,
1194                &0u16.to_le_bytes(),
1195                false,
1196            );
1197            self.queue_available_desc(queue_index, next_descriptor);
1198        }
1199
1200        fn add_indirect_to_avail_queue(&mut self, queue_index: u16) {
1201            let next_descriptor = self.reserve_descriptor(queue_index);
1202            // flags
1203            self.test_mem.modify_memory_map(
1204                self.get_queue_descriptor(queue_index, next_descriptor) + 12,
1205                &u16::from(DescriptorFlags::new().with_indirect(true)).to_le_bytes(),
1206                false,
1207            );
1208            // create another (indirect) descriptor in the buffer
1209            let buffer_addr = self.get_queue_descriptor_backing_memory_address(queue_index);
1210            // physical address
1211            self.test_mem.modify_memory_map(
1212                buffer_addr,
1213                &0xffffffff00000000u64.to_le_bytes(),
1214                false,
1215            );
1216            // length
1217            self.test_mem
1218                .modify_memory_map(buffer_addr + 8, &0x1000u32.to_le_bytes(), false);
1219            // flags
1220            self.test_mem
1221                .modify_memory_map(buffer_addr + 12, &0u16.to_le_bytes(), false);
1222            // next index
1223            self.test_mem
1224                .modify_memory_map(buffer_addr + 14, &0u16.to_le_bytes(), false);
1225            self.queue_available_desc(queue_index, next_descriptor);
1226        }
1227
1228        fn add_linked_to_avail_queue(&mut self, queue_index: u16, desc_count: u16) {
1229            let mut descriptors = Vec::with_capacity(desc_count as usize);
1230            for _ in 0..desc_count {
1231                descriptors.push(self.reserve_descriptor(queue_index));
1232            }
1233
1234            for i in 0..descriptors.len() {
1235                let base = self.get_queue_descriptor(queue_index, descriptors[i]);
1236                let flags = if i < descriptors.len() - 1 {
1237                    u16::from(DescriptorFlags::new().with_next(true))
1238                } else {
1239                    0
1240                };
1241                self.test_mem
1242                    .modify_memory_map(base + 12, &flags.to_le_bytes(), false);
1243                let next = if i < descriptors.len() - 1 {
1244                    descriptors[i + 1]
1245                } else {
1246                    0
1247                };
1248                self.test_mem
1249                    .modify_memory_map(base + 14, &next.to_le_bytes(), false);
1250            }
1251            self.queue_available_desc(queue_index, descriptors[0]);
1252        }
1253
1254        fn add_indirect_linked_to_avail_queue(&mut self, queue_index: u16, desc_count: u16) {
1255            let next_descriptor = self.reserve_descriptor(queue_index);
1256            // flags
1257            self.test_mem.modify_memory_map(
1258                self.get_queue_descriptor(queue_index, next_descriptor) + 12,
1259                &u16::from(DescriptorFlags::new().with_indirect(true)).to_le_bytes(),
1260                false,
1261            );
1262            // create indirect descriptors in the buffer
1263            let buffer_addr = self.get_queue_descriptor_backing_memory_address(queue_index);
1264            for i in 0..desc_count {
1265                let base = buffer_addr + 0x10 * i as u64;
1266                let indirect_buffer_addr = 0xffffffff00000000u64 + 0x1000 * i as u64;
1267                // physical address
1268                self.test_mem
1269                    .modify_memory_map(base, &indirect_buffer_addr.to_le_bytes(), false);
1270                // length
1271                self.test_mem
1272                    .modify_memory_map(base + 8, &0x1000u32.to_le_bytes(), false);
1273                // flags
1274                let flags = if i < desc_count - 1 {
1275                    u16::from(DescriptorFlags::new().with_next(true))
1276                } else {
1277                    0
1278                };
1279                self.test_mem
1280                    .modify_memory_map(base + 12, &flags.to_le_bytes(), false);
1281                // next index
1282                let next = if i < desc_count - 1 { i + 1 } else { 0 };
1283                self.test_mem
1284                    .modify_memory_map(base + 14, &next.to_le_bytes(), false);
1285            }
1286            self.queue_available_desc(queue_index, next_descriptor);
1287        }
1288
1289        fn get_next_completed(&mut self, queue_index: u16) -> Option<(u16, u32)> {
1290            let avail_base_addr = self.get_queue_available_base_address(queue_index);
1291            let used_base_addr = self.get_queue_used_base_address(queue_index);
1292            let cur_used_index = self.test_mem.memory_map_get_u16(used_base_addr + 2);
1293            let last_used_index = &mut self.last_used_index[queue_index as usize];
1294            if *last_used_index == cur_used_index {
1295                return None;
1296            }
1297
1298            if self.use_ring_event_index {
1299                self.test_mem.modify_memory_map(
1300                    avail_base_addr + 4 + 2 * self.queue_size as u64,
1301                    &cur_used_index.to_le_bytes(),
1302                    false,
1303                );
1304            }
1305
1306            let next_index = *last_used_index % self.queue_size;
1307            *last_used_index = last_used_index.wrapping_add(1);
1308            let desc_index = self
1309                .test_mem
1310                .memory_map_get_u32(used_base_addr + 4 + 8 * next_index as u64);
1311            let desc_index = desc_index as u16;
1312            let bytes_written = self
1313                .test_mem
1314                .memory_map_get_u32(used_base_addr + 8 + 8 * next_index as u64);
1315            self.free_descriptor(queue_index, desc_index);
1316            Some((desc_index, bytes_written))
1317        }
1318    }
1319
1320    struct VirtioTestWork {
1321        callback: VirtioTestWorkCallback,
1322    }
1323
1324    #[async_trait]
1325    impl VirtioQueueWorkerContext for VirtioTestWork {
1326        async fn process_work(&mut self, work: anyhow::Result<VirtioQueueCallbackWork>) -> bool {
1327            (self.callback)(work)
1328        }
1329    }
1330    struct VirtioPciTestDevice {
1331        pci_device: VirtioPciDevice,
1332        test_intc: Arc<TestPciInterruptController>,
1333    }
1334
1335    type TestDeviceQueueWorkFn = Arc<dyn Fn(u16, VirtioQueueCallbackWork) + Send + Sync>;
1336
1337    struct TestDevice {
1338        traits: DeviceTraits,
1339        queue_work: Option<TestDeviceQueueWorkFn>,
1340    }
1341
1342    impl TestDevice {
1343        fn new(traits: DeviceTraits, queue_work: Option<TestDeviceQueueWorkFn>) -> Self {
1344            Self { traits, queue_work }
1345        }
1346    }
1347
1348    impl LegacyVirtioDevice for TestDevice {
1349        fn traits(&self) -> DeviceTraits {
1350            self.traits
1351        }
1352
1353        fn read_registers_u32(&self, _offset: u16) -> u32 {
1354            0
1355        }
1356
1357        fn write_registers_u32(&mut self, _offset: u16, _val: u32) {}
1358
1359        fn get_work_callback(&mut self, index: u16) -> Box<dyn VirtioQueueWorkerContext + Send> {
1360            Box::new(TestDeviceWorker {
1361                index,
1362                queue_work: self.queue_work.clone(),
1363            })
1364        }
1365
1366        fn state_change(&mut self, _state: &VirtioState) {}
1367    }
1368
1369    struct TestDeviceWorker {
1370        index: u16,
1371        queue_work: Option<TestDeviceQueueWorkFn>,
1372    }
1373
1374    #[async_trait]
1375    impl VirtioQueueWorkerContext for TestDeviceWorker {
1376        async fn process_work(&mut self, work: anyhow::Result<VirtioQueueCallbackWork>) -> bool {
1377            if let Err(err) = work {
1378                panic!(
1379                    "Invalid virtio queue state index {} error {}",
1380                    self.index,
1381                    err.as_ref() as &dyn std::error::Error
1382                );
1383            }
1384            if let Some(ref func) = self.queue_work {
1385                (func)(self.index, work.unwrap());
1386            }
1387            true
1388        }
1389    }
1390
1391    impl VirtioPciTestDevice {
1392        fn new(
1393            driver: &DefaultDriver,
1394            num_queues: u16,
1395            test_mem: &Arc<VirtioTestMemoryAccess>,
1396            queue_work: Option<TestDeviceQueueWorkFn>,
1397        ) -> Self {
1398            let doorbell_registration: Arc<dyn DoorbellRegistration> = test_mem.clone();
1399            let mem = GuestMemory::new("test", test_mem.clone());
1400            let mut msi_set = MsiInterruptSet::new();
1401
1402            let dev = VirtioPciDevice::new(
1403                Box::new(LegacyWrapper::new(
1404                    &VmTaskDriverSource::new(SingleDriverBackend::new(driver.clone())),
1405                    TestDevice::new(
1406                        DeviceTraits {
1407                            device_id: 3,
1408                            device_features: 2,
1409                            max_queues: num_queues,
1410                            device_register_length: 12,
1411                            ..Default::default()
1412                        },
1413                        queue_work,
1414                    ),
1415                    &mem,
1416                )),
1417                PciInterruptModel::Msix(&mut msi_set),
1418                Some(doorbell_registration),
1419                &mut ExternallyManagedMmioIntercepts,
1420                None,
1421            )
1422            .unwrap();
1423
1424            let test_intc = Arc::new(TestPciInterruptController::new());
1425            msi_set.connect(test_intc.as_ref());
1426
1427            Self {
1428                pci_device: dev,
1429                test_intc,
1430            }
1431        }
1432
1433        fn read_u32(&mut self, address: u64) -> u32 {
1434            let mut value = [0; 4];
1435            self.pci_device.mmio_read(address, &mut value).unwrap();
1436            u32::from_ne_bytes(value)
1437        }
1438
1439        fn write_u32(&mut self, address: u64, value: u32) {
1440            self.pci_device
1441                .mmio_write(address, &value.to_ne_bytes())
1442                .unwrap();
1443        }
1444    }
1445
1446    #[async_test]
1447    async fn verify_chipset_config(driver: DefaultDriver) {
1448        let mem = VirtioTestMemoryAccess::new();
1449        let doorbell_registration: Arc<dyn DoorbellRegistration> = mem.clone();
1450        let mem = GuestMemory::new("test", mem);
1451        let interrupt = LineInterrupt::detached();
1452
1453        let mut dev = VirtioMmioDevice::new(
1454            Box::new(LegacyWrapper::new(
1455                &VmTaskDriverSource::new(SingleDriverBackend::new(driver)),
1456                TestDevice::new(
1457                    DeviceTraits {
1458                        device_id: 3,
1459                        device_features: 2,
1460                        max_queues: 1,
1461                        device_register_length: 0,
1462                        ..Default::default()
1463                    },
1464                    None,
1465                ),
1466                &mem,
1467            )),
1468            interrupt,
1469            Some(doorbell_registration),
1470            0,
1471            1,
1472        );
1473        // magic value
1474        assert_eq!(dev.read_u32(0), u32::from_le_bytes(*b"virt"));
1475        // version
1476        assert_eq!(dev.read_u32(4), 2);
1477        // device ID
1478        assert_eq!(dev.read_u32(8), 3);
1479        // vendor ID
1480        assert_eq!(dev.read_u32(12), 0x1af4);
1481        // device feature (bank 0)
1482        assert_eq!(
1483            dev.read_u32(16),
1484            VIRTIO_F_RING_INDIRECT_DESC | VIRTIO_F_RING_EVENT_IDX | 2
1485        );
1486        // device feature bank index
1487        assert_eq!(dev.read_u32(20), 0);
1488        // device feature (bank 1)
1489        dev.write_u32(20, 1);
1490        assert_eq!(dev.read_u32(20), 1);
1491        assert_eq!(dev.read_u32(16), VIRTIO_F_VERSION_1);
1492        // device feature (bank 2)
1493        dev.write_u32(20, 2);
1494        assert_eq!(dev.read_u32(16), 0);
1495        // driver feature (bank 0)
1496        assert_eq!(dev.read_u32(32), 0);
1497        dev.write_u32(32, 2);
1498        assert_eq!(dev.read_u32(32), 2);
1499        dev.write_u32(32, 0xffffffff);
1500        assert_eq!(
1501            dev.read_u32(32),
1502            VIRTIO_F_RING_INDIRECT_DESC | VIRTIO_F_RING_EVENT_IDX | 2
1503        );
1504        // driver feature bank index
1505        assert_eq!(dev.read_u32(36), 0);
1506        dev.write_u32(36, 1);
1507        assert_eq!(dev.read_u32(36), 1);
1508        // driver feature (bank 1)
1509        assert_eq!(dev.read_u32(32), 0);
1510        dev.write_u32(32, 0xffffffff);
1511        assert_eq!(dev.read_u32(32), VIRTIO_F_VERSION_1);
1512        // driver feature (bank 2)
1513        dev.write_u32(36, 2);
1514        assert_eq!(dev.read_u32(32), 0);
1515        dev.write_u32(32, 0xffffffff);
1516        assert_eq!(dev.read_u32(32), 0);
1517        // host notify
1518        assert_eq!(dev.read_u32(80), 0);
1519        // interrupt status
1520        assert_eq!(dev.read_u32(96), 0);
1521        // interrupt ACK (queue 0)
1522        assert_eq!(dev.read_u32(100), 0);
1523        // device status
1524        assert_eq!(dev.read_u32(112), 0);
1525        // config generation
1526        assert_eq!(dev.read_u32(0xfc), 0);
1527
1528        // queue index
1529        assert_eq!(dev.read_u32(48), 0);
1530        // queue max size (queue 0)
1531        assert_eq!(dev.read_u32(52), 0x40);
1532        // queue size (queue 0)
1533        assert_eq!(dev.read_u32(56), 0x40);
1534        dev.write_u32(56, 0x20);
1535        assert_eq!(dev.read_u32(56), 0x20);
1536        // queue enable (queue 0)
1537        assert_eq!(dev.read_u32(68), 0);
1538        dev.write_u32(68, 1);
1539        assert_eq!(dev.read_u32(68), 1);
1540        dev.write_u32(68, 0xffffffff);
1541        assert_eq!(dev.read_u32(68), 1);
1542        dev.write_u32(68, 0);
1543        assert_eq!(dev.read_u32(68), 0);
1544        // queue descriptor address low (queue 0)
1545        assert_eq!(dev.read_u32(128), 0);
1546        dev.write_u32(128, 0xffff);
1547        assert_eq!(dev.read_u32(128), 0xffff);
1548        // queue descriptor address high (queue 0)
1549        assert_eq!(dev.read_u32(132), 0);
1550        dev.write_u32(132, 1);
1551        assert_eq!(dev.read_u32(132), 1);
1552        // queue available address low (queue 0)
1553        assert_eq!(dev.read_u32(144), 0);
1554        dev.write_u32(144, 0xeeee);
1555        assert_eq!(dev.read_u32(144), 0xeeee);
1556        // queue available address high (queue 0)
1557        assert_eq!(dev.read_u32(148), 0);
1558        dev.write_u32(148, 2);
1559        assert_eq!(dev.read_u32(148), 2);
1560        // queue used address low (queue 0)
1561        assert_eq!(dev.read_u32(160), 0);
1562        dev.write_u32(160, 0xdddd);
1563        assert_eq!(dev.read_u32(160), 0xdddd);
1564        // queue used address high (queue 0)
1565        assert_eq!(dev.read_u32(164), 0);
1566        dev.write_u32(164, 3);
1567        assert_eq!(dev.read_u32(164), 3);
1568
1569        // switch to queue #1
1570        dev.write_u32(48, 1);
1571        assert_eq!(dev.read_u32(48), 1);
1572        // queue max size (queue 1)
1573        assert_eq!(dev.read_u32(52), 0);
1574        // queue size (queue 1)
1575        assert_eq!(dev.read_u32(56), 0);
1576        dev.write_u32(56, 2);
1577        assert_eq!(dev.read_u32(56), 0);
1578        // queue enable (queue 1)
1579        assert_eq!(dev.read_u32(68), 0);
1580        dev.write_u32(68, 1);
1581        assert_eq!(dev.read_u32(68), 0);
1582        // queue descriptor address low (queue 1)
1583        assert_eq!(dev.read_u32(128), 0);
1584        dev.write_u32(128, 1);
1585        assert_eq!(dev.read_u32(128), 0);
1586        // queue descriptor address high (queue 1)
1587        assert_eq!(dev.read_u32(132), 0);
1588        dev.write_u32(132, 1);
1589        assert_eq!(dev.read_u32(132), 0);
1590        // queue available address low (queue 1)
1591        assert_eq!(dev.read_u32(144), 0);
1592        dev.write_u32(144, 1);
1593        assert_eq!(dev.read_u32(144), 0);
1594        // queue available address high (queue 1)
1595        assert_eq!(dev.read_u32(148), 0);
1596        dev.write_u32(148, 1);
1597        assert_eq!(dev.read_u32(148), 0);
1598        // queue used address low (queue 1)
1599        assert_eq!(dev.read_u32(160), 0);
1600        dev.write_u32(160, 1);
1601        assert_eq!(dev.read_u32(160), 0);
1602        // queue used address high (queue 1)
1603        assert_eq!(dev.read_u32(164), 0);
1604        dev.write_u32(164, 1);
1605        assert_eq!(dev.read_u32(164), 0);
1606    }
1607
1608    #[async_test]
1609    async fn verify_pci_config(driver: DefaultDriver) {
1610        let mut pci_test_device =
1611            VirtioPciTestDevice::new(&driver, 1, &VirtioTestMemoryAccess::new(), None);
1612        let mut capabilities = 0;
1613        pci_test_device
1614            .pci_device
1615            .pci_cfg_read(4, &mut capabilities)
1616            .unwrap();
1617        assert_eq!(
1618            capabilities,
1619            (cfg_space::Status::new()
1620                .with_capabilities_list(true)
1621                .into_bits() as u32)
1622                << 16
1623        );
1624        let mut next_cap_offset = 0;
1625        pci_test_device
1626            .pci_device
1627            .pci_cfg_read(0x34, &mut next_cap_offset)
1628            .unwrap();
1629        assert_ne!(next_cap_offset, 0);
1630
1631        let mut header = 0;
1632        pci_test_device
1633            .pci_device
1634            .pci_cfg_read(next_cap_offset as u16, &mut header)
1635            .unwrap();
1636        let header = header.to_le_bytes();
1637        assert_eq!(header[0], CapabilityId::MSIX.0);
1638        next_cap_offset = header[1] as u32;
1639        assert_ne!(next_cap_offset, 0);
1640
1641        let mut header = 0;
1642        pci_test_device
1643            .pci_device
1644            .pci_cfg_read(next_cap_offset as u16, &mut header)
1645            .unwrap();
1646        let header = header.to_le_bytes();
1647        assert_eq!(header[0], CapabilityId::VENDOR_SPECIFIC.0);
1648        assert_eq!(header[3], VIRTIO_PCI_CAP_COMMON_CFG);
1649        assert_eq!(header[2], 16);
1650        let mut buf = 0;
1651
1652        pci_test_device
1653            .pci_device
1654            .pci_cfg_read(next_cap_offset as u16 + 4, &mut buf)
1655            .unwrap();
1656        assert_eq!(buf, 0);
1657        pci_test_device
1658            .pci_device
1659            .pci_cfg_read(next_cap_offset as u16 + 8, &mut buf)
1660            .unwrap();
1661        assert_eq!(buf, 0);
1662        pci_test_device
1663            .pci_device
1664            .pci_cfg_read(next_cap_offset as u16 + 12, &mut buf)
1665            .unwrap();
1666        assert_eq!(buf, 0x38);
1667        next_cap_offset = header[1] as u32;
1668        assert_ne!(next_cap_offset, 0);
1669
1670        let mut header = 0;
1671        pci_test_device
1672            .pci_device
1673            .pci_cfg_read(next_cap_offset as u16, &mut header)
1674            .unwrap();
1675        let header = header.to_le_bytes();
1676        assert_eq!(header[0], CapabilityId::VENDOR_SPECIFIC.0);
1677        assert_eq!(header[3], VIRTIO_PCI_CAP_NOTIFY_CFG);
1678        assert_eq!(header[2], 20);
1679        pci_test_device
1680            .pci_device
1681            .pci_cfg_read(next_cap_offset as u16 + 4, &mut buf)
1682            .unwrap();
1683        assert_eq!(buf, 0);
1684        pci_test_device
1685            .pci_device
1686            .pci_cfg_read(next_cap_offset as u16 + 8, &mut buf)
1687            .unwrap();
1688        assert_eq!(buf, 0x38);
1689        pci_test_device
1690            .pci_device
1691            .pci_cfg_read(next_cap_offset as u16 + 12, &mut buf)
1692            .unwrap();
1693        assert_eq!(buf, 4);
1694        next_cap_offset = header[1] as u32;
1695        assert_ne!(next_cap_offset, 0);
1696
1697        let mut header = 0;
1698        pci_test_device
1699            .pci_device
1700            .pci_cfg_read(next_cap_offset as u16, &mut header)
1701            .unwrap();
1702        let header = header.to_le_bytes();
1703        assert_eq!(header[0], CapabilityId::VENDOR_SPECIFIC.0);
1704        assert_eq!(header[3], VIRTIO_PCI_CAP_ISR_CFG);
1705        assert_eq!(header[2], 16);
1706        pci_test_device
1707            .pci_device
1708            .pci_cfg_read(next_cap_offset as u16 + 4, &mut buf)
1709            .unwrap();
1710        assert_eq!(buf, 0);
1711        pci_test_device
1712            .pci_device
1713            .pci_cfg_read(next_cap_offset as u16 + 8, &mut buf)
1714            .unwrap();
1715        assert_eq!(buf, 0x3c);
1716        pci_test_device
1717            .pci_device
1718            .pci_cfg_read(next_cap_offset as u16 + 12, &mut buf)
1719            .unwrap();
1720        assert_eq!(buf, 4);
1721        next_cap_offset = header[1] as u32;
1722        assert_ne!(next_cap_offset, 0);
1723
1724        let mut header = 0;
1725        pci_test_device
1726            .pci_device
1727            .pci_cfg_read(next_cap_offset as u16, &mut header)
1728            .unwrap();
1729        let header = header.to_le_bytes();
1730        assert_eq!(header[0], CapabilityId::VENDOR_SPECIFIC.0);
1731        assert_eq!(header[3], VIRTIO_PCI_CAP_DEVICE_CFG);
1732        assert_eq!(header[2], 16);
1733        pci_test_device
1734            .pci_device
1735            .pci_cfg_read(next_cap_offset as u16 + 4, &mut buf)
1736            .unwrap();
1737        assert_eq!(buf, 0);
1738        pci_test_device
1739            .pci_device
1740            .pci_cfg_read(next_cap_offset as u16 + 8, &mut buf)
1741            .unwrap();
1742        assert_eq!(buf, 0x40);
1743        pci_test_device
1744            .pci_device
1745            .pci_cfg_read(next_cap_offset as u16 + 12, &mut buf)
1746            .unwrap();
1747        assert_eq!(buf, 12);
1748        next_cap_offset = header[1] as u32;
1749        assert_eq!(next_cap_offset, 0);
1750    }
1751
1752    #[async_test]
1753    async fn verify_pci_registers(driver: DefaultDriver) {
1754        let mut pci_test_device =
1755            VirtioPciTestDevice::new(&driver, 1, &VirtioTestMemoryAccess::new(), None);
1756        let bar_address1: u64 = 0x2000000000;
1757        pci_test_device
1758            .pci_device
1759            .pci_cfg_write(0x14, (bar_address1 >> 32) as u32)
1760            .unwrap();
1761        pci_test_device
1762            .pci_device
1763            .pci_cfg_write(0x10, bar_address1 as u32)
1764            .unwrap();
1765
1766        let bar_address2: u64 = 0x4000;
1767        pci_test_device
1768            .pci_device
1769            .pci_cfg_write(0x1c, (bar_address2 >> 32) as u32)
1770            .unwrap();
1771        pci_test_device
1772            .pci_device
1773            .pci_cfg_write(0x18, bar_address2 as u32)
1774            .unwrap();
1775
1776        pci_test_device
1777            .pci_device
1778            .pci_cfg_write(
1779                0x4,
1780                cfg_space::Command::new()
1781                    .with_mmio_enabled(true)
1782                    .into_bits() as u32,
1783            )
1784            .unwrap();
1785
1786        // device feature bank index
1787        assert_eq!(pci_test_device.read_u32(bar_address1), 0);
1788        // device feature (bank 0)
1789        assert_eq!(
1790            pci_test_device.read_u32(bar_address1 + 4),
1791            VIRTIO_F_RING_INDIRECT_DESC | VIRTIO_F_RING_EVENT_IDX | 2
1792        );
1793        // device feature (bank 1)
1794        pci_test_device.write_u32(bar_address1, 1);
1795        assert_eq!(pci_test_device.read_u32(bar_address1), 1);
1796        assert_eq!(
1797            pci_test_device.read_u32(bar_address1 + 4),
1798            VIRTIO_F_VERSION_1
1799        );
1800        // device feature (bank 2)
1801        pci_test_device.write_u32(bar_address1, 2);
1802        assert_eq!(pci_test_device.read_u32(bar_address1), 2);
1803        assert_eq!(pci_test_device.read_u32(bar_address1 + 4), 0);
1804        // driver feature bank index
1805        assert_eq!(pci_test_device.read_u32(bar_address1 + 8), 0);
1806        // driver feature (bank 0)
1807        assert_eq!(pci_test_device.read_u32(bar_address1 + 12), 0);
1808        pci_test_device.write_u32(bar_address1 + 12, 2);
1809        assert_eq!(pci_test_device.read_u32(bar_address1 + 12), 2);
1810        pci_test_device.write_u32(bar_address1 + 12, 0xffffffff);
1811        assert_eq!(
1812            pci_test_device.read_u32(bar_address1 + 12),
1813            VIRTIO_F_RING_INDIRECT_DESC | VIRTIO_F_RING_EVENT_IDX | 2
1814        );
1815        // driver feature (bank 1)
1816        pci_test_device.write_u32(bar_address1 + 8, 1);
1817        assert_eq!(pci_test_device.read_u32(bar_address1 + 8), 1);
1818        assert_eq!(pci_test_device.read_u32(bar_address1 + 12), 0);
1819        pci_test_device.write_u32(bar_address1 + 12, 0xffffffff);
1820        assert_eq!(
1821            pci_test_device.read_u32(bar_address1 + 12),
1822            VIRTIO_F_VERSION_1
1823        );
1824        // driver feature (bank 2)
1825        pci_test_device.write_u32(bar_address1 + 8, 2);
1826        assert_eq!(pci_test_device.read_u32(bar_address1 + 8), 2);
1827        assert_eq!(pci_test_device.read_u32(bar_address1 + 12), 0);
1828        pci_test_device.write_u32(bar_address1 + 12, 0xffffffff);
1829        assert_eq!(pci_test_device.read_u32(bar_address1 + 12), 0);
1830        // max queues and the msix vector for config changes
1831        assert_eq!(pci_test_device.read_u32(bar_address1 + 16), 1 << 16);
1832        // queue index, config generation and device status
1833        assert_eq!(pci_test_device.read_u32(bar_address1 + 20), 0);
1834        // current queue size and msix vector
1835        assert_eq!(pci_test_device.read_u32(bar_address1 + 24), 0x40);
1836        pci_test_device.write_u32(bar_address1 + 24, 0x20);
1837        assert_eq!(pci_test_device.read_u32(bar_address1 + 24), 0x20);
1838        // current queue enabled and notify offset
1839        assert_eq!(pci_test_device.read_u32(bar_address1 + 28), 0);
1840        pci_test_device.write_u32(bar_address1 + 28, 1);
1841        assert_eq!(pci_test_device.read_u32(bar_address1 + 28), 1);
1842        pci_test_device.write_u32(bar_address1 + 28, 0xffff);
1843        assert_eq!(pci_test_device.read_u32(bar_address1 + 28), 1);
1844        pci_test_device.write_u32(bar_address1 + 28, 0);
1845        assert_eq!(pci_test_device.read_u32(bar_address1 + 28), 0);
1846        // current queue descriptor table address (low)
1847        assert_eq!(pci_test_device.read_u32(bar_address1 + 32), 0);
1848        pci_test_device.write_u32(bar_address1 + 32, 0xffff);
1849        assert_eq!(pci_test_device.read_u32(bar_address1 + 32), 0xffff);
1850        // current queue descriptor table address (high)
1851        assert_eq!(pci_test_device.read_u32(bar_address1 + 36), 0);
1852        pci_test_device.write_u32(bar_address1 + 36, 1);
1853        assert_eq!(pci_test_device.read_u32(bar_address1 + 36), 1);
1854        // current queue available ring address (low)
1855        assert_eq!(pci_test_device.read_u32(bar_address1 + 40), 0);
1856        pci_test_device.write_u32(bar_address1 + 40, 0xeeee);
1857        assert_eq!(pci_test_device.read_u32(bar_address1 + 40), 0xeeee);
1858        // current queue available ring address (high)
1859        assert_eq!(pci_test_device.read_u32(bar_address1 + 44), 0);
1860        pci_test_device.write_u32(bar_address1 + 44, 2);
1861        assert_eq!(pci_test_device.read_u32(bar_address1 + 44), 2);
1862        // current queue used ring address (low)
1863        assert_eq!(pci_test_device.read_u32(bar_address1 + 48), 0);
1864        pci_test_device.write_u32(bar_address1 + 48, 0xdddd);
1865        assert_eq!(pci_test_device.read_u32(bar_address1 + 48), 0xdddd);
1866        // current queue used ring address (high)
1867        assert_eq!(pci_test_device.read_u32(bar_address1 + 52), 0);
1868        pci_test_device.write_u32(bar_address1 + 52, 3);
1869        assert_eq!(pci_test_device.read_u32(bar_address1 + 52), 3);
1870        // VIRTIO_PCI_CAP_NOTIFY_CFG notification register
1871        assert_eq!(pci_test_device.read_u32(bar_address1 + 56), 0);
1872        // VIRTIO_PCI_CAP_ISR_CFG register
1873        assert_eq!(pci_test_device.read_u32(bar_address1 + 60), 0);
1874
1875        // switch to queue #1 (disabled, only one queue on this device)
1876        let queue_index: u16 = 1;
1877        pci_test_device
1878            .pci_device
1879            .mmio_write(bar_address1 + 22, &queue_index.to_le_bytes())
1880            .unwrap();
1881        assert_eq!(pci_test_device.read_u32(bar_address1 + 20), 1 << 24);
1882        // current queue size and msix vector
1883        assert_eq!(pci_test_device.read_u32(bar_address1 + 24), 0);
1884        pci_test_device.write_u32(bar_address1 + 24, 2);
1885        assert_eq!(pci_test_device.read_u32(bar_address1 + 24), 0);
1886        // current queue enabled and notify offset
1887        assert_eq!(pci_test_device.read_u32(bar_address1 + 28), 0);
1888        pci_test_device.write_u32(bar_address1 + 28, 1);
1889        assert_eq!(pci_test_device.read_u32(bar_address1 + 28), 0);
1890        // current queue descriptor table address (low)
1891        assert_eq!(pci_test_device.read_u32(bar_address1 + 32), 0);
1892        pci_test_device.write_u32(bar_address1 + 32, 0x10);
1893        assert_eq!(pci_test_device.read_u32(bar_address1 + 32), 0);
1894        // current queue descriptor table address (high)
1895        assert_eq!(pci_test_device.read_u32(bar_address1 + 36), 0);
1896        pci_test_device.write_u32(bar_address1 + 36, 0x10);
1897        assert_eq!(pci_test_device.read_u32(bar_address1 + 36), 0);
1898        // current queue available ring address (low)
1899        assert_eq!(pci_test_device.read_u32(bar_address1 + 40), 0);
1900        pci_test_device.write_u32(bar_address1 + 40, 0x10);
1901        assert_eq!(pci_test_device.read_u32(bar_address1 + 40), 0);
1902        // current queue available ring address (high)
1903        assert_eq!(pci_test_device.read_u32(bar_address1 + 44), 0);
1904        pci_test_device.write_u32(bar_address1 + 44, 0x10);
1905        assert_eq!(pci_test_device.read_u32(bar_address1 + 44), 0);
1906        // current queue used ring address (low)
1907        assert_eq!(pci_test_device.read_u32(bar_address1 + 48), 0);
1908        pci_test_device.write_u32(bar_address1 + 48, 0x10);
1909        assert_eq!(pci_test_device.read_u32(bar_address1 + 48), 0);
1910        // current queue used ring address (high)
1911        assert_eq!(pci_test_device.read_u32(bar_address1 + 52), 0);
1912        pci_test_device.write_u32(bar_address1 + 52, 0x10);
1913        assert_eq!(pci_test_device.read_u32(bar_address1 + 52), 0);
1914    }
1915
1916    #[async_test]
1917    async fn verify_queue_simple(driver: DefaultDriver) {
1918        let test_mem = VirtioTestMemoryAccess::new();
1919        let mut guest = VirtioTestGuest::new(&driver, &test_mem, 1, 2, true);
1920        let base_addr = guest.get_queue_descriptor_backing_memory_address(0);
1921        let (tx, mut rx) = mesh::mpsc_channel();
1922        let event = Event::new();
1923        let mut queues = guest.create_direct_queues(|i| {
1924            let tx = tx.clone();
1925            CreateDirectQueueParams {
1926                process_work: Box::new(move |work: anyhow::Result<VirtioQueueCallbackWork>| {
1927                    let mut work = work.expect("Queue failure");
1928                    assert_eq!(work.payload.len(), 1);
1929                    assert_eq!(work.payload[0].address, base_addr);
1930                    assert_eq!(work.payload[0].length, 0x1000);
1931                    work.complete(123);
1932                    true
1933                }),
1934                notify: Interrupt::from_fn(move || {
1935                    tx.send(i as usize);
1936                }),
1937                event: event.clone(),
1938            }
1939        });
1940
1941        guest.add_to_avail_queue(0);
1942        event.signal();
1943        must_recv_in_timeout(&mut rx, Duration::from_millis(100)).await;
1944        let (desc, len) = guest.get_next_completed(0).unwrap();
1945        assert_eq!(desc, 0u16);
1946        assert_eq!(len, 123);
1947        assert_eq!(guest.get_next_completed(0).is_none(), true);
1948        queues[0].stop().await;
1949    }
1950
1951    #[async_test]
1952    async fn verify_queue_indirect(driver: DefaultDriver) {
1953        let test_mem = VirtioTestMemoryAccess::new();
1954        let mut guest = VirtioTestGuest::new(&driver, &test_mem, 1, 2, true);
1955        let (tx, mut rx) = mesh::mpsc_channel();
1956        let event = Event::new();
1957        let mut queues = guest.create_direct_queues(|i| {
1958            let tx = tx.clone();
1959            CreateDirectQueueParams {
1960                process_work: Box::new(move |work: anyhow::Result<VirtioQueueCallbackWork>| {
1961                    let mut work = work.expect("Queue failure");
1962                    assert_eq!(work.payload.len(), 1);
1963                    assert_eq!(work.payload[0].address, 0xffffffff00000000u64);
1964                    assert_eq!(work.payload[0].length, 0x1000);
1965                    work.complete(123);
1966                    true
1967                }),
1968                notify: Interrupt::from_fn(move || {
1969                    tx.send(i as usize);
1970                }),
1971                event: event.clone(),
1972            }
1973        });
1974
1975        guest.add_indirect_to_avail_queue(0);
1976        event.signal();
1977        must_recv_in_timeout(&mut rx, Duration::from_millis(100)).await;
1978        let (desc, len) = guest.get_next_completed(0).unwrap();
1979        assert_eq!(desc, 0u16);
1980        assert_eq!(len, 123);
1981        assert_eq!(guest.get_next_completed(0).is_none(), true);
1982        queues[0].stop().await;
1983    }
1984
1985    #[async_test]
1986    async fn verify_queue_linked(driver: DefaultDriver) {
1987        let test_mem = VirtioTestMemoryAccess::new();
1988        let mut guest = VirtioTestGuest::new(&driver, &test_mem, 1, 5, true);
1989        let (tx, mut rx) = mesh::mpsc_channel();
1990        let base_address = guest.get_queue_descriptor_backing_memory_address(0);
1991        let event = Event::new();
1992        let mut queues = guest.create_direct_queues(|i| {
1993            let tx = tx.clone();
1994            CreateDirectQueueParams {
1995                process_work: Box::new(move |work: anyhow::Result<VirtioQueueCallbackWork>| {
1996                    let mut work = work.expect("Queue failure");
1997                    assert_eq!(work.payload.len(), 3);
1998                    for i in 0..work.payload.len() {
1999                        assert_eq!(work.payload[i].address, base_address + 0x1000 * i as u64);
2000                        assert_eq!(work.payload[i].length, 0x1000);
2001                    }
2002                    work.complete(123 * 3);
2003                    true
2004                }),
2005                notify: Interrupt::from_fn(move || {
2006                    tx.send(i as usize);
2007                }),
2008                event: event.clone(),
2009            }
2010        });
2011
2012        guest.add_linked_to_avail_queue(0, 3);
2013        event.signal();
2014        must_recv_in_timeout(&mut rx, Duration::from_millis(100)).await;
2015        let (desc, len) = guest.get_next_completed(0).unwrap();
2016        assert_eq!(desc, 0u16);
2017        assert_eq!(len, 123 * 3);
2018        assert_eq!(guest.get_next_completed(0).is_none(), true);
2019        queues[0].stop().await;
2020    }
2021
2022    #[async_test]
2023    async fn verify_queue_indirect_linked(driver: DefaultDriver) {
2024        let test_mem = VirtioTestMemoryAccess::new();
2025        let mut guest = VirtioTestGuest::new(&driver, &test_mem, 1, 5, true);
2026        let (tx, mut rx) = mesh::mpsc_channel();
2027        let event = Event::new();
2028        let mut queues = guest.create_direct_queues(|i| {
2029            let tx = tx.clone();
2030            CreateDirectQueueParams {
2031                process_work: Box::new(move |work: anyhow::Result<VirtioQueueCallbackWork>| {
2032                    let mut work = work.expect("Queue failure");
2033                    assert_eq!(work.payload.len(), 3);
2034                    for i in 0..work.payload.len() {
2035                        assert_eq!(
2036                            work.payload[i].address,
2037                            0xffffffff00000000u64 + 0x1000 * i as u64
2038                        );
2039                        assert_eq!(work.payload[i].length, 0x1000);
2040                    }
2041                    work.complete(123 * 3);
2042                    true
2043                }),
2044                notify: Interrupt::from_fn(move || {
2045                    tx.send(i as usize);
2046                }),
2047                event: event.clone(),
2048            }
2049        });
2050
2051        guest.add_indirect_linked_to_avail_queue(0, 3);
2052        event.signal();
2053        must_recv_in_timeout(&mut rx, Duration::from_millis(100)).await;
2054        let (desc, len) = guest.get_next_completed(0).unwrap();
2055        assert_eq!(desc, 0u16);
2056        assert_eq!(len, 123 * 3);
2057        assert_eq!(guest.get_next_completed(0).is_none(), true);
2058        queues[0].stop().await;
2059    }
2060
2061    #[async_test]
2062    async fn verify_queue_avail_rollover(driver: DefaultDriver) {
2063        let test_mem = VirtioTestMemoryAccess::new();
2064        let mut guest = VirtioTestGuest::new(&driver, &test_mem, 1, 2, true);
2065        let base_addr = guest.get_queue_descriptor_backing_memory_address(0);
2066        let (tx, mut rx) = mesh::mpsc_channel();
2067        let event = Event::new();
2068        let mut queues = guest.create_direct_queues(|i| {
2069            let tx = tx.clone();
2070            CreateDirectQueueParams {
2071                process_work: Box::new(move |work: anyhow::Result<VirtioQueueCallbackWork>| {
2072                    let mut work = work.expect("Queue failure");
2073                    assert_eq!(work.payload.len(), 1);
2074                    assert_eq!(work.payload[0].address, base_addr);
2075                    assert_eq!(work.payload[0].length, 0x1000);
2076                    work.complete(123);
2077                    true
2078                }),
2079                notify: Interrupt::from_fn(move || {
2080                    tx.send(i as usize);
2081                }),
2082                event: event.clone(),
2083            }
2084        });
2085
2086        for _ in 0..3 {
2087            guest.add_to_avail_queue(0);
2088            event.signal();
2089            must_recv_in_timeout(&mut rx, Duration::from_millis(100)).await;
2090            let (desc, len) = guest.get_next_completed(0).unwrap();
2091            assert_eq!(desc, 0u16);
2092            assert_eq!(len, 123);
2093            assert_eq!(guest.get_next_completed(0).is_none(), true);
2094        }
2095
2096        queues[0].stop().await;
2097    }
2098
2099    #[async_test]
2100    async fn verify_multi_queue(driver: DefaultDriver) {
2101        let test_mem = VirtioTestMemoryAccess::new();
2102        let mut guest = VirtioTestGuest::new(&driver, &test_mem, 5, 2, true);
2103        let (tx, mut rx) = mesh::mpsc_channel();
2104        let events = (0..guest.num_queues)
2105            .map(|_| Event::new())
2106            .collect::<Vec<_>>();
2107        let mut queues = guest.create_direct_queues(|queue_index| {
2108            let tx = tx.clone();
2109            let base_addr = guest.get_queue_descriptor_backing_memory_address(queue_index);
2110            CreateDirectQueueParams {
2111                process_work: Box::new(move |work: anyhow::Result<VirtioQueueCallbackWork>| {
2112                    let mut work = work.expect("Queue failure");
2113                    assert_eq!(work.payload.len(), 1);
2114                    assert_eq!(work.payload[0].address, base_addr);
2115                    assert_eq!(work.payload[0].length, 0x1000);
2116                    work.complete(123 * queue_index as u32);
2117                    true
2118                }),
2119                notify: Interrupt::from_fn(move || {
2120                    tx.send(queue_index as usize);
2121                }),
2122                event: events[queue_index as usize].clone(),
2123            }
2124        });
2125
2126        for (i, event) in events.iter().enumerate() {
2127            let queue_index = i as u16;
2128            guest.add_to_avail_queue(queue_index);
2129            event.signal();
2130        }
2131        // wait for all queue processing to finish
2132        for _ in 0..guest.num_queues {
2133            must_recv_in_timeout(&mut rx, Duration::from_millis(100)).await;
2134        }
2135        // check results
2136        for queue_index in 0..guest.num_queues {
2137            let (desc, len) = guest.get_next_completed(queue_index).unwrap();
2138            assert_eq!(desc, 0u16);
2139            assert_eq!(len, 123 * queue_index as u32);
2140        }
2141        // verify no extraneous completions
2142        for (i, queue) in queues.iter_mut().enumerate() {
2143            let queue_index = i as u16;
2144            assert_eq!(guest.get_next_completed(queue_index).is_none(), true);
2145            queue.stop().await;
2146        }
2147    }
2148
2149    fn take_mmio_interrupt_status(dev: &mut VirtioMmioDevice, mask: u32) -> u32 {
2150        let mut v = [0; 4];
2151        dev.mmio_read(96, &mut v).unwrap();
2152        dev.mmio_write(100, &mask.to_ne_bytes()).unwrap();
2153        u32::from_ne_bytes(v)
2154    }
2155
2156    async fn expect_mmio_interrupt(
2157        dev: &mut VirtioMmioDevice,
2158        target: &TestLineInterruptTarget,
2159        mask: u32,
2160        multiple_expected: bool,
2161    ) {
2162        poll_fn(|cx| target.poll_high(cx, 0)).await;
2163        let v = take_mmio_interrupt_status(dev, mask);
2164        assert_eq!(v & mask, mask);
2165        assert!(multiple_expected || !target.is_high(0));
2166    }
2167
2168    #[async_test]
2169    async fn verify_device_queue_simple(driver: DefaultDriver) {
2170        let test_mem = VirtioTestMemoryAccess::new();
2171        let doorbell_registration: Arc<dyn DoorbellRegistration> = test_mem.clone();
2172        let mut guest = VirtioTestGuest::new(&driver, &test_mem, 1, 2, true);
2173        let mem = guest.mem();
2174        let features = ((VIRTIO_F_VERSION_1 as u64) << 32) | VIRTIO_F_RING_EVENT_IDX as u64 | 2;
2175        let target = TestLineInterruptTarget::new_arc();
2176        let interrupt = LineInterrupt::new_with_target("test", target.clone(), 0);
2177        let base_addr = guest.get_queue_descriptor_backing_memory_address(0);
2178        let queue_work = Arc::new(move |_: u16, mut work: VirtioQueueCallbackWork| {
2179            assert_eq!(work.payload.len(), 1);
2180            assert_eq!(work.payload[0].address, base_addr);
2181            assert_eq!(work.payload[0].length, 0x1000);
2182            work.complete(123);
2183        });
2184        let mut dev = VirtioMmioDevice::new(
2185            Box::new(LegacyWrapper::new(
2186                &VmTaskDriverSource::new(SingleDriverBackend::new(driver)),
2187                TestDevice::new(
2188                    DeviceTraits {
2189                        device_id: 3,
2190                        device_features: features,
2191                        max_queues: 1,
2192                        device_register_length: 0,
2193                        ..Default::default()
2194                    },
2195                    Some(queue_work),
2196                ),
2197                &mem,
2198            )),
2199            interrupt,
2200            Some(doorbell_registration),
2201            0,
2202            1,
2203        );
2204
2205        guest.setup_chipset_device(&mut dev, features);
2206        expect_mmio_interrupt(
2207            &mut dev,
2208            &target,
2209            VIRTIO_MMIO_INTERRUPT_STATUS_CONFIG_CHANGE,
2210            false,
2211        )
2212        .await;
2213        guest.add_to_avail_queue(0);
2214        // notify device
2215        dev.write_u32(80, 0);
2216        expect_mmio_interrupt(
2217            &mut dev,
2218            &target,
2219            VIRTIO_MMIO_INTERRUPT_STATUS_USED_BUFFER,
2220            false,
2221        )
2222        .await;
2223        let (desc, len) = guest.get_next_completed(0).unwrap();
2224        assert_eq!(desc, 0u16);
2225        assert_eq!(len, 123);
2226        assert_eq!(guest.get_next_completed(0).is_none(), true);
2227        // reset the device
2228        dev.write_u32(112, 0);
2229        drop(dev);
2230    }
2231
2232    #[async_test]
2233    async fn verify_device_multi_queue(driver: DefaultDriver) {
2234        let num_queues = 5;
2235        let test_mem = VirtioTestMemoryAccess::new();
2236        let doorbell_registration: Arc<dyn DoorbellRegistration> = test_mem.clone();
2237        let mut guest = VirtioTestGuest::new(&driver, &test_mem, num_queues, 2, true);
2238        let mem = guest.mem();
2239        let features = ((VIRTIO_F_VERSION_1 as u64) << 32) | VIRTIO_F_RING_EVENT_IDX as u64 | 2;
2240        let target = TestLineInterruptTarget::new_arc();
2241        let interrupt = LineInterrupt::new_with_target("test", target.clone(), 0);
2242        let base_addr: Vec<_> = (0..num_queues)
2243            .map(|i| guest.get_queue_descriptor_backing_memory_address(i))
2244            .collect();
2245        let queue_work = Arc::new(move |i: u16, mut work: VirtioQueueCallbackWork| {
2246            assert_eq!(work.payload.len(), 1);
2247            assert_eq!(work.payload[0].address, base_addr[i as usize]);
2248            assert_eq!(work.payload[0].length, 0x1000);
2249            work.complete(123 * i as u32);
2250        });
2251        let mut dev = VirtioMmioDevice::new(
2252            Box::new(LegacyWrapper::new(
2253                &VmTaskDriverSource::new(SingleDriverBackend::new(driver)),
2254                TestDevice::new(
2255                    DeviceTraits {
2256                        device_id: 3,
2257                        device_features: features,
2258                        max_queues: num_queues + 1,
2259                        device_register_length: 0,
2260                        ..Default::default()
2261                    },
2262                    Some(queue_work),
2263                ),
2264                &mem,
2265            )),
2266            interrupt,
2267            Some(doorbell_registration),
2268            0,
2269            1,
2270        );
2271        guest.setup_chipset_device(&mut dev, features);
2272        expect_mmio_interrupt(
2273            &mut dev,
2274            &target,
2275            VIRTIO_MMIO_INTERRUPT_STATUS_CONFIG_CHANGE,
2276            false,
2277        )
2278        .await;
2279        for i in 0..num_queues {
2280            guest.add_to_avail_queue(i);
2281            // notify device
2282            dev.write_u32(80, i as u32);
2283        }
2284        // check results
2285        for i in 0..num_queues {
2286            let (desc, len) = loop {
2287                if let Some(x) = guest.get_next_completed(i) {
2288                    break x;
2289                }
2290                expect_mmio_interrupt(
2291                    &mut dev,
2292                    &target,
2293                    VIRTIO_MMIO_INTERRUPT_STATUS_USED_BUFFER,
2294                    i < (num_queues - 1),
2295                )
2296                .await;
2297            };
2298            assert_eq!(desc, 0u16);
2299            assert_eq!(len, 123 * i as u32);
2300        }
2301        // verify no extraneous completions
2302        for i in 0..num_queues {
2303            assert_eq!(guest.get_next_completed(i).is_none(), true);
2304        }
2305        // reset the device
2306        dev.write_u32(112, 0);
2307        drop(dev);
2308    }
2309
2310    #[async_test]
2311    async fn verify_device_multi_queue_pci(driver: DefaultDriver) {
2312        let num_queues = 5;
2313        let test_mem = VirtioTestMemoryAccess::new();
2314        let mut guest = VirtioTestGuest::new(&driver, &test_mem, num_queues, 2, true);
2315        let features = ((VIRTIO_F_VERSION_1 as u64) << 32) | VIRTIO_F_RING_EVENT_IDX as u64 | 2;
2316        let base_addr: Vec<_> = (0..num_queues)
2317            .map(|i| guest.get_queue_descriptor_backing_memory_address(i))
2318            .collect();
2319        let mut dev = VirtioPciTestDevice::new(
2320            &driver,
2321            num_queues + 1,
2322            &test_mem,
2323            Some(Arc::new(move |i, mut work| {
2324                assert_eq!(work.payload.len(), 1);
2325                assert_eq!(work.payload[0].address, base_addr[i as usize]);
2326                assert_eq!(work.payload[0].length, 0x1000);
2327                work.complete(123 * i as u32);
2328            })),
2329        );
2330
2331        guest.setup_pci_device(&mut dev, features);
2332
2333        let mut timer = PolledTimer::new(&driver);
2334
2335        // expect a config generation interrupt
2336        timer.sleep(Duration::from_millis(100)).await;
2337        let delivered = dev.test_intc.get_next_interrupt().unwrap();
2338        assert_eq!(delivered.0, 0);
2339        assert!(dev.test_intc.get_next_interrupt().is_none());
2340
2341        for i in 0..num_queues {
2342            guest.add_to_avail_queue(i);
2343            // notify device
2344            dev.write_u32(0x10000000000 + 0x38, i as u32);
2345        }
2346        // verify all queue processing finished
2347        timer.sleep(Duration::from_millis(100)).await;
2348        for _ in 0..num_queues {
2349            let delivered = dev.test_intc.get_next_interrupt();
2350            assert!(delivered.is_some());
2351        }
2352        // check results
2353        for i in 0..num_queues {
2354            let (desc, len) = guest.get_next_completed(i).unwrap();
2355            assert_eq!(desc, 0u16);
2356            assert_eq!(len, 123 * i as u32);
2357        }
2358        // verify no extraneous completions
2359        for i in 0..num_queues {
2360            assert_eq!(guest.get_next_completed(i).is_none(), true);
2361        }
2362        // reset the device
2363        let device_status: u8 = 0;
2364        dev.pci_device
2365            .mmio_write(0x10000000000 + 20, &device_status.to_le_bytes())
2366            .unwrap();
2367        drop(dev);
2368    }
2369}