headervec/
lib.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! This module implements the `HeaderVec` type for constructing dynamically
5//! sized values that have a fixed size header and a variable sized element
6//! type. This is a common pattern in IOCTL input buffers.
7
8// UNSAFETY: Implementing a custom data structure that requires manual memory
9// management and pointer manipulation.
10#![expect(unsafe_code)]
11#![no_std]
12
13extern crate alloc;
14
15use alloc::alloc::Layout;
16use alloc::alloc::alloc;
17use alloc::alloc::handle_alloc_error;
18use alloc::boxed::Box;
19use core::cmp;
20use core::mem::MaybeUninit;
21use core::ops::Deref;
22use core::ops::DerefMut;
23use core::ptr::NonNull;
24
25/// A type that represents a fixed-sized header followed by a variable-sized
26/// tail.
27#[repr(C)]
28#[derive(Debug)]
29pub struct HeaderSlice<T, U: ?Sized> {
30    /// The fixed-sized header.
31    pub head: T,
32    /// The variable-sized tail.
33    pub tail: U,
34}
35
36impl<T, U> HeaderSlice<T, [U]> {
37    fn ptr_from_raw_parts(ptr: *const T, len: usize) -> *const Self {
38        // Create a [T] (the inner type doesn't actually matter) with `len`
39        // elements, then cast it to a HeaderSlice<T, [U]>. The cast via `as`
40        // preserves the element count.
41        //
42        // FUTURE: use [`core::ptr::from_raw_parts`] once it is stable.
43        core::ptr::slice_from_raw_parts(ptr, len) as *const Self
44    }
45
46    fn ptr_from_raw_parts_mut(ptr: *mut T, len: usize) -> *mut Self {
47        // Create a [T] (the inner type doesn't actually matter) with `len`
48        // elements, then cast it to a HeaderSlice<T, [U]>. The cast via `as`
49        // preserves the element count.
50        //
51        // FUTURE: use [`core::ptr::from_raw_parts_mut`] once it is stable.
52        core::ptr::slice_from_raw_parts_mut(ptr, len) as *mut Self
53    }
54
55    /// # Safety
56    /// The caller must ensure that `ptr` points to a `T` followed by `len`
57    /// elements of `U`, valid for lifetime `'a`.
58    unsafe fn from_raw_parts<'a>(ptr: *const T, len: usize) -> &'a Self {
59        // SAFETY: the caller ensures that the resulting pointer is valid for
60        // lifetime `'a`.
61        unsafe { &*Self::ptr_from_raw_parts(ptr, len) }
62    }
63
64    /// # Safety
65    /// The caller must ensure that `ptr` points to a `T` followed by `len`
66    /// elements of `U`, valid for lifetime `'a`.
67    unsafe fn from_raw_parts_mut<'a>(ptr: *mut T, len: usize) -> &'a mut Self {
68        // SAFETY: the caller ensures that the resulting pointer is valid for
69        // lifetime `'a`.
70        unsafe { &mut *Self::ptr_from_raw_parts_mut(ptr, len) }
71    }
72}
73
74#[derive(Debug)]
75enum Data<T, U, const N: usize> {
76    Fixed(HeaderSlice<T, [MaybeUninit<U>; N]>),
77    Alloc(Box<HeaderSlice<T, [MaybeUninit<U>]>>),
78}
79
80impl<T, U, const N: usize> Data<T, U, N> {
81    /// # Safety
82    ///
83    /// The caller must ensure that the first `len` elements have been initialized.
84    unsafe fn valid(&self, len: usize) -> &HeaderSlice<T, [U]> {
85        // SAFETY: the caller has ensured that the first `len` elements have been
86        // initialized.
87        unsafe { HeaderSlice::from_raw_parts(core::ptr::from_ref(self.storage()).cast(), len) }
88    }
89
90    /// # Safety
91    ///
92    /// The caller must ensure that the first `len` elements have been initialized.
93    unsafe fn valid_mut(&mut self, len: usize) -> &mut HeaderSlice<T, [U]> {
94        // SAFETY: the caller has ensured that the first `len` elements have been
95        // initialized.
96        unsafe {
97            HeaderSlice::from_raw_parts_mut(core::ptr::from_mut(self.storage_mut()).cast(), len)
98        }
99    }
100
101    fn storage(&self) -> &HeaderSlice<T, [MaybeUninit<U>]> {
102        let p: &HeaderSlice<T, [MaybeUninit<U>]> = match self {
103            Data::Fixed(p) => p,
104            Data::Alloc(p) => p,
105        };
106        if size_of::<U>() == 0 {
107            // SAFETY: the tail element is a ZST so its slice is valid for any
108            // length.
109            unsafe { HeaderSlice::from_raw_parts(&raw const p.head, usize::MAX) }
110        } else {
111            p
112        }
113    }
114
115    fn storage_mut(&mut self) -> &mut HeaderSlice<T, [MaybeUninit<U>]> {
116        let p: &mut HeaderSlice<T, [MaybeUninit<U>]> = match self {
117            Data::Fixed(p) => p,
118            Data::Alloc(p) => p,
119        };
120        if size_of::<U>() == 0 {
121            // SAFETY: the tail element is a ZST so its slice is valid for any
122            // length.
123            unsafe { HeaderSlice::from_raw_parts_mut(&raw mut p.head, usize::MAX) }
124        } else {
125            p
126        }
127    }
128}
129
130/// Implements a `Vec`-like type for building structures with a fixed-sized
131/// prefix before a dynamic number of elements.
132///
133/// To avoid allocations in common cases, the header and elements are stored
134/// internally without allocating until the element count would exceed the
135/// statically determined capacity.
136///
137/// Only a small portion of the `Vec` interface is supported. Additional methods
138/// can be added as needed.
139///
140/// The data managed by this type must be `Copy`. This simplifies the resource
141/// management and should be sufficient for most use cases.
142///
143/// # Example
144/// ```
145/// # use headervec::HeaderVec;
146/// #[derive(Copy, Clone)]
147/// struct Header { x: u32 }
148/// let mut v = HeaderVec::<Header, u8, 10>::new(Header{ x: 1234 });
149/// v.push_tail(5);
150/// v.push_tail(6);
151/// assert_eq!(v.head.x, 1234);
152/// assert_eq!(&v.tail, &[5, 6]);
153/// ```
154#[derive(Debug)]
155pub struct HeaderVec<T, U, const N: usize> {
156    data: Data<T, U, N>,
157    len: usize,
158}
159
160impl<T: Copy + Default, U: Copy, const N: usize> Default for HeaderVec<T, U, N> {
161    fn default() -> Self {
162        Self::new(Default::default())
163    }
164}
165
166impl<T: Copy, U: Copy, const N: usize> HeaderVec<T, U, N> {
167    /// Constructs a new `HeaderVec` with a header of `head` and no tail
168    /// elements.
169    pub fn new(head: T) -> Self {
170        Self {
171            data: Data::Fixed(HeaderSlice {
172                head,
173                tail: [const { MaybeUninit::uninit() }; N],
174            }),
175            len: 0,
176        }
177    }
178
179    /// Constructs a new `HeaderVec` with a header of `head` and no tail
180    /// elements, but with a dynamically allocated capacity for `cap` elements.
181    pub fn with_capacity(head: T, cap: usize) -> Self {
182        let mut vec = Self::new(head);
183        if cap > vec.tail_capacity() {
184            vec.realloc(cap);
185        }
186        vec
187    }
188
189    fn realloc(&mut self, cap: usize) {
190        assert!(cap > self.len);
191        assert!(size_of::<U>() > 0);
192
193        let base_layout = Layout::new::<HeaderSlice<T, [MaybeUninit<U>; 0]>>();
194        let layout = Layout::from_size_align(
195            base_layout
196                .size()
197                .checked_add(size_of::<U>().checked_mul(cap).unwrap())
198                .unwrap(),
199            base_layout.align(),
200        )
201        .unwrap();
202
203        // SAFETY: `layout` is correctly constructed and is non-empty.
204        let alloc = unsafe { alloc(layout) };
205        let Some(alloc) = NonNull::new(alloc) else {
206            handle_alloc_error(layout);
207        };
208        // Copy the head.
209        // SAFETY: `alloc` starts with `T`.
210        unsafe {
211            alloc.cast::<T>().write(self.data.storage_mut().head);
212        }
213        // Build the fat pointer to the DST.
214        let alloc =
215            HeaderSlice::<T, [MaybeUninit<U>]>::ptr_from_raw_parts_mut(alloc.as_ptr().cast(), cap);
216        // SAFETY: `head` has been initialized and `tail` is `MaybeUninit`.
217        // `alloc` was allocated with the same layout `Box::new` would use.
218        let mut alloc = unsafe { Box::from_raw(alloc) };
219        // Copy the initialized portion of the tail.
220        alloc.tail[..self.len].copy_from_slice(&self.data.storage_mut().tail[..self.len]);
221        self.data = Data::Alloc(alloc);
222    }
223
224    fn extend_tail(&mut self, n: usize) -> &mut [MaybeUninit<U>] {
225        let cap = self.tail_capacity();
226        if cap - self.len < n {
227            assert!(size_of::<U>() > 0, "ZST tail slice overflow");
228            // Double the current capacity to ensure a geometric progression
229            // (avoiding O(n^2) allocations).
230            let new_cap = cmp::max(
231                cmp::max(8, cap.checked_mul(2).unwrap()),
232                self.len.checked_add(n).unwrap(),
233            );
234            self.realloc(new_cap);
235        }
236        &mut self.spare_tail_capacity_mut()[..n]
237    }
238
239    /// Reserves capacity for at least `n` additional tail elements.
240    pub fn reserve_tail(&mut self, n: usize) {
241        self.extend_tail(n);
242    }
243
244    /// Returns the remaining spare capacity of the tail as a slice of
245    /// `MaybeUninit<U>`.
246    ///
247    /// The returned slice can be used to fill the tail with data before marking
248    /// the data as initialized using [`Self::set_tail_len`].
249    pub fn spare_tail_capacity_mut(&mut self) -> &mut [MaybeUninit<U>] {
250        &mut self.data.storage_mut().tail[self.len..]
251    }
252
253    /// Pushes a tail element, reallocating if necessary.
254    pub fn push_tail(&mut self, val: U) {
255        // For zero-sized types (unlikely to be useful but hard to prohibit),
256        // just increment len.
257        if size_of_val(&val) > 0 {
258            self.extend_tail(1)[0].write(val);
259        }
260        self.len += 1;
261    }
262
263    /// Extends the tail elements from the given slice.
264    pub fn extend_tail_from_slice(&mut self, other: &[U]) {
265        // SAFETY: `[MaybeUninit<U>]` and `[U]` have the same layout.
266        let other = unsafe { core::mem::transmute::<&[U], &[MaybeUninit<U>]>(other) };
267        self.extend_tail(other.len()).copy_from_slice(other);
268        self.len += other.len();
269    }
270
271    /// Retrieves a pointer to the head. The tail is guaranteed to immediately
272    /// after the head (with appropriate padding).
273    pub fn as_ptr(&self) -> *const T {
274        &self.head
275    }
276
277    /// Retrieves a mutable pointer to the head. The tail is guaranteed to
278    /// immediately after the head (with appropriate padding).
279    pub fn as_mut_ptr(&mut self) -> *mut T {
280        &mut self.head
281    }
282
283    /// Returns the number of tail elements that can be stored without
284    /// reallocating.
285    pub fn tail_capacity(&self) -> usize {
286        self.data.storage().tail.len()
287    }
288
289    /// Sets the number of tail elements to 0.
290    pub fn clear_tail(&mut self) {
291        self.len = 0;
292    }
293
294    /// Truncates the tail to `len` elements. Has no effect if there are already
295    /// fewer than `len` tail elements.
296    pub fn truncate_tail(&mut self, len: usize) {
297        if len < self.len {
298            self.len = len;
299        }
300    }
301
302    /// Sets the number of tail elements.
303    ///
304    /// Panics if `len` is greater than the capacity.
305    ///
306    /// # Safety
307    ///
308    /// The caller must ensure that all `len` elements have been initialized.
309    pub unsafe fn set_tail_len(&mut self, len: usize) {
310        assert!(len <= self.tail_capacity());
311        self.len = len;
312    }
313
314    /// Returns the total contiguous byte length of the structure, including
315    /// both the head and tail elements.
316    pub fn total_byte_len(&self) -> usize {
317        size_of_val(&**self)
318    }
319
320    /// Returns the total contiguous byte length of the structure, including
321    /// both the head and tail elements, including the tail's capacity.
322    pub fn total_byte_capacity(&self) -> usize {
323        size_of_val(self.data.storage())
324    }
325}
326
327impl<T, U, const N: usize> Deref for HeaderVec<T, U, N> {
328    type Target = HeaderSlice<T, [U]>;
329    fn deref(&self) -> &Self::Target {
330        // SAFETY: `self.len` tail elements have been initialized.
331        unsafe { self.data.valid(self.len) }
332    }
333}
334
335impl<T, U, const N: usize> DerefMut for HeaderVec<T, U, N> {
336    fn deref_mut(&mut self) -> &mut Self::Target {
337        // SAFETY: `self.len` tail elements have been initialized.
338        unsafe { self.data.valid_mut(self.len) }
339    }
340}
341
342impl<T: Copy, U: Copy, const N: usize> Extend<U> for HeaderVec<T, U, N> {
343    fn extend<I: IntoIterator<Item = U>>(&mut self, iter: I) {
344        for item in iter {
345            self.push_tail(item);
346        }
347    }
348}
349
350#[cfg(test)]
351mod tests {
352    use super::HeaderVec;
353    use alloc::vec::Vec;
354    use core::fmt::Debug;
355
356    fn test<T: Copy + Eq + Debug, U: Copy + Eq + Debug, const N: usize>(
357        head: T,
358        vals: impl IntoIterator<Item = U>,
359    ) {
360        let vals = Vec::from_iter(vals);
361        // Push
362        {
363            let mut v: HeaderVec<T, U, N> = HeaderVec::new(head);
364            for &i in &vals {
365                v.push_tail(i);
366            }
367            assert_eq!(v.head, head);
368            assert_eq!(&v.tail, vals.as_slice());
369        }
370        // Extend from slice
371        {
372            let mut v: HeaderVec<T, U, N> = HeaderVec::new(head);
373            v.extend_tail_from_slice(&vals);
374            assert_eq!(v.head, head);
375            assert_eq!(&v.tail, vals.as_slice());
376        }
377        // Extend
378        {
379            let mut v: HeaderVec<T, U, N> = HeaderVec::new(head);
380            v.extend(vals.iter().copied());
381            assert_eq!(v.head, head);
382            assert_eq!(&v.tail, vals.as_slice());
383        }
384        // Reserve + set_len
385        {
386            let mut v: HeaderVec<T, U, N> = HeaderVec::new(head);
387            v.reserve_tail(vals.len());
388            if size_of::<U>() > 0 {
389                assert_eq!(
390                    v.tail_capacity(),
391                    if size_of::<U>() == 0 {
392                        usize::MAX
393                    } else {
394                        vals.len()
395                    }
396                );
397            }
398            for (s, d) in vals.iter().copied().zip(v.spare_tail_capacity_mut()) {
399                d.write(s);
400            }
401            // SAFETY: all elements are initialized.
402            unsafe { v.set_tail_len(vals.len()) };
403            assert_eq!(v.head, head);
404            assert_eq!(&v.tail, vals.as_slice());
405        }
406    }
407
408    #[test]
409    fn test_push() {
410        test::<u8, u32, 3>(0x10, 0..200);
411    }
412
413    #[test]
414    fn test_zero_array() {
415        test::<u8, u32, 0>(0x10, 0..200);
416    }
417
418    #[test]
419    fn test_zst_head() {
420        test::<(), u32, 3>((), 0..200);
421    }
422
423    #[test]
424    fn test_zst_tail() {
425        test::<u8, (), 0>(0x10, (0..200).map(|_| ()));
426    }
427
428    #[test]
429    fn test_zst_both() {
430        test::<(), (), 0>((), (0..200).map(|_| ()));
431    }
432
433    #[test]
434    #[should_panic(expected = "ZST tail slice overflow")]
435    fn test_zst_overflow() {
436        let mut v: HeaderVec<u8, (), 0> = HeaderVec::new(0);
437        v.push_tail(());
438        v.extend_tail_from_slice(&[(); usize::MAX]);
439    }
440}