scsi_buffers/
lib.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! Functionality for referencing locked memory buffers for the lifetime of an
5//! IO.
6
7// UNSAFETY: Handling raw pointers and transmuting between types for different use cases.
8#![expect(unsafe_code)]
9
10use guestmem::AccessError;
11use guestmem::GuestMemory;
12use guestmem::LockedRange;
13use guestmem::LockedRangeImpl;
14use guestmem::MemoryRead;
15use guestmem::MemoryWrite;
16use guestmem::ranges::PagedRange;
17use guestmem::ranges::PagedRangeWriter;
18use safeatomic::AsAtomicBytes;
19use smallvec::SmallVec;
20use std::marker::PhantomData;
21use std::ops::Deref;
22use std::sync::atomic::AtomicU8;
23use std::sync::atomic::AtomicUsize;
24use std::sync::atomic::Ordering;
25use zerocopy::FromBytes;
26use zerocopy::Immutable;
27use zerocopy::IntoBytes;
28use zerocopy::KnownLayout;
29
30/// A pointer/length pair that is ABI compatible with the iovec type on Linux.
31#[derive(Debug, Copy, Clone)]
32#[repr(C)]
33pub struct AtomicIoVec {
34    /// The address of the buffer.
35    pub address: *const AtomicU8,
36    /// The length of the buffer in bytes.
37    pub len: usize,
38}
39
40impl Default for AtomicIoVec {
41    fn default() -> Self {
42        Self {
43            address: std::ptr::null(),
44            len: 0,
45        }
46    }
47}
48
49impl From<&'_ [AtomicU8]> for AtomicIoVec {
50    fn from(p: &'_ [AtomicU8]) -> Self {
51        Self {
52            address: p.as_ptr(),
53            len: p.len(),
54        }
55    }
56}
57
58impl AtomicIoVec {
59    /// Returns a pointer to a slice backed by the buffer.
60    ///
61    /// # Safety
62    /// The caller must ensure this iovec points to [valid](std::ptr#Safety)
63    /// data.
64    pub unsafe fn as_slice_unchecked(&self) -> &[AtomicU8] {
65        // SAFETY: guaranteed by caller.
66        unsafe { std::slice::from_raw_parts(self.address, self.len) }
67    }
68}
69
70/// SAFETY: AtomicIoVec just represents a pointer and length and can be
71/// sent/accessed anywhere freely.
72unsafe impl Send for AtomicIoVec {}
73// SAFETY: see above comment
74unsafe impl Sync for AtomicIoVec {}
75
76/// Wrapper around an &[AtomicU8] guaranteed to be ABI compatible with the
77/// `iovec` type on Linux.
78#[derive(Debug, Copy, Clone, Default)]
79#[repr(transparent)]
80pub struct IoBuffer<'a> {
81    io_vec: AtomicIoVec,
82    phantom: PhantomData<&'a AtomicU8>,
83}
84
85impl<'a> IoBuffer<'a> {
86    /// Wraps `buffer` and returns it.
87    pub fn new(buffer: &'a [AtomicU8]) -> Self {
88        Self {
89            io_vec: AtomicIoVec {
90                address: buffer.as_ptr(),
91                len: buffer.len(),
92            },
93            phantom: PhantomData,
94        }
95    }
96
97    /// Reinterprets `io_vec` as `IoBuffer`.
98    ///
99    /// # Safety
100    /// `io_vec` must reference a valid buffer for the lifetime of `Self`.
101    pub unsafe fn from_io_vec(io_vec: &AtomicIoVec) -> &Self {
102        // SAFETY: IoBuffer is #[repr(transparent)] over AtomicIoVec
103        unsafe { std::mem::transmute(io_vec) }
104    }
105
106    /// Reinterprets the `io_vecs` slice as `[IoBuffer]`.
107    ///
108    /// # Safety
109    /// `io_vecs` must reference valid buffers for the lifetime of `Self`.
110    pub unsafe fn from_io_vecs(io_vecs: &[AtomicIoVec]) -> &[Self] {
111        // SAFETY: IoBuffer is #[repr(transparent)] over AtomicIoVec
112        unsafe { std::mem::transmute(io_vecs) }
113    }
114
115    /// Returns a pointer to the beginning of the buffer.
116    pub fn as_ptr(&self) -> *const AtomicU8 {
117        self.io_vec.address
118    }
119
120    /// Returns the buffer's length in bytes.
121    pub fn len(&self) -> usize {
122        self.io_vec.len
123    }
124}
125
126impl Deref for IoBuffer<'_> {
127    type Target = [AtomicU8];
128
129    fn deref(&self) -> &Self::Target {
130        // SAFETY: the buffer is guaranteed to be valid for the lifetime of
131        // self.
132        unsafe { self.io_vec.as_slice_unchecked() }
133    }
134}
135
136const PAGE_SIZE: usize = 4096;
137
138#[repr(C, align(4096))]
139#[derive(Clone, IntoBytes, Immutable, KnownLayout, FromBytes)]
140struct Page([u8; PAGE_SIZE]);
141
142const ZERO_PAGE: Page = Page([0; PAGE_SIZE]);
143
144/// A page-aligned buffer used to double-buffer IO data.
145pub struct BounceBuffer {
146    pages: Vec<Page>,
147    io_vec: AtomicIoVec,
148}
149
150impl BounceBuffer {
151    /// Allocates a new bounce buffer of `size` bytes.
152    pub fn new(size: usize) -> Self {
153        let mut pages = vec![ZERO_PAGE; size.div_ceil(PAGE_SIZE)];
154        let io_vec = pages.as_mut_bytes()[..size].as_atomic_bytes().into();
155        BounceBuffer { pages, io_vec }
156    }
157
158    fn len(&self) -> usize {
159        self.io_vec.len
160    }
161
162    /// Returns the bounce buffer memory.
163    pub fn as_mut_bytes(&mut self) -> &mut [u8] {
164        // SAFETY: while there are no concurrent references (e.g., via io_vec),
165        // the buffer in pages is exclusively owned, and it is accessible as a
166        // byte array.
167        unsafe { std::slice::from_raw_parts_mut(self.pages.as_mut_ptr().cast::<u8>(), self.len()) }
168    }
169
170    /// Returns a reference to the underlying buffer.
171    ///
172    /// This is returned in a form convenient for using with IO functions.
173    pub fn io_vecs(&self) -> &[IoBuffer<'_>] {
174        std::slice::from_ref({
175            // SAFETY: io_vec contains a pointer to the live data in pages.
176            unsafe { IoBuffer::from_io_vec(&self.io_vec) }
177        })
178    }
179}
180
181/// A set of locked memory ranges, represented by [`IoBuffer`]s.
182pub struct LockedIoBuffers(LockedRangeImpl<LockedIoVecs>);
183
184impl LockedIoBuffers {
185    /// Returns the slice of IO buffers.
186    pub fn io_vecs(&self) -> &[IoBuffer<'_>] {
187        // SAFETY: the LockedRangeImpl passed to new guarantees that only
188        // vectors with valid lifetimes were passed to
189        // LockedGuestBuffers::push_sub_range.
190        unsafe { IoBuffer::from_io_vecs(&self.0.get().0) }
191    }
192}
193
194struct LockedIoVecs(SmallVec<[AtomicIoVec; 64]>);
195
196impl LockedIoVecs {
197    fn new() -> Self {
198        Self(Default::default())
199    }
200}
201
202impl LockedRange for LockedIoVecs {
203    fn push_sub_range(&mut self, sub_range: &[AtomicU8]) {
204        self.0.push(sub_range.into());
205    }
206}
207
208/// An implementation of [`MemoryWrite`] that provides semantically
209/// correct results. Specifically, it always returns a `ReadOnly` error
210/// when attempting to write to it.
211struct PermissionedMemoryWriter<'a> {
212    range: PagedRange<'a>,
213    writer: PagedRangeWriter<'a>,
214    is_write: bool,
215}
216
217impl PermissionedMemoryWriter<'_> {
218    /// Creates a new memory writer with the given range and guest memory.
219    fn new<'a>(
220        range: PagedRange<'a>,
221        guest_memory: &'a GuestMemory,
222        is_write: bool,
223    ) -> PermissionedMemoryWriter<'a> {
224        // Simply create an empty range here to avoid branching on hot paths (`write`, `fill`, etc.)
225        let range = if is_write { range } else { PagedRange::empty() };
226        PermissionedMemoryWriter {
227            range,
228            writer: range.writer(guest_memory),
229            is_write,
230        }
231    }
232}
233
234impl MemoryWrite for PermissionedMemoryWriter<'_> {
235    fn write(&mut self, data: &[u8]) -> Result<(), AccessError> {
236        self.writer.write(data).map_err(|e| {
237            if self.is_write {
238                e
239            } else {
240                AccessError::ReadOnly
241            }
242        })
243    }
244
245    fn fill(&mut self, val: u8, len: usize) -> Result<(), AccessError> {
246        self.writer.fill(val, len).map_err(|e| {
247            if self.is_write {
248                e
249            } else {
250                AccessError::ReadOnly
251            }
252        })
253    }
254
255    fn len(&self) -> usize {
256        self.range.len()
257    }
258}
259
260/// An accessor for the memory associated with an IO request.
261#[derive(Clone, Debug)]
262pub struct RequestBuffers<'a> {
263    range: PagedRange<'a>,
264    guest_memory: &'a GuestMemory,
265    is_write: bool,
266}
267
268impl<'a> RequestBuffers<'a> {
269    /// Creates a new request buffer from the given memory ranges.
270    pub fn new(guest_memory: &'a GuestMemory, range: PagedRange<'a>, is_write: bool) -> Self {
271        Self {
272            range,
273            guest_memory,
274            is_write,
275        }
276    }
277
278    /// Returns true if the buffer is empty.
279    pub fn is_empty(&self) -> bool {
280        self.range.is_empty()
281    }
282
283    /// Return the total length of the buffers in bytes.
284    pub fn len(&self) -> usize {
285        self.range.len()
286    }
287
288    /// Returns the guest memory accessor.
289    pub fn guest_memory(&self) -> &GuestMemory {
290        self.guest_memory
291    }
292
293    /// Return the internal paged range.
294    pub fn range(&self) -> PagedRange<'_> {
295        self.range
296    }
297
298    /// Returns whether the buffers are all aligned to at least `alignment`
299    /// bytes.
300    ///
301    /// `alignment` must be a power of two.
302    pub fn is_aligned(&self, alignment: usize) -> bool {
303        assert!(alignment.is_power_of_two());
304        ((self.range.offset() | self.range.len() | PAGE_SIZE) & (alignment - 1)) == 0
305    }
306
307    /// Gets a memory writer for the buffers.
308    ///
309    /// Returns an empty writer if the buffers are only available for read access.
310    pub fn writer(&self) -> impl MemoryWrite + '_ {
311        PermissionedMemoryWriter::new(self.range, self.guest_memory, self.is_write)
312    }
313
314    /// Gets a memory reader for the buffers.
315    pub fn reader(&self) -> impl MemoryRead + '_ {
316        self.range.reader(self.guest_memory)
317    }
318
319    /// Locks the guest memory ranges described by this buffer and returns an
320    /// object containing [`IoBuffer`]s, suitable for executing asynchronous I/O
321    /// operations.
322    pub fn lock(&self, for_write: bool) -> Result<LockedIoBuffers, AccessError> {
323        if for_write && !self.is_write {
324            return Err(AccessError::ReadOnly);
325        }
326        Ok(LockedIoBuffers(
327            self.guest_memory
328                .lock_range(self.range, LockedIoVecs::new())?,
329        ))
330    }
331
332    /// Returns a subrange of this set of buffers.
333    ///
334    /// Panics if `offset + len > self.len()`.
335    pub fn subrange(&self, offset: usize, len: usize) -> Self {
336        Self {
337            range: self.range.subrange(offset, len),
338            guest_memory: self.guest_memory,
339            is_write: self.is_write,
340        }
341    }
342}
343
344/// A memory range.
345#[derive(Debug, Clone)]
346pub struct OwnedRequestBuffers {
347    gpns: Vec<u64>,
348    offset: usize,
349    len: usize,
350    is_write: bool,
351}
352
353impl OwnedRequestBuffers {
354    /// A new memory range with the given guest page numbers.
355    pub fn new(gpns: &[u64]) -> Self {
356        Self::new_unaligned(gpns, 0, gpns.len() * PAGE_SIZE)
357    }
358
359    /// A new memory range with the given guest page numbers, offset by `offset`
360    /// bytes, and of `len` bytes length.
361    pub fn new_unaligned(gpns: &[u64], offset: usize, len: usize) -> Self {
362        Self {
363            gpns: gpns.to_vec(),
364            offset,
365            len,
366            is_write: true,
367        }
368    }
369
370    /// A new memory range containing the linear address range from
371    /// `offset..offset+len`.
372    pub fn linear(offset: u64, len: usize, is_write: bool) -> Self {
373        let start_page = offset / PAGE_SIZE as u64;
374        let end_page = offset + (len as u64).div_ceil(PAGE_SIZE as u64);
375        let gpns: Vec<u64> = (start_page..end_page).collect();
376        Self {
377            gpns,
378            offset: (offset % PAGE_SIZE as u64) as usize,
379            len,
380            is_write,
381        }
382    }
383
384    /// A [`RequestBuffers`] referencing this memory range.
385    pub fn buffer<'a>(&'a self, guest_memory: &'a GuestMemory) -> RequestBuffers<'a> {
386        RequestBuffers::new(
387            guest_memory,
388            PagedRange::new(self.offset, self.len, &self.gpns).unwrap(),
389            self.is_write,
390        )
391    }
392
393    /// The length of the range in bytes.
394    pub fn len(&self) -> usize {
395        self.len
396    }
397}
398
399/// Tracks an active bounce buffer, signaling to the bounce buffer tracker
400/// upon drop that pages can be reclaimed.
401pub struct TrackedBounceBuffer<'a> {
402    /// The active bounce buffer being tracked.
403    pub buffer: BounceBuffer,
404    /// Reference to free page counter for current IO thread.
405    free_pages: &'a AtomicUsize,
406    /// Used to signal pending bounce buffer requests of newly freed pages.
407    event: &'a event_listener::Event,
408}
409
410impl Drop for TrackedBounceBuffer<'_> {
411    fn drop(&mut self) {
412        let pages = self.buffer.len().div_ceil(4096);
413        self.free_pages.fetch_add(pages, Ordering::SeqCst);
414        self.event.notify(usize::MAX);
415    }
416}
417
418/// Tracks active bounce buffers against a set limit of pages. If no limit is
419/// specified a default of 8Mb will be applied. This limit is tracked per thread
420/// specified by the backing AffinitizedThreadpool.
421#[derive(Debug)]
422pub struct BounceBufferTracker {
423    /// Active bounce buffer pages on a given thread.
424    free_pages: Vec<AtomicUsize>,
425    /// Event used by TrackedBounceBuffer to signal pages have been dropped.
426    event: Vec<event_listener::Event>,
427}
428
429impl BounceBufferTracker {
430    /// Create a new bounce buffer tracker.
431    pub fn new(max_bounce_buffer_pages: usize, threads: usize) -> Self {
432        let mut free_pages = Vec::with_capacity(threads);
433        let mut event = Vec::with_capacity(threads);
434
435        (0..threads).for_each(|_| {
436            event.push(event_listener::Event::new());
437            free_pages.push(AtomicUsize::new(max_bounce_buffer_pages));
438        });
439
440        Self { free_pages, event }
441    }
442
443    /// Attempts to acquire bounce buffers from the tracker proceeding if pages
444    /// are available or waiting until a tracked bounce buffer is dropped, which
445    /// triggers the per-thread event to indicate newly freed pages.
446    pub async fn acquire_bounce_buffers<'a, 'b>(
447        &'b self,
448        size: usize,
449        thread: usize,
450    ) -> Box<TrackedBounceBuffer<'a>>
451    where
452        'b: 'a,
453    {
454        let pages = size.div_ceil(4096);
455        let event = self.event.get(thread).unwrap();
456        let free_pages = self.free_pages.get(thread).unwrap();
457
458        loop {
459            let listener = event.listen();
460            if free_pages
461                .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |x| x.checked_sub(pages))
462                .is_ok()
463            {
464                break;
465            }
466            listener.await;
467        }
468
469        Box::new(TrackedBounceBuffer {
470            buffer: BounceBuffer::new(size),
471            free_pages,
472            event,
473        })
474    }
475}
476
477#[cfg(test)]
478mod tests {
479    use super::*;
480    use sparse_mmap::SparseMapping;
481    const SIZE_1MB: usize = 1048576;
482
483    #[test]
484    fn correct_read_only_behavior() {
485        let mapping = SparseMapping::new(SIZE_1MB * 4).unwrap();
486        let guest_memory = GuestMemory::new("test-scsi-buffers", mapping);
487        let range = PagedRange::new(0, 4096, &[0]).unwrap();
488        let buffers = RequestBuffers::new(&guest_memory, range, false);
489
490        let r = buffers.writer().write(&[1; 4096]);
491        assert!(
492            matches!(r, Err(AccessError::ReadOnly)),
493            "Expected read-only error, got {:?}",
494            r
495        );
496
497        let r = buffers.writer().fill(1, 4096);
498        assert!(
499            matches!(r, Err(AccessError::ReadOnly)),
500            "Expected read-only error, got {:?}",
501            r
502        );
503
504        assert!(
505            buffers.writer().len() == 0,
506            "Length should be 0 for read-only writer"
507        );
508    }
509}