mesh_protobuf/
inplace.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! Provides an `Option`-like type for constructing values in place.
5
6use alloc::boxed::Box;
7use alloc::sync::Arc;
8use core::mem::MaybeUninit;
9
10/// A type with methods like `Option` but that operates on a mutable reference
11/// to possibly-initialized data.
12///
13/// This is used to initialize data in place without copying to/from `Option`
14/// types.
15pub struct InplaceOption<'a, T> {
16    val: &'a mut MaybeUninit<T>,
17    init: bool,
18}
19
20impl<'a, T> InplaceOption<'a, T> {
21    /// Creates an option in the uninitialized state.
22    pub fn uninit(val: &'a mut MaybeUninit<T>) -> Self {
23        Self { val, init: false }
24    }
25
26    /// Creates an option in the initialized state.
27    ///
28    /// # Safety
29    ///
30    /// The caller must guarantee that the value referenced by `val` is
31    /// initialized.
32    pub unsafe fn new_init_unchecked(val: &'a mut MaybeUninit<T>) -> Self {
33        Self { val, init: true }
34    }
35
36    /// Sets the value to the initialized state.
37    ///
38    /// # Safety
39    ///
40    /// The caller must guarantee that the underlying data has been fully
41    /// initialized.
42    pub unsafe fn set_init_unchecked(&mut self) -> &mut T {
43        self.init = true;
44        // SAFETY: the caller guarantees val is initialized.
45        unsafe { self.val.assume_init_mut() }
46    }
47
48    /// Takes the value, returning `Some(_)` if the value is initialized and
49    /// `None` otherwise.
50    pub fn take(&mut self) -> Option<T> {
51        if self.init {
52            self.init = false;
53            // SAFETY: val is initialized
54            unsafe {
55                let val = core::ptr::read(&*self.val);
56                Some(val.assume_init())
57            }
58        } else {
59            None
60        }
61    }
62
63    /// Returns a reference to the data if it's initialized.
64    pub fn as_ref(&self) -> Option<&T> {
65        if self.init {
66            // SAFETY: We have just checked that val is initialized
67            unsafe { self.val.as_ptr().as_ref() }
68        } else {
69            None
70        }
71    }
72
73    /// Returns a mutable reference to the data if it's initialized.
74    pub fn as_mut(&mut self) -> Option<&mut T> {
75        if self.init {
76            // SAFETY: val is initialized
77            Some(unsafe { self.val.assume_init_mut() })
78        } else {
79            None
80        }
81    }
82
83    /// Clears the data to the uninitialized state.
84    pub fn clear(&mut self) {
85        if self.init {
86            self.init = false;
87            // SAFETY: val is initialized
88            unsafe { self.val.assume_init_drop() };
89        }
90    }
91
92    /// Resets the data to the uninitialized state without dropping any
93    /// initialized value.
94    pub fn forget(&mut self) -> bool {
95        core::mem::take(&mut self.init)
96    }
97
98    /// Initializes the value to `v`, dropping any existing value first.
99    pub fn set(&mut self, v: T) -> &mut T {
100        self.clear();
101        self.init = true;
102        self.val.write(v)
103    }
104
105    /// Gets a mutable reference to the value, setting it to `v` first if it's
106    /// not initialized.
107    pub fn get_or_insert(&mut self, v: T) -> &mut T {
108        self.get_or_insert_with(|| v)
109    }
110
111    /// Gets a mutable reference to the value, setting it to `f()` first if it's
112    /// not initialized.
113    pub fn get_or_insert_with(&mut self, f: impl FnOnce() -> T) -> &mut T {
114        if self.init {
115            // SAFETY: val is initialized
116            unsafe { self.val.assume_init_mut() }
117        } else {
118            self.init = true;
119            self.val.write(f())
120        }
121    }
122
123    /// Returns whether the value is initialized.
124    pub fn is_some(&self) -> bool {
125        self.init
126    }
127
128    /// Returns whether the value is uninitialized.
129    pub fn is_none(&self) -> bool {
130        !self.init
131    }
132
133    /// Returns a const pointer to the underlying value (initialized or not).
134    pub fn as_ptr(&self) -> *const T {
135        self.val.as_ptr()
136    }
137
138    /// Returns a mut pointer to the underlying value (initialized or not).
139    pub fn as_mut_ptr(&mut self) -> *mut T {
140        self.val.as_mut_ptr()
141    }
142}
143
144impl<T> InplaceOption<'_, Box<T>> {
145    /// Updates a boxed value in place.
146    ///
147    /// N.B. This will allocate space for a value if one is not already present,
148    ///      which is wasteful if `f` does not actually initialize the value.
149    pub fn update_box<F, R>(&mut self, f: F) -> R
150    where
151        F: FnOnce(&mut InplaceOption<'_, T>) -> R,
152    {
153        let mut boxed;
154        let mut inplace;
155
156        if let Some(b) = self.take() {
157            // SAFETY: MaybeUninit<T> has the same layout as T.
158            boxed = unsafe { Box::from_raw(Box::into_raw(b).cast::<MaybeUninit<T>>()) };
159            // SAFETY: the value is known to be initialized.
160            inplace = unsafe { InplaceOption::new_init_unchecked(&mut *boxed) };
161        } else {
162            boxed = Box::new(MaybeUninit::uninit());
163            inplace = InplaceOption::uninit(&mut *boxed);
164        }
165
166        let r = f(&mut inplace);
167        if inplace.forget() {
168            drop(inplace);
169            // SAFETY: T has the same layout as MaybeUninit<T>, and the value is
170            // known to be initialized.
171            let b = unsafe { Box::from_raw(Box::into_raw(boxed).cast::<T>()) };
172            self.set(b);
173        }
174        r
175    }
176}
177
178impl<T: Clone> InplaceOption<'_, Arc<T>> {
179    /// Updates a reference counted value in place.
180    ///
181    /// N.B. This will allocate space for a value if one is not already present,
182    ///      which is wasteful if `f` does not actually initialize the value.
183    pub fn update_arc<F, R>(&mut self, f: F) -> R
184    where
185        F: FnOnce(&mut InplaceOption<'_, T>) -> R,
186    {
187        let mut arced;
188        let mut inplace;
189
190        if let Some(mut a) = self.take() {
191            // Ensure there is only a single reference.
192            Arc::make_mut(&mut a);
193            // SAFETY: MaybeUninit<T> has the same layout as T.
194            arced = unsafe { Arc::from_raw(Arc::into_raw(a).cast::<MaybeUninit<T>>()) };
195            // SAFETY: the value is known to be initialized.
196            unsafe {
197                inplace = InplaceOption::new_init_unchecked(Arc::get_mut(&mut arced).unwrap())
198            };
199        } else {
200            arced = Arc::new(MaybeUninit::uninit());
201            inplace = InplaceOption::uninit(Arc::get_mut(&mut arced).unwrap());
202        }
203
204        let r = f(&mut inplace);
205        if inplace.forget() {
206            drop(inplace);
207            // SAFETY: T has the same layout as MaybeUninit<T>, and the value is
208            // known to be initialized.
209            let a = unsafe { Arc::from_raw(Arc::into_raw(arced).cast::<T>()) };
210            self.set(a);
211        }
212        r
213    }
214}
215
216impl<T> Drop for InplaceOption<'_, T> {
217    fn drop(&mut self) {
218        self.clear();
219    }
220}
221
222/// Constructs a possibly-initialized [`crate::inplace::InplaceOption`] on the stack
223/// from an `Option<T>`.
224#[macro_export]
225macro_rules! inplace {
226    ($v:ident) => {
227        let opt = $v;
228        let mut $v;
229        let mut $v = match opt {
230            Some(v) => {
231                $v = core::mem::MaybeUninit::new(v);
232                // SAFETY: We just initialized the value.
233                unsafe { $crate::inplace::InplaceOption::new_init_unchecked(&mut $v) }
234            }
235            None => {
236                $v = core::mem::MaybeUninit::uninit();
237                $crate::inplace::InplaceOption::uninit(&mut $v)
238            }
239        };
240    };
241}
242
243/// Constructs an initialized [`crate::inplace::InplaceOption`] on the stack from a `T`.
244#[macro_export]
245macro_rules! inplace_some {
246    ($v:ident) => {
247        let mut $v = core::mem::MaybeUninit::new($v);
248        // SAFETY: We just initialized the value.
249        let mut $v = unsafe { $crate::inplace::InplaceOption::new_init_unchecked(&mut $v) };
250    };
251}
252
253/// Constructs an uninitialized [`crate::inplace::InplaceOption`] on the stack.
254#[macro_export]
255macro_rules! inplace_none {
256    ($v:ident) => {
257        let mut $v = core::mem::MaybeUninit::uninit();
258        let mut $v = $crate::inplace::InplaceOption::uninit(&mut $v);
259    };
260    ($v:ident : $t:ty) => {
261        let mut $v = core::mem::MaybeUninit::<$t>::uninit();
262        let mut $v = $crate::inplace::InplaceOption::uninit(&mut $v);
263    };
264}
265
266#[cfg(test)]
267mod tests {
268    use alloc::boxed::Box;
269    use alloc::string::String;
270    use alloc::string::ToString;
271    use alloc::sync::Arc;
272
273    #[test]
274    fn test_inplace_some() {
275        let v = "test".to_string();
276        inplace_some!(v);
277        assert_eq!(&v.take().unwrap(), "test");
278    }
279
280    #[test]
281    fn test_inplace_none() {
282        inplace_none!(v: String);
283        v.set("test".to_string());
284        assert_eq!(&v.take().unwrap(), "test");
285    }
286
287    #[test]
288    fn test_inplace() {
289        let v = Some("test".to_string());
290        inplace!(v);
291        assert_eq!(&v.take().unwrap(), "test");
292    }
293
294    #[test]
295    fn test_inplace_replace() {
296        let v = "old".to_string();
297        inplace_some!(v);
298        v.set("new".to_string());
299        assert_eq!(&v.take().unwrap(), "new");
300    }
301
302    #[test]
303    fn test_updates() {
304        let v = Arc::new(Box::new(1234));
305        inplace_some!(v);
306        v.update_arc(|v| {
307            v.update_box(|v| {
308                v.set(5678);
309            });
310        });
311        assert_eq!(**v.take().unwrap(), 5678);
312    }
313}