user_driver/
memory.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! Traits and types for sharing host memory with the device.
5
6use safeatomic::AtomicSliceOps;
7use std::sync::Arc;
8use std::sync::atomic::AtomicU8;
9use zerocopy::FromBytes;
10use zerocopy::Immutable;
11use zerocopy::IntoBytes;
12use zerocopy::KnownLayout;
13
14/// The 4KB page size used by user-mode devices.
15pub const PAGE_SIZE: usize = 4096;
16pub const PAGE_SIZE32: u32 = 4096;
17pub const PAGE_SIZE64: u64 = PAGE_SIZE as u64;
18
19/// A mapped buffer that can be accessed by the host or the device.
20///
21/// # Safety
22/// The implementor must ensure that the VA region from `base()..base() + len()`
23/// remains mapped for the lifetime.
24pub unsafe trait MappedDmaTarget: Send + Sync {
25    /// The virtual address of the mapped memory.
26    fn base(&self) -> *const u8;
27
28    /// The length of the buffer in bytes.
29    fn len(&self) -> usize;
30
31    /// 4KB page numbers used to refer to the memory when communicating with the
32    /// device.
33    fn pfns(&self) -> &[u64];
34
35    /// The pfn_bias on confidential platforms (aka vTOM) applied to PFNs in [`Self::pfns()`],
36    fn pfn_bias(&self) -> u64;
37
38    /// Returns a view of a subset of the buffer.
39    ///
40    /// Returns `None` if the default implementation should be used.
41    ///
42    /// This should not be implemented except by internal implementations.
43    #[doc(hidden)]
44    fn view(&self, offset: usize, len: usize) -> Option<MemoryBlock> {
45        let _ = (offset, len);
46        None
47    }
48}
49
50struct RestrictedView {
51    mem: Arc<dyn MappedDmaTarget>,
52    len: usize,
53    offset: usize,
54}
55
56impl RestrictedView {
57    /// Wraps `mem` and provides a restricted view of it.
58    fn new(mem: Arc<dyn MappedDmaTarget>, offset: usize, len: usize) -> Self {
59        let mem_len = mem.len();
60        assert!(mem_len >= offset && mem_len - offset >= len);
61        Self { len, offset, mem }
62    }
63}
64
65// SAFETY: Passing through to the underlying impl after restricting the bounds
66// (which were validated in `new`).
67unsafe impl MappedDmaTarget for RestrictedView {
68    fn base(&self) -> *const u8 {
69        // SAFETY: verified in `new` to be in bounds.
70        unsafe { self.mem.base().add(self.offset) }
71    }
72
73    fn len(&self) -> usize {
74        self.len
75    }
76
77    fn pfns(&self) -> &[u64] {
78        let start = self.offset / PAGE_SIZE;
79        let count = (self.base() as usize % PAGE_SIZE + self.len + 0xfff) / PAGE_SIZE;
80        let pages = self.mem.pfns();
81        &pages[start..][..count]
82    }
83
84    fn pfn_bias(&self) -> u64 {
85        self.mem.pfn_bias()
86    }
87
88    fn view(&self, offset: usize, len: usize) -> Option<MemoryBlock> {
89        Some(MemoryBlock::new(RestrictedView::new(
90            self.mem.clone(),
91            self.offset.checked_add(offset).unwrap(),
92            len,
93        )))
94    }
95}
96
97/// A DMA target.
98#[derive(Clone)]
99pub struct MemoryBlock {
100    base: *const u8,
101    len: usize,
102    mem: Arc<dyn MappedDmaTarget>,
103}
104
105impl std::fmt::Debug for MemoryBlock {
106    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
107        f.debug_struct("MemoryBlock")
108            .field("base", &self.base)
109            .field("len", &self.len)
110            .field("pfns", &self.pfns())
111            .field("pfn_bias", &self.pfn_bias())
112            .finish()
113    }
114}
115
116// SAFETY: The inner MappedDmaTarget is Send + Sync, so a view of it is too.
117unsafe impl Send for MemoryBlock {}
118// SAFETY: The inner MappedDmaTarget is Send + Sync, so a view of it is too.
119unsafe impl Sync for MemoryBlock {}
120
121impl MemoryBlock {
122    /// Creates a new memory block backed by `mem`.
123    pub fn new<T: 'static + MappedDmaTarget>(mem: T) -> Self {
124        Self {
125            base: mem.base(),
126            len: mem.len(),
127            mem: Arc::new(mem),
128        }
129    }
130
131    /// Returns a view of a subset of the buffer.
132    pub fn subblock(&self, offset: usize, len: usize) -> Self {
133        match self.mem.view(offset, len) {
134            Some(view) => view,
135            None => Self::new(RestrictedView::new(self.mem.clone(), offset, len)),
136        }
137    }
138
139    /// Get the base address of the buffer.
140    pub fn base(&self) -> *const u8 {
141        self.base
142    }
143
144    /// Gets the length of the buffer in bytes.
145    pub fn len(&self) -> usize {
146        self.len
147    }
148
149    /// Gets the PFNs of the underlying memory.
150    pub fn pfns(&self) -> &[u64] {
151        self.mem.pfns()
152    }
153
154    /// Gets the pfn_bias of the underlying memory.
155    pub fn pfn_bias(&self) -> u64 {
156        self.mem.pfn_bias()
157    }
158
159    /// Gets the buffer as an atomic slice.
160    pub fn as_slice(&self) -> &[AtomicU8] {
161        // SAFETY: the underlying memory is valid for the lifetime of `mem`.
162        unsafe { std::slice::from_raw_parts(self.base.cast(), self.len) }
163    }
164
165    /// Reads from the buffer into `data`.
166    pub fn read_at(&self, offset: usize, data: &mut [u8]) {
167        self.as_slice()[offset..][..data.len()].atomic_read(data);
168    }
169
170    /// Reads an object from the buffer at `offset`.
171    pub fn read_obj<T: FromBytes + Immutable + KnownLayout>(&self, offset: usize) -> T {
172        self.as_slice()[offset..][..size_of::<T>()].atomic_read_obj()
173    }
174
175    /// Writes into the buffer from `data`.
176    pub fn write_at(&self, offset: usize, data: &[u8]) {
177        self.as_slice()[offset..][..data.len()].atomic_write(data);
178    }
179
180    /// Writes an object into the buffer at `offset`.
181    pub fn write_obj<T: IntoBytes + Immutable + KnownLayout>(&self, offset: usize, data: &T) {
182        self.as_slice()[offset..][..size_of::<T>()].atomic_write_obj(data);
183    }
184
185    /// Returns the offset of the beginning of the buffer in the first page
186    /// returned by [`Self::pfns`].
187    pub fn offset_in_page(&self) -> u32 {
188        self.base as u32 % PAGE_SIZE as u32
189    }
190}