user_driver_emulated_mock/
lib.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! This crate provides a collection of wrapper structs around things like devices and memory. Through the wrappers, it provides functionality to emulate devices such
5//! as Nvme and Mana and gives some additional control over things like [`GuestMemory`] to make testing devices easier.
6//! Everything in this crate is meant for TESTING PURPOSES ONLY and it should only ever be added as a dev-dependency (Few expceptions like using this for fuzzing)
7
8mod guest_memory_access_wrapper;
9
10use crate::guest_memory_access_wrapper::GuestMemoryAccessWrapper;
11
12use anyhow::Context;
13use chipset_device::mmio::MmioIntercept;
14use chipset_device::pci::PciConfigSpace;
15use guestmem::GuestMemory;
16use inspect::Inspect;
17use inspect::InspectMut;
18use memory_range::MemoryRange;
19use page_pool_alloc::PagePool;
20use page_pool_alloc::PagePoolAllocator;
21use page_pool_alloc::TestMapper;
22use parking_lot::Mutex;
23use pci_core::chipset_device_ext::PciChipsetDeviceExt;
24use pci_core::msi::MsiControl;
25use pci_core::msi::MsiInterruptSet;
26use pci_core::msi::MsiInterruptTarget;
27use std::sync::Arc;
28use user_driver::DeviceBacking;
29use user_driver::DeviceRegisterIo;
30use user_driver::DmaClient;
31use user_driver::interrupt::DeviceInterrupt;
32use user_driver::interrupt::DeviceInterruptSource;
33use user_driver::memory::PAGE_SIZE64;
34
35/// A wrapper around any user_driver device T. It provides device emulation by providing access to the memory shared with the device and thus
36/// allowing the user to control device behaviour to a certain extent. Can be used with devices such as the `NvmeController`
37pub struct EmulatedDevice<T, U> {
38    device: Arc<Mutex<T>>,
39    controller: Arc<MsiController>,
40    dma_client: Arc<U>,
41    bar0_len: usize,
42}
43
44impl<T: InspectMut, U> Inspect for EmulatedDevice<T, U> {
45    fn inspect(&self, req: inspect::Request<'_>) {
46        self.device.lock().inspect_mut(req);
47    }
48}
49
50struct MsiController {
51    events: Arc<[DeviceInterruptSource]>,
52}
53
54impl MsiController {
55    fn new(n: usize) -> Self {
56        Self {
57            events: (0..n).map(|_| DeviceInterruptSource::new()).collect(),
58        }
59    }
60}
61
62impl MsiInterruptTarget for MsiController {
63    fn new_interrupt(&self) -> Box<dyn MsiControl> {
64        let events = self.events.clone();
65        Box::new(move |address, _data| {
66            let index = address as usize;
67            if let Some(event) = events.get(index) {
68                tracing::debug!(index, "signaling interrupt");
69                event.signal_uncached();
70            } else {
71                tracing::info!("interrupt ignored");
72            }
73        })
74    }
75}
76
77impl<T: PciConfigSpace + MmioIntercept, U: DmaClient> Clone for EmulatedDevice<T, U> {
78    fn clone(&self) -> Self {
79        Self {
80            device: self.device.clone(),
81            controller: self.controller.clone(),
82            dma_client: self.dma_client.clone(),
83            bar0_len: self.bar0_len,
84        }
85    }
86}
87
88impl<T: PciConfigSpace + MmioIntercept, U: DmaClient> EmulatedDevice<T, U> {
89    /// Creates a new emulated device, wrapping `device` of type T, using the provided MSI Interrupt Set. Dma_client should point to memory
90    /// shared with the device.
91    pub fn new(mut device: T, msi_set: MsiInterruptSet, dma_client: Arc<U>) -> Self {
92        // Connect an interrupt controller.
93        let controller = MsiController::new(msi_set.len());
94        msi_set.connect(&controller);
95        let controller = Arc::new(controller);
96
97        let bars = device.probe_bar_masks();
98        let bar0_len = !(bars[0] & !0xf) as usize + 1;
99
100        // Enable BAR0 at 0, BAR4 at X.
101        device.pci_cfg_write(0x20, 0).unwrap();
102        device.pci_cfg_write(0x24, 0x1).unwrap();
103        device
104            .pci_cfg_write(
105                0x4,
106                pci_core::spec::cfg_space::Command::new()
107                    .with_mmio_enabled(true)
108                    .into_bits() as u32,
109            )
110            .unwrap();
111
112        // Enable MSIX.
113        for i in 0u64..64 {
114            device
115                .mmio_write((0x1 << 32) + i * 16, &i.to_ne_bytes())
116                .unwrap();
117            device
118                .mmio_write((0x1 << 32) + i * 16 + 12, &0u32.to_ne_bytes())
119                .unwrap();
120        }
121        device.pci_cfg_write(0x40, 0x80000000).unwrap();
122
123        Self {
124            device: Arc::new(Mutex::new(device)),
125            controller,
126            dma_client,
127            bar0_len,
128        }
129    }
130}
131
132/// A memory mapping for an [`EmulatedDevice`].
133#[derive(Inspect)]
134pub struct Mapping<T> {
135    #[inspect(skip)]
136    device: Arc<Mutex<T>>,
137    addr: u64,
138    len: usize,
139}
140
141impl<T: 'static + Send + InspectMut + MmioIntercept, U: 'static + Send + DmaClient> DeviceBacking
142    for EmulatedDevice<T, U>
143{
144    type Registers = Mapping<T>;
145
146    fn id(&self) -> &str {
147        "emulated"
148    }
149
150    fn map_bar(&mut self, n: u8) -> anyhow::Result<Self::Registers> {
151        if n != 0 {
152            anyhow::bail!("invalid bar {n}");
153        }
154        Ok(Mapping {
155            device: self.device.clone(),
156            addr: (n as u64) << 32,
157            len: self.bar0_len,
158        })
159    }
160
161    fn dma_client(&self) -> Arc<dyn DmaClient> {
162        self.dma_client.clone()
163    }
164
165    fn max_interrupt_count(&self) -> u32 {
166        self.controller.events.len() as u32
167    }
168
169    fn map_interrupt(&mut self, msix: u32, _cpu: u32) -> anyhow::Result<DeviceInterrupt> {
170        Ok(self
171            .controller
172            .events
173            .get(msix as usize)
174            .with_context(|| format!("invalid msix index {msix}"))?
175            .new_target())
176    }
177}
178
179impl<T: MmioIntercept + Send> DeviceRegisterIo for Mapping<T> {
180    fn len(&self) -> usize {
181        self.len
182    }
183
184    fn read_u32(&self, offset: usize) -> u32 {
185        let mut n = [0; 4];
186        self.device
187            .lock()
188            .mmio_read(self.addr + offset as u64, &mut n)
189            .unwrap();
190        u32::from_ne_bytes(n)
191    }
192
193    fn read_u64(&self, offset: usize) -> u64 {
194        let mut n = [0; 8];
195        self.device
196            .lock()
197            .mmio_read(self.addr + offset as u64, &mut n)
198            .unwrap();
199        u64::from_ne_bytes(n)
200    }
201
202    fn write_u32(&self, offset: usize, data: u32) {
203        self.device
204            .lock()
205            .mmio_write(self.addr + offset as u64, &data.to_ne_bytes())
206            .unwrap();
207    }
208
209    fn write_u64(&self, offset: usize, data: u64) {
210        self.device
211            .lock()
212            .mmio_write(self.addr + offset as u64, &data.to_ne_bytes())
213            .unwrap();
214    }
215}
216
217/// A wrapper around the [`TestMapper`] that generates both [`GuestMemory`] and [`PagePoolAllocator`] backed
218/// by the same underlying memory. Meant to provide shared memory for testing devices.
219pub struct DeviceTestMemory {
220    guest_mem: GuestMemory,
221    payload_mem: GuestMemory,
222    _pool: PagePool,
223    allocator: Arc<PagePoolAllocator>,
224}
225
226impl DeviceTestMemory {
227    /// Creates test memory that leverages the [`TestMapper`] as the backing. It creates 3 accessors for the underlying memory:
228    /// guest_memory [`GuestMemory`] - Has access to the entire range.
229    /// payload_memory [`GuestMemory`] - Has access to the second half of the range.
230    /// dma_client [`PagePoolAllocator`] - Has access to the first half of the range.
231    /// If the `allow_dma` switch is enabled, both guest_memory and payload_memory will report a base_iova of 0.
232    pub fn new(num_pages: u64, allow_dma: bool, pool_name: &str) -> Self {
233        let test_mapper = TestMapper::new(num_pages).unwrap();
234        let sparse_mmap = test_mapper.sparse_mapping();
235        let guest_mem = GuestMemoryAccessWrapper::create_test_guest_memory(sparse_mmap, allow_dma);
236        let pool = PagePool::new(
237            &[MemoryRange::from_4k_gpn_range(0..num_pages / 2)],
238            test_mapper,
239        )
240        .unwrap();
241
242        // Save page pool so that it is not dropped.
243        let allocator = pool.allocator(pool_name.into()).unwrap();
244        let range_half = num_pages / 2 * PAGE_SIZE64;
245        Self {
246            guest_mem: guest_mem.clone(),
247            payload_mem: guest_mem.subrange(range_half, range_half, false).unwrap(),
248            _pool: pool,
249            allocator: Arc::new(allocator),
250        }
251    }
252
253    /// Returns [`GuestMemory`] accessor to the underlying memory. Reports base_iova as 0 if `allow_dma` switch is enabled.
254    pub fn guest_memory(&self) -> GuestMemory {
255        self.guest_mem.clone()
256    }
257
258    /// Returns [`GuestMemory`] accessor to the second half of underlying memory. Reports base_iova as 0 if `allow_dma` switch is enabled.
259    pub fn payload_mem(&self) -> GuestMemory {
260        self.payload_mem.clone()
261    }
262
263    /// Returns [`PagePoolAllocator`] with access to the first half of the underlying memory.
264    pub fn dma_client(&self) -> Arc<PagePoolAllocator> {
265        self.allocator.clone()
266    }
267}
268
269/// Callbacks for the [`DeviceTestDmaClient`]. Tests supply these to customize the behaviour of the dma client.
270pub trait DeviceTestDmaClientCallbacks: Sync + Send {
271    /// Called when the DMA client needs to allocate a new DMA buffer.
272    fn allocate_dma_buffer(
273        &self,
274        allocator: &PagePoolAllocator,
275        total_size: usize,
276    ) -> anyhow::Result<user_driver::memory::MemoryBlock>;
277
278    /// Called when the DMA client needs to attach pending buffers.
279    fn attach_pending_buffers(
280        &self,
281        inner: &PagePoolAllocator,
282    ) -> anyhow::Result<Vec<user_driver::memory::MemoryBlock>>;
283}
284
285/// A DMA client that uses a [`PagePoolAllocator`] as the backing. It can be customized through the use of
286/// [`DeviceTestDmaClientCallbacks`] to modify its behaviour for testing purposes.
287///
288/// # Example
289/// ```rust
290/// use std::sync::Arc;
291/// use user_driver::DmaClient;
292/// use user_driver_emulated_mock::DeviceTestDmaClient;
293/// use page_pool_alloc::PagePoolAllocator;
294///
295/// struct MyCallbacks;
296/// impl user_driver_emulated_mock::DeviceTestDmaClientCallbacks for MyCallbacks {
297///     fn allocate_dma_buffer(
298///         &self,
299///         allocator: &page_pool_alloc::PagePoolAllocator,
300///         total_size: usize,
301///     ) -> anyhow::Result<user_driver::memory::MemoryBlock> {
302///         // Custom test logic here, for example:
303///         anyhow::bail!("allocation failed for testing");
304///     }
305///
306///     fn attach_pending_buffers(
307///         &self,
308///         allocator: &page_pool_alloc::PagePoolAllocator,
309///     ) -> anyhow::Result<Vec<user_driver::memory::MemoryBlock>> {
310///         // Custom test logic here, for example:
311///         anyhow::bail!("attachment failed for testing");
312///     }
313/// }
314///
315/// // Use the above in a test ...
316/// fn test_dma_client() {
317///     let pages = 1000;
318///     let device_test_memory = user_driver_emulated_mock::DeviceTestMemory::new(
319///         pages,
320///         true,
321///         "test_dma_client",
322///     );
323///     let page_pool_allocator = device_test_memory.dma_client();
324///     let dma_client = DeviceTestDmaClient::new(page_pool_allocator, MyCallbacks);
325///
326///     // Use dma_client in tests...
327///     assert!(dma_client.allocate_dma_buffer(4096).is_err());
328/// }
329/// ```
330#[derive(Inspect)]
331#[inspect(transparent)]
332pub struct DeviceTestDmaClient<C>
333where
334    C: DeviceTestDmaClientCallbacks,
335{
336    inner: Arc<PagePoolAllocator>,
337    #[inspect(skip)]
338    callbacks: C,
339}
340
341impl<C: DeviceTestDmaClientCallbacks> DeviceTestDmaClient<C> {
342    /// Creates a new [`DeviceTestDmaClient`] with the given inner allocator.
343    pub fn new(inner: Arc<PagePoolAllocator>, callbacks: C) -> Self {
344        Self { inner, callbacks }
345    }
346}
347
348impl<C: DeviceTestDmaClientCallbacks> DmaClient for DeviceTestDmaClient<C> {
349    fn allocate_dma_buffer(
350        &self,
351        total_size: usize,
352    ) -> anyhow::Result<user_driver::memory::MemoryBlock> {
353        self.callbacks.allocate_dma_buffer(&self.inner, total_size)
354    }
355
356    fn attach_pending_buffers(&self) -> anyhow::Result<Vec<user_driver::memory::MemoryBlock>> {
357        self.callbacks.attach_pending_buffers(&self.inner)
358    }
359}