user_driver/memory.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
//! Traits and types for sharing host memory with the device.
use safeatomic::AtomicSliceOps;
use std::sync::Arc;
use std::sync::atomic::AtomicU8;
use zerocopy::FromBytes;
use zerocopy::Immutable;
use zerocopy::IntoBytes;
use zerocopy::KnownLayout;
/// The 4KB page size used by user-mode devices.
pub const PAGE_SIZE: usize = 4096;
pub const PAGE_SIZE32: u32 = 4096;
pub const PAGE_SIZE64: u64 = PAGE_SIZE as u64;
/// A mapped buffer that can be accessed by the host or the device.
///
/// # Safety
/// The implementor must ensure that the VA region from `base()..base() + len()`
/// remains mapped for the lifetime.
pub unsafe trait MappedDmaTarget: Send + Sync {
/// The virtual address of the mapped memory.
fn base(&self) -> *const u8;
/// The length of the buffer in bytes.
fn len(&self) -> usize;
/// 4KB page numbers used to refer to the memory when communicating with the
/// device.
fn pfns(&self) -> &[u64];
/// The pfn_bias on confidential platforms (aka vTOM) applied to PFNs in [`Self::pfns()`],
fn pfn_bias(&self) -> u64;
/// Returns a view of a subset of the buffer.
///
/// Returns `None` if the default implementation should be used.
///
/// This should not be implemented except by internal implementations.
#[doc(hidden)]
fn view(&self, offset: usize, len: usize) -> Option<MemoryBlock> {
let _ = (offset, len);
None
}
}
struct RestrictedView {
mem: Arc<dyn MappedDmaTarget>,
len: usize,
offset: usize,
}
impl RestrictedView {
/// Wraps `mem` and provides a restricted view of it.
fn new(mem: Arc<dyn MappedDmaTarget>, offset: usize, len: usize) -> Self {
let mem_len = mem.len();
assert!(mem_len >= offset && mem_len - offset >= len);
Self { len, offset, mem }
}
}
// SAFETY: Passing through to the underlying impl after restricting the bounds
// (which were validated in `new`).
unsafe impl MappedDmaTarget for RestrictedView {
fn base(&self) -> *const u8 {
// SAFETY: verified in `new` to be in bounds.
unsafe { self.mem.base().add(self.offset) }
}
fn len(&self) -> usize {
self.len
}
fn pfns(&self) -> &[u64] {
let start = self.offset / PAGE_SIZE;
let count = (self.base() as usize % PAGE_SIZE + self.len + 0xfff) / PAGE_SIZE;
let pages = self.mem.pfns();
&pages[start..][..count]
}
fn pfn_bias(&self) -> u64 {
self.mem.pfn_bias()
}
fn view(&self, offset: usize, len: usize) -> Option<MemoryBlock> {
Some(MemoryBlock::new(RestrictedView::new(
self.mem.clone(),
self.offset.checked_add(offset).unwrap(),
len,
)))
}
}
/// A DMA target.
#[derive(Clone)]
pub struct MemoryBlock {
base: *const u8,
len: usize,
mem: Arc<dyn MappedDmaTarget>,
}
impl std::fmt::Debug for MemoryBlock {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MemoryBlock")
.field("base", &self.base)
.field("len", &self.len)
.field("pfns", &self.pfns())
.field("pfn_bias", &self.pfn_bias())
.finish()
}
}
// SAFETY: The inner MappedDmaTarget is Send + Sync, so a view of it is too.
unsafe impl Send for MemoryBlock {}
// SAFETY: The inner MappedDmaTarget is Send + Sync, so a view of it is too.
unsafe impl Sync for MemoryBlock {}
impl MemoryBlock {
/// Creates a new memory block backed by `mem`.
pub fn new<T: 'static + MappedDmaTarget>(mem: T) -> Self {
Self {
base: mem.base(),
len: mem.len(),
mem: Arc::new(mem),
}
}
/// Returns a view of a subset of the buffer.
pub fn subblock(&self, offset: usize, len: usize) -> Self {
match self.mem.view(offset, len) {
Some(view) => view,
None => Self::new(RestrictedView::new(self.mem.clone(), offset, len)),
}
}
/// Get the base address of the buffer.
pub fn base(&self) -> *const u8 {
self.base
}
/// Gets the length of the buffer in bytes.
pub fn len(&self) -> usize {
self.len
}
/// Gets the PFNs of the underlying memory.
pub fn pfns(&self) -> &[u64] {
self.mem.pfns()
}
/// Gets the pfn_bias of the underlying memory.
pub fn pfn_bias(&self) -> u64 {
self.mem.pfn_bias()
}
/// Gets the buffer as an atomic slice.
pub fn as_slice(&self) -> &[AtomicU8] {
// SAFETY: the underlying memory is valid for the lifetime of `mem`.
unsafe { std::slice::from_raw_parts(self.base.cast(), self.len) }
}
/// Reads from the buffer into `data`.
pub fn read_at(&self, offset: usize, data: &mut [u8]) {
self.as_slice()[offset..][..data.len()].atomic_read(data);
}
/// Reads an object from the buffer at `offset`.
pub fn read_obj<T: FromBytes + Immutable + KnownLayout>(&self, offset: usize) -> T {
self.as_slice()[offset..][..size_of::<T>()].atomic_read_obj()
}
/// Writes into the buffer from `data`.
pub fn write_at(&self, offset: usize, data: &[u8]) {
self.as_slice()[offset..][..data.len()].atomic_write(data);
}
/// Writes an object into the buffer at `offset`.
pub fn write_obj<T: IntoBytes + Immutable + KnownLayout>(&self, offset: usize, data: &T) {
self.as_slice()[offset..][..size_of::<T>()].atomic_write_obj(data);
}
/// Returns the offset of the beginning of the buffer in the first page
/// returned by [`Self::pfns`].
pub fn offset_in_page(&self) -> u32 {
self.base as u32 % PAGE_SIZE as u32
}
}