user_driver_emulated_mock/
lib.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! This crate provides a collection of wrapper structs around things like devices and memory. Through the wrappers, it provides functionality to emulate devices such
5//! as Nvme and Mana and gives some additional control over things like [`GuestMemory`] to make testing devices easier.
6//! Everything in this crate is meant for TESTING PURPOSES ONLY and it should only ever be added as a dev-dependency (Few expceptions like using this for fuzzing)
7
8mod guest_memory_access_wrapper;
9
10use crate::guest_memory_access_wrapper::GuestMemoryAccessWrapper;
11
12use anyhow::Context;
13use chipset_device::mmio::MmioIntercept;
14use chipset_device::pci::PciConfigSpace;
15use guestmem::GuestMemory;
16use inspect::Inspect;
17use inspect::InspectMut;
18use memory_range::MemoryRange;
19use page_pool_alloc::PagePool;
20use page_pool_alloc::PagePoolAllocator;
21use page_pool_alloc::TestMapper;
22use parking_lot::Mutex;
23use pci_core::chipset_device_ext::PciChipsetDeviceExt;
24use pci_core::msi::MsiConnection;
25use pci_core::msi::SignalMsi;
26use std::sync::Arc;
27use user_driver::DeviceBacking;
28use user_driver::DeviceRegisterIo;
29use user_driver::DmaClient;
30use user_driver::interrupt::DeviceInterrupt;
31use user_driver::interrupt::DeviceInterruptSource;
32use user_driver::memory::PAGE_SIZE64;
33
34/// A wrapper around any user_driver device T. It provides device emulation by providing access to the memory shared with the device and thus
35/// allowing the user to control device behaviour to a certain extent. Can be used with devices such as the `NvmeController`
36pub struct EmulatedDevice<T, U> {
37    device: Arc<Mutex<T>>,
38    controller: Arc<MsiController>,
39    dma_client: Arc<U>,
40    bar0_len: usize,
41}
42
43impl<T: InspectMut, U> Inspect for EmulatedDevice<T, U> {
44    fn inspect(&self, req: inspect::Request<'_>) {
45        self.device.lock().inspect_mut(req);
46    }
47}
48
49struct MsiController {
50    events: Arc<[DeviceInterruptSource]>,
51}
52
53impl MsiController {
54    fn new(n: usize) -> Self {
55        Self {
56            events: (0..n).map(|_| DeviceInterruptSource::new()).collect(),
57        }
58    }
59}
60
61impl SignalMsi for MsiController {
62    fn signal_msi(&self, rid: u32, address: u64, _data: u32) {
63        let index = address as usize;
64        if rid != 0 {
65            return;
66        }
67        if let Some(event) = self.events.get(index) {
68            tracing::debug!(index, "signaling interrupt");
69            event.signal_uncached();
70        } else {
71            tracing::info!("interrupt ignored");
72        }
73    }
74}
75
76impl<T: PciConfigSpace + MmioIntercept, U: DmaClient> Clone for EmulatedDevice<T, U> {
77    fn clone(&self) -> Self {
78        Self {
79            device: self.device.clone(),
80            controller: self.controller.clone(),
81            dma_client: self.dma_client.clone(),
82            bar0_len: self.bar0_len,
83        }
84    }
85}
86
87impl<T: PciConfigSpace + MmioIntercept, U: DmaClient> EmulatedDevice<T, U> {
88    /// Creates a new emulated device, wrapping `device` of type T, using the provided MSI Interrupt Set. Dma_client should point to memory
89    /// shared with the device.
90    pub fn new(mut device: T, msi_conn: MsiConnection, dma_client: Arc<U>) -> Self {
91        let bars = device.probe_bar_masks();
92        let bar0_len = !(bars[0] & !0xf) as usize + 1;
93
94        // Enable BAR0 at 0, BAR4 at X.
95        device.pci_cfg_write(0x20, 0).unwrap();
96        device.pci_cfg_write(0x24, 0x1).unwrap();
97        device
98            .pci_cfg_write(
99                0x4,
100                pci_core::spec::cfg_space::Command::new()
101                    .with_mmio_enabled(true)
102                    .into_bits() as u32,
103            )
104            .unwrap();
105
106        // Determine the number of MSI-X vectors.
107        let msix_table_size = {
108            let mut n = 0;
109            device.pci_cfg_read(0x40, &mut n).unwrap();
110            ((n >> 16) & 0x7ff) + 1
111        } as usize;
112
113        // Connect an interrupt controller.
114        let controller = Arc::new(MsiController::new(msix_table_size));
115        msi_conn.connect(controller.clone());
116
117        // Enable MSIX.
118        for i in 0u64..64 {
119            device
120                .mmio_write((0x1 << 32) + i * 16, &i.to_ne_bytes())
121                .unwrap();
122            device
123                .mmio_write((0x1 << 32) + i * 16 + 12, &0u32.to_ne_bytes())
124                .unwrap();
125        }
126        device.pci_cfg_write(0x40, 0x80000000).unwrap();
127
128        Self {
129            device: Arc::new(Mutex::new(device)),
130            controller,
131            dma_client,
132            bar0_len,
133        }
134    }
135}
136
137/// A memory mapping for an [`EmulatedDevice`].
138#[derive(Inspect)]
139pub struct Mapping<T> {
140    #[inspect(skip)]
141    device: Arc<Mutex<T>>,
142    addr: u64,
143    len: usize,
144}
145
146impl<T: 'static + Send + InspectMut + MmioIntercept, U: 'static + Send + DmaClient> DeviceBacking
147    for EmulatedDevice<T, U>
148{
149    type Registers = Mapping<T>;
150
151    fn id(&self) -> &str {
152        "emulated"
153    }
154
155    fn map_bar(&mut self, n: u8) -> anyhow::Result<Self::Registers> {
156        if n != 0 {
157            anyhow::bail!("invalid bar {n}");
158        }
159        Ok(Mapping {
160            device: self.device.clone(),
161            addr: (n as u64) << 32,
162            len: self.bar0_len,
163        })
164    }
165
166    fn dma_client(&self) -> Arc<dyn DmaClient> {
167        self.dma_client.clone()
168    }
169
170    fn dma_client_for(&self, _pool: user_driver::DmaPool) -> anyhow::Result<Arc<dyn DmaClient>> {
171        // In the emulated device, we only have one dma client.
172        Ok(self.dma_client.clone())
173    }
174
175    fn max_interrupt_count(&self) -> u32 {
176        self.controller.events.len() as u32
177    }
178
179    fn map_interrupt(&mut self, msix: u32, _cpu: u32) -> anyhow::Result<DeviceInterrupt> {
180        Ok(self
181            .controller
182            .events
183            .get(msix as usize)
184            .with_context(|| format!("invalid msix index {msix}"))?
185            .new_target())
186    }
187}
188
189impl<T: MmioIntercept + Send> DeviceRegisterIo for Mapping<T> {
190    fn len(&self) -> usize {
191        self.len
192    }
193
194    fn read_u32(&self, offset: usize) -> u32 {
195        let mut n = [0; 4];
196        self.device
197            .lock()
198            .mmio_read(self.addr + offset as u64, &mut n)
199            .unwrap();
200        u32::from_ne_bytes(n)
201    }
202
203    fn read_u64(&self, offset: usize) -> u64 {
204        let mut n = [0; 8];
205        self.device
206            .lock()
207            .mmio_read(self.addr + offset as u64, &mut n)
208            .unwrap();
209        u64::from_ne_bytes(n)
210    }
211
212    fn write_u32(&self, offset: usize, data: u32) {
213        self.device
214            .lock()
215            .mmio_write(self.addr + offset as u64, &data.to_ne_bytes())
216            .unwrap();
217    }
218
219    fn write_u64(&self, offset: usize, data: u64) {
220        self.device
221            .lock()
222            .mmio_write(self.addr + offset as u64, &data.to_ne_bytes())
223            .unwrap();
224    }
225}
226
227/// A wrapper around the [`TestMapper`] that generates both [`GuestMemory`] and [`PagePoolAllocator`] backed
228/// by the same underlying memory. Meant to provide shared memory for testing devices.
229pub struct DeviceTestMemory {
230    guest_mem: GuestMemory,
231    payload_mem: GuestMemory,
232    _pool: PagePool,
233    allocator: Arc<PagePoolAllocator>,
234}
235
236impl DeviceTestMemory {
237    /// Creates test memory that leverages the [`TestMapper`] as the backing. It creates 3 accessors for the underlying memory:
238    /// guest_memory [`GuestMemory`] - Has access to the entire range.
239    /// payload_memory [`GuestMemory`] - Has access to the second half of the range.
240    /// dma_client [`PagePoolAllocator`] - Has access to the first half of the range.
241    /// If the `allow_dma` switch is enabled, both guest_memory and payload_memory will report a base_iova of 0.
242    pub fn new(num_pages: u64, allow_dma: bool, pool_name: &str) -> Self {
243        let test_mapper = TestMapper::new(num_pages).unwrap();
244        let sparse_mmap = test_mapper.sparse_mapping();
245        let guest_mem = GuestMemoryAccessWrapper::create_test_guest_memory(sparse_mmap, allow_dma);
246        let pool = PagePool::new(
247            &[MemoryRange::from_4k_gpn_range(0..num_pages / 2)],
248            test_mapper,
249        )
250        .unwrap();
251
252        // Save page pool so that it is not dropped.
253        let allocator = pool.allocator(pool_name.into()).unwrap();
254        let range_half = num_pages / 2 * PAGE_SIZE64;
255        Self {
256            guest_mem: guest_mem.clone(),
257            payload_mem: guest_mem.subrange(range_half, range_half, false).unwrap(),
258            _pool: pool,
259            allocator: Arc::new(allocator),
260        }
261    }
262
263    /// Returns [`GuestMemory`] accessor to the underlying memory. Reports base_iova as 0 if `allow_dma` switch is enabled.
264    pub fn guest_memory(&self) -> GuestMemory {
265        self.guest_mem.clone()
266    }
267
268    /// Returns [`GuestMemory`] accessor to the second half of underlying memory. Reports base_iova as 0 if `allow_dma` switch is enabled.
269    pub fn payload_mem(&self) -> GuestMemory {
270        self.payload_mem.clone()
271    }
272
273    /// Returns [`PagePoolAllocator`] with access to the first half of the underlying memory.
274    pub fn dma_client(&self) -> Arc<PagePoolAllocator> {
275        self.allocator.clone()
276    }
277}
278
279/// Callbacks for the [`DeviceTestDmaClient`]. Tests supply these to customize the behaviour of the dma client.
280pub trait DeviceTestDmaClientCallbacks: Sync + Send {
281    /// Called when the DMA client needs to allocate a new DMA buffer.
282    fn allocate_dma_buffer(
283        &self,
284        allocator: &PagePoolAllocator,
285        total_size: usize,
286    ) -> anyhow::Result<user_driver::memory::MemoryBlock>;
287
288    /// Called when the DMA client needs to attach pending buffers.
289    fn attach_pending_buffers(
290        &self,
291        inner: &PagePoolAllocator,
292    ) -> anyhow::Result<Vec<user_driver::memory::MemoryBlock>>;
293}
294
295/// A DMA client that uses a [`PagePoolAllocator`] as the backing. It can be customized through the use of
296/// [`DeviceTestDmaClientCallbacks`] to modify its behaviour for testing purposes.
297///
298/// # Example
299/// ```rust
300/// use std::sync::Arc;
301/// use user_driver::DmaClient;
302/// use user_driver_emulated_mock::DeviceTestDmaClient;
303/// use page_pool_alloc::PagePoolAllocator;
304///
305/// struct MyCallbacks;
306/// impl user_driver_emulated_mock::DeviceTestDmaClientCallbacks for MyCallbacks {
307///     fn allocate_dma_buffer(
308///         &self,
309///         allocator: &page_pool_alloc::PagePoolAllocator,
310///         total_size: usize,
311///     ) -> anyhow::Result<user_driver::memory::MemoryBlock> {
312///         // Custom test logic here, for example:
313///         anyhow::bail!("allocation failed for testing");
314///     }
315///
316///     fn attach_pending_buffers(
317///         &self,
318///         allocator: &page_pool_alloc::PagePoolAllocator,
319///     ) -> anyhow::Result<Vec<user_driver::memory::MemoryBlock>> {
320///         // Custom test logic here, for example:
321///         anyhow::bail!("attachment failed for testing");
322///     }
323/// }
324///
325/// // Use the above in a test ...
326/// fn test_dma_client() {
327///     let pages = 1000;
328///     let device_test_memory = user_driver_emulated_mock::DeviceTestMemory::new(
329///         pages,
330///         true,
331///         "test_dma_client",
332///     );
333///     let page_pool_allocator = device_test_memory.dma_client();
334///     let dma_client = DeviceTestDmaClient::new(page_pool_allocator, MyCallbacks);
335///
336///     // Use dma_client in tests...
337///     assert!(dma_client.allocate_dma_buffer(4096).is_err());
338/// }
339/// ```
340#[derive(Inspect)]
341#[inspect(transparent)]
342pub struct DeviceTestDmaClient<C>
343where
344    C: DeviceTestDmaClientCallbacks,
345{
346    inner: Arc<PagePoolAllocator>,
347    #[inspect(skip)]
348    callbacks: C,
349}
350
351impl<C: DeviceTestDmaClientCallbacks> DeviceTestDmaClient<C> {
352    /// Creates a new [`DeviceTestDmaClient`] with the given inner allocator.
353    pub fn new(inner: Arc<PagePoolAllocator>, callbacks: C) -> Self {
354        Self { inner, callbacks }
355    }
356}
357
358impl<C: DeviceTestDmaClientCallbacks> DmaClient for DeviceTestDmaClient<C> {
359    fn allocate_dma_buffer(
360        &self,
361        total_size: usize,
362    ) -> anyhow::Result<user_driver::memory::MemoryBlock> {
363        self.callbacks.allocate_dma_buffer(&self.inner, total_size)
364    }
365
366    fn attach_pending_buffers(&self) -> anyhow::Result<Vec<user_driver::memory::MemoryBlock>> {
367        self.callbacks.attach_pending_buffers(&self.inner)
368    }
369}