user_driver_emulated_mock/
lib.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! This crate provides a collection of wrapper structs around things like devices and memory. Through the wrappers, it provides functionality to emulate devices such
5//! as Nvme and Mana and gives some additional control over things like [`GuestMemory`] to make testing devices easier.
6//! Everything in this crate is meant for TESTING PURPOSES ONLY and it should only ever be added as a dev-dependency (Few expceptions like using this for fuzzing)
7
8mod guest_memory_access_wrapper;
9
10use crate::guest_memory_access_wrapper::GuestMemoryAccessWrapper;
11
12use anyhow::Context;
13use chipset_device::mmio::MmioIntercept;
14use chipset_device::pci::PciConfigSpace;
15use guestmem::GuestMemory;
16use inspect::Inspect;
17use inspect::InspectMut;
18use memory_range::MemoryRange;
19use page_pool_alloc::PagePool;
20use page_pool_alloc::PagePoolAllocator;
21use page_pool_alloc::TestMapper;
22use parking_lot::Mutex;
23use pci_core::chipset_device_ext::PciChipsetDeviceExt;
24use pci_core::msi::MsiControl;
25use pci_core::msi::MsiInterruptSet;
26use pci_core::msi::MsiInterruptTarget;
27use std::sync::Arc;
28use user_driver::DeviceBacking;
29use user_driver::DeviceRegisterIo;
30use user_driver::DmaClient;
31use user_driver::interrupt::DeviceInterrupt;
32use user_driver::interrupt::DeviceInterruptSource;
33use user_driver::memory::PAGE_SIZE64;
34
35/// A wrapper around any user_driver device T. It provides device emulation by providing access to the memory shared with the device and thus
36/// allowing the user to control device behaviour to a certain extent. Can be used with devices such as the `NvmeController`
37pub struct EmulatedDevice<T, U> {
38    device: Arc<Mutex<T>>,
39    controller: Arc<MsiController>,
40    dma_client: Arc<U>,
41    bar0_len: usize,
42}
43
44impl<T: InspectMut, U> Inspect for EmulatedDevice<T, U> {
45    fn inspect(&self, req: inspect::Request<'_>) {
46        self.device.lock().inspect_mut(req);
47    }
48}
49
50struct MsiController {
51    events: Arc<[DeviceInterruptSource]>,
52}
53
54impl MsiController {
55    fn new(n: usize) -> Self {
56        Self {
57            events: (0..n).map(|_| DeviceInterruptSource::new()).collect(),
58        }
59    }
60}
61
62impl MsiInterruptTarget for MsiController {
63    fn new_interrupt(&self) -> Box<dyn MsiControl> {
64        let events = self.events.clone();
65        Box::new(move |address, _data| {
66            let index = address as usize;
67            if let Some(event) = events.get(index) {
68                tracing::debug!(index, "signaling interrupt");
69                event.signal_uncached();
70            } else {
71                tracing::info!("interrupt ignored");
72            }
73        })
74    }
75}
76
77impl<T: PciConfigSpace + MmioIntercept, U: DmaClient> Clone for EmulatedDevice<T, U> {
78    fn clone(&self) -> Self {
79        Self {
80            device: self.device.clone(),
81            controller: self.controller.clone(),
82            dma_client: self.dma_client.clone(),
83            bar0_len: self.bar0_len,
84        }
85    }
86}
87
88impl<T: PciConfigSpace + MmioIntercept, U: DmaClient> EmulatedDevice<T, U> {
89    /// Creates a new emulated device, wrapping `device` of type T, using the provided MSI Interrupt Set. Dma_client should point to memory
90    /// shared with the device.
91    pub fn new(mut device: T, msi_set: MsiInterruptSet, dma_client: Arc<U>) -> Self {
92        // Connect an interrupt controller.
93        let controller = MsiController::new(msi_set.len());
94        msi_set.connect(&controller);
95        let controller = Arc::new(controller);
96
97        let bars = device.probe_bar_masks();
98        let bar0_len = !(bars[0] & !0xf) as usize + 1;
99
100        // Enable BAR0 at 0, BAR4 at X.
101        device.pci_cfg_write(0x20, 0).unwrap();
102        device.pci_cfg_write(0x24, 0x1).unwrap();
103        device
104            .pci_cfg_write(
105                0x4,
106                pci_core::spec::cfg_space::Command::new()
107                    .with_mmio_enabled(true)
108                    .into_bits() as u32,
109            )
110            .unwrap();
111
112        // Enable MSIX.
113        for i in 0u64..64 {
114            device
115                .mmio_write((0x1 << 32) + i * 16, &i.to_ne_bytes())
116                .unwrap();
117            device
118                .mmio_write((0x1 << 32) + i * 16 + 12, &0u32.to_ne_bytes())
119                .unwrap();
120        }
121        device.pci_cfg_write(0x40, 0x80000000).unwrap();
122
123        Self {
124            device: Arc::new(Mutex::new(device)),
125            controller,
126            dma_client,
127            bar0_len,
128        }
129    }
130}
131
132/// A memory mapping for an [`EmulatedDevice`].
133#[derive(Inspect)]
134pub struct Mapping<T> {
135    #[inspect(skip)]
136    device: Arc<Mutex<T>>,
137    addr: u64,
138    len: usize,
139}
140
141impl<T: 'static + Send + InspectMut + MmioIntercept, U: 'static + Send + DmaClient> DeviceBacking
142    for EmulatedDevice<T, U>
143{
144    type Registers = Mapping<T>;
145
146    fn id(&self) -> &str {
147        "emulated"
148    }
149
150    fn map_bar(&mut self, n: u8) -> anyhow::Result<Self::Registers> {
151        if n != 0 {
152            anyhow::bail!("invalid bar {n}");
153        }
154        Ok(Mapping {
155            device: self.device.clone(),
156            addr: (n as u64) << 32,
157            len: self.bar0_len,
158        })
159    }
160
161    fn dma_client(&self) -> Arc<dyn DmaClient> {
162        self.dma_client.clone()
163    }
164
165    fn dma_client_for(&self, _pool: user_driver::DmaPool) -> anyhow::Result<Arc<dyn DmaClient>> {
166        // In the emulated device, we only have one dma client.
167        Ok(self.dma_client.clone())
168    }
169
170    fn max_interrupt_count(&self) -> u32 {
171        self.controller.events.len() as u32
172    }
173
174    fn map_interrupt(&mut self, msix: u32, _cpu: u32) -> anyhow::Result<DeviceInterrupt> {
175        Ok(self
176            .controller
177            .events
178            .get(msix as usize)
179            .with_context(|| format!("invalid msix index {msix}"))?
180            .new_target())
181    }
182}
183
184impl<T: MmioIntercept + Send> DeviceRegisterIo for Mapping<T> {
185    fn len(&self) -> usize {
186        self.len
187    }
188
189    fn read_u32(&self, offset: usize) -> u32 {
190        let mut n = [0; 4];
191        self.device
192            .lock()
193            .mmio_read(self.addr + offset as u64, &mut n)
194            .unwrap();
195        u32::from_ne_bytes(n)
196    }
197
198    fn read_u64(&self, offset: usize) -> u64 {
199        let mut n = [0; 8];
200        self.device
201            .lock()
202            .mmio_read(self.addr + offset as u64, &mut n)
203            .unwrap();
204        u64::from_ne_bytes(n)
205    }
206
207    fn write_u32(&self, offset: usize, data: u32) {
208        self.device
209            .lock()
210            .mmio_write(self.addr + offset as u64, &data.to_ne_bytes())
211            .unwrap();
212    }
213
214    fn write_u64(&self, offset: usize, data: u64) {
215        self.device
216            .lock()
217            .mmio_write(self.addr + offset as u64, &data.to_ne_bytes())
218            .unwrap();
219    }
220}
221
222/// A wrapper around the [`TestMapper`] that generates both [`GuestMemory`] and [`PagePoolAllocator`] backed
223/// by the same underlying memory. Meant to provide shared memory for testing devices.
224pub struct DeviceTestMemory {
225    guest_mem: GuestMemory,
226    payload_mem: GuestMemory,
227    _pool: PagePool,
228    allocator: Arc<PagePoolAllocator>,
229}
230
231impl DeviceTestMemory {
232    /// Creates test memory that leverages the [`TestMapper`] as the backing. It creates 3 accessors for the underlying memory:
233    /// guest_memory [`GuestMemory`] - Has access to the entire range.
234    /// payload_memory [`GuestMemory`] - Has access to the second half of the range.
235    /// dma_client [`PagePoolAllocator`] - Has access to the first half of the range.
236    /// If the `allow_dma` switch is enabled, both guest_memory and payload_memory will report a base_iova of 0.
237    pub fn new(num_pages: u64, allow_dma: bool, pool_name: &str) -> Self {
238        let test_mapper = TestMapper::new(num_pages).unwrap();
239        let sparse_mmap = test_mapper.sparse_mapping();
240        let guest_mem = GuestMemoryAccessWrapper::create_test_guest_memory(sparse_mmap, allow_dma);
241        let pool = PagePool::new(
242            &[MemoryRange::from_4k_gpn_range(0..num_pages / 2)],
243            test_mapper,
244        )
245        .unwrap();
246
247        // Save page pool so that it is not dropped.
248        let allocator = pool.allocator(pool_name.into()).unwrap();
249        let range_half = num_pages / 2 * PAGE_SIZE64;
250        Self {
251            guest_mem: guest_mem.clone(),
252            payload_mem: guest_mem.subrange(range_half, range_half, false).unwrap(),
253            _pool: pool,
254            allocator: Arc::new(allocator),
255        }
256    }
257
258    /// Returns [`GuestMemory`] accessor to the underlying memory. Reports base_iova as 0 if `allow_dma` switch is enabled.
259    pub fn guest_memory(&self) -> GuestMemory {
260        self.guest_mem.clone()
261    }
262
263    /// Returns [`GuestMemory`] accessor to the second half of underlying memory. Reports base_iova as 0 if `allow_dma` switch is enabled.
264    pub fn payload_mem(&self) -> GuestMemory {
265        self.payload_mem.clone()
266    }
267
268    /// Returns [`PagePoolAllocator`] with access to the first half of the underlying memory.
269    pub fn dma_client(&self) -> Arc<PagePoolAllocator> {
270        self.allocator.clone()
271    }
272}
273
274/// Callbacks for the [`DeviceTestDmaClient`]. Tests supply these to customize the behaviour of the dma client.
275pub trait DeviceTestDmaClientCallbacks: Sync + Send {
276    /// Called when the DMA client needs to allocate a new DMA buffer.
277    fn allocate_dma_buffer(
278        &self,
279        allocator: &PagePoolAllocator,
280        total_size: usize,
281    ) -> anyhow::Result<user_driver::memory::MemoryBlock>;
282
283    /// Called when the DMA client needs to attach pending buffers.
284    fn attach_pending_buffers(
285        &self,
286        inner: &PagePoolAllocator,
287    ) -> anyhow::Result<Vec<user_driver::memory::MemoryBlock>>;
288}
289
290/// A DMA client that uses a [`PagePoolAllocator`] as the backing. It can be customized through the use of
291/// [`DeviceTestDmaClientCallbacks`] to modify its behaviour for testing purposes.
292///
293/// # Example
294/// ```rust
295/// use std::sync::Arc;
296/// use user_driver::DmaClient;
297/// use user_driver_emulated_mock::DeviceTestDmaClient;
298/// use page_pool_alloc::PagePoolAllocator;
299///
300/// struct MyCallbacks;
301/// impl user_driver_emulated_mock::DeviceTestDmaClientCallbacks for MyCallbacks {
302///     fn allocate_dma_buffer(
303///         &self,
304///         allocator: &page_pool_alloc::PagePoolAllocator,
305///         total_size: usize,
306///     ) -> anyhow::Result<user_driver::memory::MemoryBlock> {
307///         // Custom test logic here, for example:
308///         anyhow::bail!("allocation failed for testing");
309///     }
310///
311///     fn attach_pending_buffers(
312///         &self,
313///         allocator: &page_pool_alloc::PagePoolAllocator,
314///     ) -> anyhow::Result<Vec<user_driver::memory::MemoryBlock>> {
315///         // Custom test logic here, for example:
316///         anyhow::bail!("attachment failed for testing");
317///     }
318/// }
319///
320/// // Use the above in a test ...
321/// fn test_dma_client() {
322///     let pages = 1000;
323///     let device_test_memory = user_driver_emulated_mock::DeviceTestMemory::new(
324///         pages,
325///         true,
326///         "test_dma_client",
327///     );
328///     let page_pool_allocator = device_test_memory.dma_client();
329///     let dma_client = DeviceTestDmaClient::new(page_pool_allocator, MyCallbacks);
330///
331///     // Use dma_client in tests...
332///     assert!(dma_client.allocate_dma_buffer(4096).is_err());
333/// }
334/// ```
335#[derive(Inspect)]
336#[inspect(transparent)]
337pub struct DeviceTestDmaClient<C>
338where
339    C: DeviceTestDmaClientCallbacks,
340{
341    inner: Arc<PagePoolAllocator>,
342    #[inspect(skip)]
343    callbacks: C,
344}
345
346impl<C: DeviceTestDmaClientCallbacks> DeviceTestDmaClient<C> {
347    /// Creates a new [`DeviceTestDmaClient`] with the given inner allocator.
348    pub fn new(inner: Arc<PagePoolAllocator>, callbacks: C) -> Self {
349        Self { inner, callbacks }
350    }
351}
352
353impl<C: DeviceTestDmaClientCallbacks> DmaClient for DeviceTestDmaClient<C> {
354    fn allocate_dma_buffer(
355        &self,
356        total_size: usize,
357    ) -> anyhow::Result<user_driver::memory::MemoryBlock> {
358        self.callbacks.allocate_dma_buffer(&self.inner, total_size)
359    }
360
361    fn attach_pending_buffers(&self) -> anyhow::Result<Vec<user_driver::memory::MemoryBlock>> {
362        self.callbacks.attach_pending_buffers(&self.inner)
363    }
364}