user_driver/
memory.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

//! Traits and types for sharing host memory with the device.

use safeatomic::AtomicSliceOps;
use std::sync::Arc;
use std::sync::atomic::AtomicU8;
use zerocopy::FromBytes;
use zerocopy::Immutable;
use zerocopy::IntoBytes;
use zerocopy::KnownLayout;

/// The 4KB page size used by user-mode devices.
pub const PAGE_SIZE: usize = 4096;
pub const PAGE_SIZE32: u32 = 4096;
pub const PAGE_SIZE64: u64 = PAGE_SIZE as u64;

/// A mapped buffer that can be accessed by the host or the device.
///
/// # Safety
/// The implementor must ensure that the VA region from `base()..base() + len()`
/// remains mapped for the lifetime.
pub unsafe trait MappedDmaTarget: Send + Sync {
    /// The virtual address of the mapped memory.
    fn base(&self) -> *const u8;

    /// The length of the buffer in bytes.
    fn len(&self) -> usize;

    /// 4KB page numbers used to refer to the memory when communicating with the
    /// device.
    fn pfns(&self) -> &[u64];

    /// The pfn_bias on confidential platforms (aka vTOM) applied to PFNs in [`Self::pfns()`],
    fn pfn_bias(&self) -> u64;

    /// Returns a view of a subset of the buffer.
    ///
    /// Returns `None` if the default implementation should be used.
    ///
    /// This should not be implemented except by internal implementations.
    #[doc(hidden)]
    fn view(&self, offset: usize, len: usize) -> Option<MemoryBlock> {
        let _ = (offset, len);
        None
    }
}

struct RestrictedView {
    mem: Arc<dyn MappedDmaTarget>,
    len: usize,
    offset: usize,
}

impl RestrictedView {
    /// Wraps `mem` and provides a restricted view of it.
    fn new(mem: Arc<dyn MappedDmaTarget>, offset: usize, len: usize) -> Self {
        let mem_len = mem.len();
        assert!(mem_len >= offset && mem_len - offset >= len);
        Self { len, offset, mem }
    }
}

// SAFETY: Passing through to the underlying impl after restricting the bounds
// (which were validated in `new`).
unsafe impl MappedDmaTarget for RestrictedView {
    fn base(&self) -> *const u8 {
        // SAFETY: verified in `new` to be in bounds.
        unsafe { self.mem.base().add(self.offset) }
    }

    fn len(&self) -> usize {
        self.len
    }

    fn pfns(&self) -> &[u64] {
        let start = self.offset / PAGE_SIZE;
        let count = (self.base() as usize % PAGE_SIZE + self.len + 0xfff) / PAGE_SIZE;
        let pages = self.mem.pfns();
        &pages[start..][..count]
    }

    fn pfn_bias(&self) -> u64 {
        self.mem.pfn_bias()
    }

    fn view(&self, offset: usize, len: usize) -> Option<MemoryBlock> {
        Some(MemoryBlock::new(RestrictedView::new(
            self.mem.clone(),
            self.offset.checked_add(offset).unwrap(),
            len,
        )))
    }
}

/// A DMA target.
#[derive(Clone)]
pub struct MemoryBlock {
    base: *const u8,
    len: usize,
    mem: Arc<dyn MappedDmaTarget>,
}

impl std::fmt::Debug for MemoryBlock {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("MemoryBlock")
            .field("base", &self.base)
            .field("len", &self.len)
            .field("pfns", &self.pfns())
            .field("pfn_bias", &self.pfn_bias())
            .finish()
    }
}

// SAFETY: The inner MappedDmaTarget is Send + Sync, so a view of it is too.
unsafe impl Send for MemoryBlock {}
// SAFETY: The inner MappedDmaTarget is Send + Sync, so a view of it is too.
unsafe impl Sync for MemoryBlock {}

impl MemoryBlock {
    /// Creates a new memory block backed by `mem`.
    pub fn new<T: 'static + MappedDmaTarget>(mem: T) -> Self {
        Self {
            base: mem.base(),
            len: mem.len(),
            mem: Arc::new(mem),
        }
    }

    /// Returns a view of a subset of the buffer.
    pub fn subblock(&self, offset: usize, len: usize) -> Self {
        match self.mem.view(offset, len) {
            Some(view) => view,
            None => Self::new(RestrictedView::new(self.mem.clone(), offset, len)),
        }
    }

    /// Get the base address of the buffer.
    pub fn base(&self) -> *const u8 {
        self.base
    }

    /// Gets the length of the buffer in bytes.
    pub fn len(&self) -> usize {
        self.len
    }

    /// Gets the PFNs of the underlying memory.
    pub fn pfns(&self) -> &[u64] {
        self.mem.pfns()
    }

    /// Gets the pfn_bias of the underlying memory.
    pub fn pfn_bias(&self) -> u64 {
        self.mem.pfn_bias()
    }

    /// Gets the buffer as an atomic slice.
    pub fn as_slice(&self) -> &[AtomicU8] {
        // SAFETY: the underlying memory is valid for the lifetime of `mem`.
        unsafe { std::slice::from_raw_parts(self.base.cast(), self.len) }
    }

    /// Reads from the buffer into `data`.
    pub fn read_at(&self, offset: usize, data: &mut [u8]) {
        self.as_slice()[offset..][..data.len()].atomic_read(data);
    }

    /// Reads an object from the buffer at `offset`.
    pub fn read_obj<T: FromBytes + Immutable + KnownLayout>(&self, offset: usize) -> T {
        self.as_slice()[offset..][..size_of::<T>()].atomic_read_obj()
    }

    /// Writes into the buffer from `data`.
    pub fn write_at(&self, offset: usize, data: &[u8]) {
        self.as_slice()[offset..][..data.len()].atomic_write(data);
    }

    /// Writes an object into the buffer at `offset`.
    pub fn write_obj<T: IntoBytes + Immutable + KnownLayout>(&self, offset: usize, data: &T) {
        self.as_slice()[offset..][..size_of::<T>()].atomic_write_obj(data);
    }

    /// Returns the offset of the beginning of the buffer in the first page
    /// returned by [`Self::pfns`].
    pub fn offset_in_page(&self) -> u32 {
        self.base as u32 % PAGE_SIZE as u32
    }
}