mesh_protobuf/
buffer.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! Types to support writing to a contiguous byte buffer.
5//!
6//! This is different from `bytes::BufMut` in that the buffer is required to be
7//! contiguous, which allows for more efficient use with type erasure.
8
9use alloc::vec::Vec;
10use core::mem::MaybeUninit;
11
12/// Models a partially written, contiguous byte buffer.
13pub trait Buffer {
14    /// Returns the unwritten portion of the buffer. The returned data may or
15    /// may not be initialized.
16    ///
17    /// # Safety
18    /// The caller must ensure that no uninitialized bytes are written to the
19    /// slice.
20    ///
21    /// An astute reader might note that the `Vec<u8>` implementation does not
22    /// require the unsafe bound on this function, as those bytes returned by
23    /// are truly `MaybeUninit`. However, based on the backing storage of [Buffer]
24    /// this is not always the case.
25    ///
26    /// For example, a `Buffer` implementation on a `Cursor<&[u8]>` could be used
27    /// to _uninitialize_ a portion of the slice, by doing the following:
28    ///
29    /// ```ignore
30    /// // some_cursor contains a Cursor based implementation of Buffer which is
31    /// // backed by storage that is always initialized.
32    /// let foo = some_cursor.unwritten();
33    /// foo[0].write(MaybeUninit::uninit()) // This is UB!! ⚠️
34    /// ```
35    ///
36    /// Thus the caller must ensure that uninitialize bytes are _never_
37    /// written to the returned slice, and why this function is unsafe.
38    unsafe fn unwritten(&mut self) -> &mut [MaybeUninit<u8>];
39
40    /// Extends the initialized region of the buffer.
41    ///
42    /// # Safety
43    /// The caller must ensure that the next `len` bytes have been initialized.
44    unsafe fn extend_written(&mut self, len: usize);
45}
46
47impl Buffer for Vec<u8> {
48    unsafe fn unwritten(&mut self) -> &mut [MaybeUninit<u8>] {
49        self.spare_capacity_mut()
50    }
51
52    unsafe fn extend_written(&mut self, len: usize) {
53        // SAFETY: The caller guarantees that `len` bytes have been written.
54        unsafe {
55            self.set_len(self.len() + len);
56        }
57    }
58}
59
60impl Buffer for Buf<'_> {
61    unsafe fn unwritten(&mut self) -> &mut [MaybeUninit<u8>] {
62        &mut self.buf[*self.filled..]
63    }
64
65    unsafe fn extend_written(&mut self, len: usize) {
66        *self.filled += len;
67    }
68}
69
70#[cfg(feature = "std")]
71impl Buffer for std::io::Cursor<&mut [u8]> {
72    unsafe fn unwritten(&mut self) -> &mut [MaybeUninit<u8>] {
73        let slice = self.get_mut();
74        // SAFETY: the caller promises not to uninitialize any initialized data.
75        unsafe { core::slice::from_raw_parts_mut(slice.as_mut_ptr().cast(), slice.len()) }
76    }
77
78    unsafe fn extend_written(&mut self, len: usize) {
79        self.set_position(self.position() + len as u64);
80    }
81}
82
83/// An accessor for writing to a partially-initialized byte buffer.
84pub struct Buf<'a> {
85    buf: &'a mut [MaybeUninit<u8>],
86    filled: &'a mut usize,
87}
88
89impl Buf<'_> {
90    /// Returns the remaining bytes that fit.
91    #[inline(always)]
92    pub fn remaining(&self) -> usize {
93        self.buf.len() - *self.filled
94    }
95
96    /// Returns the number of bytes that have been written.
97    #[inline(always)]
98    pub fn len(&self) -> usize {
99        *self.filled
100    }
101
102    /// Extends the initialized portion of the buffer with `b`. Panics if it
103    /// doesn't fit.
104    #[inline(always)]
105    pub fn push(&mut self, b: u8) {
106        self.buf[*self.filled] = MaybeUninit::new(b);
107        *self.filled += 1;
108    }
109
110    /// Extends the initialized portion of the buffer with `buf`. Panics if the
111    /// data does not fit.
112    #[inline(always)]
113    pub fn append(&mut self, buf: &[u8]) {
114        assert!(buf.len() <= self.remaining());
115        // SAFETY: copying into self.buf with bounds checked above.
116        unsafe {
117            self.buf
118                .as_mut_ptr()
119                .add(*self.filled)
120                .cast::<u8>()
121                .copy_from_nonoverlapping(buf.as_ptr(), buf.len());
122        }
123        *self.filled += buf.len();
124    }
125
126    /// Extends the initialized portion of the buffer with `len` bytes equal to
127    /// `val`. Panics if the data does not fit.
128    #[inline(always)]
129    pub fn fill(&mut self, val: u8, len: usize) {
130        self.buf[*self.filled..][..len].fill(MaybeUninit::new(val));
131        *self.filled += len;
132    }
133
134    /// Splits this buffer into two at `split_at` and calls `f` to fill out each
135    /// part.
136    ///
137    /// If the left buffer is not filled in full but the right buffer is
138    /// partially initialized, then the remainder of the left buffer will be
139    /// zero-initialized.
140    #[track_caller]
141    pub fn write_split<R>(&mut self, split_at: usize, f: impl FnOnce(Buf<'_>, Buf<'_>) -> R) -> R {
142        let (left, right) = self.buf[*self.filled..].split_at_mut(split_at);
143        let mut left_filled = 0;
144        let mut right_filled = 0;
145        let r = f(
146            Buf {
147                buf: left,
148                filled: &mut left_filled,
149            },
150            Buf {
151                buf: right,
152                filled: &mut right_filled,
153            },
154        );
155        assert!(left_filled <= left.len());
156        assert!(right_filled <= right.len());
157        *self.filled += left_filled;
158        if right_filled > 0 {
159            let to_zero = left.len() - left_filled;
160            self.fill(0, to_zero);
161            *self.filled += right_filled;
162        }
163        r
164    }
165}
166
167/// Calls `f` with a [`Buf`], which provides safe methods for
168/// extending the initialized portion of the buffer.
169pub fn write_with<T, F, R>(buffer: &mut T, f: F) -> R
170where
171    T: Buffer + ?Sized,
172    F: FnOnce(Buf<'_>) -> R,
173{
174    let mut filled = 0;
175    // SAFETY: Buf will only write initialized bytes to the buffer.
176    let buf = unsafe { buffer.unwritten() };
177
178    let r = f(Buf {
179        buf,
180        filled: &mut filled,
181    });
182    // SAFETY: `filled` bytes are known to have been written.
183    unsafe {
184        buffer.extend_written(filled);
185    }
186    r
187}
188
189#[cfg(test)]
190mod tests {
191    use super::write_with;
192    use alloc::vec;
193
194    #[test]
195    #[should_panic]
196    fn test_append_vec_panic() {
197        let mut v = vec![1, 2, 3];
198        write_with(&mut v, |mut buf| {
199            buf.append(&vec![0; buf.remaining() + 1]);
200        });
201    }
202
203    #[test]
204    fn test_append_vec() {
205        let mut v = vec![1, 2, 3, 4];
206        v.reserve(3);
207
208        write_with(&mut v, |mut buf| {
209            buf.append(&[5, 6]);
210            buf.push(7);
211        });
212        assert_eq!(&v, &[1, 2, 3, 4, 5, 6, 7]);
213    }
214}