mesh_protobuf/
buffer.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! Types to support writing to a contiguous byte buffer.
5//!
6//! This is different from `bytes::BufMut` in that the buffer is required to be
7//! contiguous, which allows for more efficient use with type erasure.
8
9use alloc::vec::Vec;
10use core::mem::MaybeUninit;
11
12/// Models a partially written, contiguous byte buffer.
13pub trait Buffer {
14    /// Returns the unwritten portion of the buffer. The returned data may or
15    /// may not be initialized.
16    ///
17    /// # Safety
18    /// The caller must ensure that no uninitialized bytes are written to the
19    /// slice.
20    ///
21    /// An astute reader might note that the `Vec<u8>` implementation does not
22    /// require the unsafe bound on this function, as those bytes returned by
23    /// are truly `MaybeUninit`. However, based on the backing storage of [Buffer]
24    /// this is not always the case.
25    ///
26    /// For example, a `Buffer` implementation on a `Cursor<&[u8]>` could be used
27    /// to _uninitialize_ a portion of the slice, by doing the following:
28    ///
29    /// ```ignore
30    /// // some_cursor contains a Cursor based implementation of Buffer which is
31    /// // backed by storage that is always initialized.
32    /// let foo = some_cursor.unwritten();
33    /// foo[0].write(MaybeUninit::uninit()) // This is UB!! ⚠️
34    /// ```
35    ///
36    /// Thus the caller must ensure that uninitialize bytes are _never_
37    /// written to the returned slice, and why this function is unsafe.
38    unsafe fn unwritten(&mut self) -> &mut [MaybeUninit<u8>];
39
40    /// Extends the initialized region of the buffer.
41    ///
42    /// # Safety
43    /// The caller must ensure that the next `len` bytes have been initialized.
44    unsafe fn extend_written(&mut self, len: usize);
45}
46
47impl Buffer for Vec<u8> {
48    unsafe fn unwritten(&mut self) -> &mut [MaybeUninit<u8>] {
49        self.spare_capacity_mut()
50    }
51
52    unsafe fn extend_written(&mut self, len: usize) {
53        // SAFETY: The caller guarantees that `len` bytes have been written.
54        unsafe {
55            self.set_len(self.len() + len);
56        }
57    }
58}
59
60impl Buffer for Buf<'_> {
61    unsafe fn unwritten(&mut self) -> &mut [MaybeUninit<u8>] {
62        &mut self.buf[*self.filled..]
63    }
64
65    unsafe fn extend_written(&mut self, len: usize) {
66        *self.filled += len;
67    }
68}
69
70#[cfg(feature = "std")]
71impl Buffer for std::io::Cursor<&mut [u8]> {
72    unsafe fn unwritten(&mut self) -> &mut [MaybeUninit<u8>] {
73        let pos = core::cmp::min(self.position(), self.get_ref().len() as u64) as usize;
74        let slice = self.get_mut();
75        let slice = &mut slice[pos..];
76        // SAFETY: the caller promises not to uninitialize any initialized data.
77        unsafe { core::slice::from_raw_parts_mut(slice.as_mut_ptr().cast(), slice.len()) }
78    }
79
80    unsafe fn extend_written(&mut self, len: usize) {
81        self.set_position(self.position() + len as u64);
82    }
83}
84
85/// An accessor for writing to a partially-initialized byte buffer.
86pub struct Buf<'a> {
87    buf: &'a mut [MaybeUninit<u8>],
88    filled: &'a mut usize,
89}
90
91impl Buf<'_> {
92    /// Returns the remaining bytes that fit.
93    #[inline(always)]
94    pub fn remaining(&self) -> usize {
95        self.buf.len() - *self.filled
96    }
97
98    /// Returns the number of bytes that have been written.
99    #[inline(always)]
100    pub fn len(&self) -> usize {
101        *self.filled
102    }
103
104    /// Extends the initialized portion of the buffer with `b`. Panics if it
105    /// doesn't fit.
106    #[inline(always)]
107    pub fn push(&mut self, b: u8) {
108        self.buf[*self.filled] = MaybeUninit::new(b);
109        *self.filled += 1;
110    }
111
112    /// Extends the initialized portion of the buffer with `buf`. Panics if the
113    /// data does not fit.
114    #[inline(always)]
115    pub fn append(&mut self, buf: &[u8]) {
116        assert!(buf.len() <= self.remaining());
117        // SAFETY: copying into self.buf with bounds checked above.
118        unsafe {
119            self.buf
120                .as_mut_ptr()
121                .add(*self.filled)
122                .cast::<u8>()
123                .copy_from_nonoverlapping(buf.as_ptr(), buf.len());
124        }
125        *self.filled += buf.len();
126    }
127
128    /// Extends the initialized portion of the buffer with `len` bytes equal to
129    /// `val`. Panics if the data does not fit.
130    #[inline(always)]
131    pub fn fill(&mut self, val: u8, len: usize) {
132        self.buf[*self.filled..][..len].fill(MaybeUninit::new(val));
133        *self.filled += len;
134    }
135
136    /// Splits this buffer into two at `split_at` and calls `f` to fill out each
137    /// part.
138    ///
139    /// If the left buffer is not filled in full but the right buffer is
140    /// partially initialized, then the remainder of the left buffer will be
141    /// zero-initialized.
142    #[track_caller]
143    pub fn write_split<R>(&mut self, split_at: usize, f: impl FnOnce(Buf<'_>, Buf<'_>) -> R) -> R {
144        let (left, right) = self.buf[*self.filled..].split_at_mut(split_at);
145        let mut left_filled = 0;
146        let mut right_filled = 0;
147        let r = f(
148            Buf {
149                buf: left,
150                filled: &mut left_filled,
151            },
152            Buf {
153                buf: right,
154                filled: &mut right_filled,
155            },
156        );
157        assert!(left_filled <= left.len());
158        assert!(right_filled <= right.len());
159        *self.filled += left_filled;
160        if right_filled > 0 {
161            let to_zero = left.len() - left_filled;
162            self.fill(0, to_zero);
163            *self.filled += right_filled;
164        }
165        r
166    }
167}
168
169/// Calls `f` with a [`Buf`], which provides safe methods for
170/// extending the initialized portion of the buffer.
171pub fn write_with<T, F, R>(buffer: &mut T, f: F) -> R
172where
173    T: Buffer + ?Sized,
174    F: FnOnce(Buf<'_>) -> R,
175{
176    let mut filled = 0;
177    // SAFETY: Buf will only write initialized bytes to the buffer.
178    let buf = unsafe { buffer.unwritten() };
179
180    let r = f(Buf {
181        buf,
182        filled: &mut filled,
183    });
184    // SAFETY: `filled` bytes are known to have been written.
185    unsafe {
186        buffer.extend_written(filled);
187    }
188    r
189}
190
191#[cfg(test)]
192mod tests {
193    use super::write_with;
194    use alloc::vec;
195
196    #[test]
197    #[should_panic]
198    fn test_append_vec_panic() {
199        let mut v = vec![1, 2, 3];
200        write_with(&mut v, |mut buf| {
201            buf.append(&vec![0; buf.remaining() + 1]);
202        });
203    }
204
205    #[test]
206    fn test_append_vec() {
207        let mut v = vec![1, 2, 3, 4];
208        v.reserve(3);
209
210        write_with(&mut v, |mut buf| {
211            buf.append(&[5, 6]);
212            buf.push(7);
213        });
214        assert_eq!(&v, &[1, 2, 3, 4, 5, 6, 7]);
215    }
216
217    #[test]
218    #[cfg(feature = "std")]
219    fn test_cursor_multiple_writes() {
220        let mut backing = [0u8; 8];
221        let mut cursor = std::io::Cursor::new(&mut backing[..]);
222
223        // First write: fills positions 0..3.
224        write_with(&mut cursor, |mut buf| {
225            buf.append(&[1, 2, 3]);
226        });
227
228        // Second write: must continue at position 3, not overwrite from 0.
229        write_with(&mut cursor, |mut buf| {
230            buf.append(&[4, 5]);
231        });
232
233        assert_eq!(cursor.position(), 5);
234        assert_eq!(&backing[..5], &[1, 2, 3, 4, 5]);
235    }
236
237    #[test]
238    #[cfg(feature = "std")]
239    fn test_cursor_position_beyond_slice() {
240        let mut backing = [0u8; 4];
241        let mut cursor = std::io::Cursor::new(&mut backing[..]);
242        cursor.set_position(100); // way past the end
243
244        // Should get an empty unwritten region, not panic.
245        write_with(&mut cursor, |buf| {
246            assert_eq!(buf.remaining(), 0);
247        });
248    }
249}