mesh_protobuf/table/
decode.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! Table-based decoding.
5
6#![expect(clippy::missing_safety_doc)]
7
8use super::StructMetadata;
9use super::TableEncoder;
10use crate::Error;
11use crate::FieldDecode;
12use crate::MessageDecode;
13use crate::inplace::InplaceOption;
14use crate::protobuf::FieldReader;
15use crate::protobuf::MessageReader;
16use alloc::slice;
17use alloc::vec;
18use core::marker::PhantomData;
19use core::mem::MaybeUninit;
20
21/// Calls `f` on `item`, splitting the pointer and initialized flag out.
22///
23/// # Safety
24///
25/// The caller must ensure that on the return, the bool specifies the
26/// initialized state of the item.
27unsafe fn run_inplace<T, R>(
28    item: &mut InplaceOption<'_, T>,
29    f: impl FnOnce(*mut u8, &mut bool) -> R,
30) -> R {
31    let mut initialized = item.forget();
32    let base = item.as_mut_ptr().cast::<u8>();
33    let r = f(base, &mut initialized);
34    if initialized {
35        // SAFETY: the caller ensures that `item` is initialized.
36        unsafe { item.set_init_unchecked() };
37    }
38    r
39}
40
41impl<'de, T, R> MessageDecode<'de, T, R> for TableEncoder
42where
43    T: StructDecodeMetadata<'de, R>,
44{
45    fn read_message(
46        item: &mut InplaceOption<'_, T>,
47        reader: MessageReader<'de, '_, R>,
48    ) -> crate::Result<()> {
49        // SAFETY: T guarantees that the metadata is valid.
50        unsafe {
51            run_inplace(item, |base, initialized| {
52                read_fields(
53                    T::NUMBERS,
54                    T::DECODERS,
55                    T::OFFSETS,
56                    base,
57                    initialized,
58                    reader,
59                )
60            })
61        }
62    }
63}
64
65impl<'de, T, R> FieldDecode<'de, T, R> for TableEncoder
66where
67    T: StructDecodeMetadata<'de, R>,
68{
69    // Override the default implementation to use the table decoder directly.
70    // This saves code size by avoiding extra stub functions and vtables.
71    const ENTRY: DecoderEntry<'de, T, R> = DecoderEntry::table();
72
73    fn read_field(
74        item: &mut InplaceOption<'_, T>,
75        reader: FieldReader<'de, '_, R>,
76    ) -> crate::Result<()> {
77        // SAFETY: T guarantees that the metadata is valid.
78        unsafe {
79            run_inplace(item, |base, initialized| {
80                read_message(
81                    T::NUMBERS,
82                    T::DECODERS,
83                    T::OFFSETS,
84                    base,
85                    initialized,
86                    reader,
87                )
88            })
89        }
90    }
91
92    fn default_field(item: &mut InplaceOption<'_, T>) -> crate::Result<()> {
93        // SAFETY: T guarantees that the metadata is valid.
94        unsafe {
95            run_inplace(item, |base, initialized| {
96                default_fields(T::DECODERS, T::OFFSETS, base, initialized)
97            })
98        }
99    }
100}
101
102/// Read a field as a message from the provided field metadata.
103///
104/// # Safety
105///
106/// The caller must ensure that `base` points to a location that can be written
107/// to, that `struct_initialized` is set correctly, and that the metadata is
108/// correct and complete for the type of the object pointed to by `base`.
109#[doc(hidden)] // only used publicly in mesh_derive
110pub unsafe fn read_message<R>(
111    numbers: &[u32],
112    decoders: &[ErasedDecoderEntry],
113    offsets: &[usize],
114    base: *mut u8,
115    struct_initialized: &mut bool,
116    reader: FieldReader<'_, '_, R>,
117) -> Result<(), Error> {
118    assert_eq!(numbers.len(), decoders.len());
119    assert_eq!(numbers.len(), offsets.len());
120    // SAFETY: guaranteed by caller and by the assertions above.
121    unsafe {
122        // Convert the slices to pointers and a single length to shrink
123        // code size.
124        read_message_by_ptr(
125            numbers.len(),
126            numbers.as_ptr(),
127            decoders.as_ptr(),
128            offsets.as_ptr(),
129            base,
130            struct_initialized,
131            reader,
132        )
133    }
134}
135
136// Don't inline this since it is used by every table decoder instantiation.
137#[inline(never)]
138unsafe fn read_message_by_ptr<R>(
139    count: usize,
140    numbers: *const u32,
141    decoders: *const ErasedDecoderEntry,
142    offsets: *const usize,
143    base: *mut u8,
144    struct_initialized: &mut bool,
145    reader: FieldReader<'_, '_, R>,
146) -> Result<(), Error> {
147    // SAFETY: guaranteed by caller.
148    unsafe {
149        read_fields_inline(
150            slice::from_raw_parts(numbers, count),
151            slice::from_raw_parts(decoders, count),
152            slice::from_raw_parts(offsets, count),
153            base,
154            struct_initialized,
155            reader.message()?,
156        )
157    }
158}
159
160/// Read a message from the provided field metadata.
161///
162/// # Safety
163///
164/// The caller must ensure that `base` points to a location that can be written
165/// to, that `struct_initialized` is set correctly, and that the metadata is
166/// correct and complete for the type of the object pointed to by `base`.
167unsafe fn read_fields<R>(
168    numbers: &[u32],
169    decoders: &[ErasedDecoderEntry],
170    offsets: &[usize],
171    base: *mut u8,
172    struct_initialized: &mut bool,
173    reader: MessageReader<'_, '_, R>,
174) -> Result<(), Error> {
175    assert_eq!(numbers.len(), decoders.len());
176    assert_eq!(numbers.len(), offsets.len());
177    // SAFETY: guaranteed by caller and by the assertions above.
178    unsafe {
179        // Convert the slices to pointers and a single length to shrink
180        // code size.
181        read_fields_by_ptr(
182            numbers.len(),
183            numbers.as_ptr(),
184            decoders.as_ptr(),
185            offsets.as_ptr(),
186            base,
187            struct_initialized,
188            reader,
189        )
190    }
191}
192
193// Don't inline this since it is used by every table decoder instantiation.
194#[inline(never)]
195unsafe fn read_fields_by_ptr<R>(
196    count: usize,
197    numbers: *const u32,
198    decoders: *const ErasedDecoderEntry,
199    offsets: *const usize,
200    base: *mut u8,
201    struct_initialized: &mut bool,
202    reader: MessageReader<'_, '_, R>,
203) -> Result<(), Error> {
204    // SAFETY: guaranteed by caller.
205    unsafe {
206        read_fields_inline(
207            slice::from_raw_parts(numbers, count),
208            slice::from_raw_parts(decoders, count),
209            slice::from_raw_parts(offsets, count),
210            base,
211            struct_initialized,
212            reader,
213        )
214    }
215}
216
217/// Reads fields from the provided field metadata.
218///
219/// # Safety
220///
221/// The caller must ensure that `base` points to a location that can be written
222/// to, that `initialized` is set correctly, and that the metadata is correct
223/// and complete for the type of the object pointed to by `base`.
224unsafe fn read_fields_inline<R>(
225    numbers: &[u32],
226    decoders: &[ErasedDecoderEntry],
227    offsets: &[usize],
228    base: *mut u8,
229    struct_initialized: &mut bool,
230    reader: MessageReader<'_, '_, R>,
231) -> Result<(), Error> {
232    const STACK_LIMIT: usize = 32;
233    let mut field_init_static;
234    let mut field_init_dynamic;
235    let field_inits = if numbers.len() <= STACK_LIMIT {
236        field_init_static = [false; STACK_LIMIT];
237        field_init_static[..numbers.len()].fill(*struct_initialized);
238        &mut field_init_static[..numbers.len()]
239    } else {
240        field_init_dynamic = vec![*struct_initialized; numbers.len()];
241        &mut field_init_dynamic[..]
242    };
243
244    // SAFETY: guaranteed by caller.
245    let r = unsafe { read_fields_inner(numbers, decoders, offsets, base, field_inits, reader) };
246    *struct_initialized = true;
247    if r.is_err() && !field_inits.iter().all(|&b| b) {
248        // Drop any initialized fields.
249        for ((field_init, &offset), decoder) in field_inits.iter_mut().zip(offsets).zip(decoders) {
250            if *field_init {
251                // SAFETY: guaranteed by the caller.
252                unsafe {
253                    decoder.drop_field(base.add(offset));
254                }
255            }
256        }
257        *struct_initialized = false;
258    }
259    r
260}
261
262/// Reads fields from the provided field metadata, but does not drop any fields
263/// of a partially initialized message on failure.
264unsafe fn read_fields_inner<R>(
265    numbers: &[u32],
266    decoders: &[ErasedDecoderEntry],
267    offsets: &[usize],
268    base: *mut u8,
269    field_init: &mut [bool],
270    reader: MessageReader<'_, '_, R>,
271) -> Result<(), Error> {
272    let decoders = &decoders[..numbers.len()];
273    let offsets = &offsets[..numbers.len()];
274    let field_init = &mut field_init[..numbers.len()];
275    for field in reader {
276        let (number, reader) = field?;
277        if let Some(index) = numbers.iter().position(|&n| n == number) {
278            let decoder = &decoders[index];
279            // SAFETY: the decoder is valid according to the caller.
280            unsafe {
281                decoder.read_field(base.add(offsets[index]), &mut field_init[index], reader)?;
282            }
283        }
284    }
285    for ((field_init, &offset), decoder) in field_init.iter_mut().zip(offsets).zip(decoders) {
286        if !*field_init {
287            // SAFETY: the decoder is valid according to the caller.
288            unsafe {
289                decoder.default_field(base.add(offset), field_init)?;
290            }
291            assert!(*field_init);
292        }
293    }
294    Ok(())
295}
296
297/// Sets fields to their default values from the provided field metadata.
298///
299/// # Safety
300///
301/// The caller must ensure that `base` points to a location that can be written
302/// to, that `struct_initialized` is set correctly, and that the metadata is
303/// correct and complete for the type of the object pointed to by `base`.
304unsafe fn default_fields(
305    decoders: &[ErasedDecoderEntry],
306    offsets: &[usize],
307    base: *mut u8,
308    struct_initialized: &mut bool,
309) -> Result<(), Error> {
310    assert_eq!(decoders.len(), offsets.len());
311    // SAFETY: guaranteed by caller and by the assertion above.
312    unsafe {
313        default_fields_by_ptr(
314            decoders.len(),
315            decoders.as_ptr(),
316            offsets.as_ptr(),
317            base,
318            struct_initialized,
319        )
320    }
321}
322
323#[inline(never)]
324unsafe fn default_fields_by_ptr(
325    count: usize,
326    decoders: *const ErasedDecoderEntry,
327    offsets: *const usize,
328    base: *mut u8,
329    struct_initialized: &mut bool,
330) -> Result<(), Error> {
331    // SAFETY: guaranteed by caller.
332    unsafe {
333        default_fields_inline(
334            slice::from_raw_parts(decoders, count),
335            slice::from_raw_parts(offsets, count),
336            base,
337            struct_initialized,
338        )
339    }
340}
341
342unsafe fn default_fields_inline(
343    decoders: &[ErasedDecoderEntry],
344    offsets: &[usize],
345    base: *mut u8,
346    struct_initialized: &mut bool,
347) -> Result<(), Error> {
348    for (i, (&offset, decoder)) in offsets.iter().zip(decoders).enumerate() {
349        let mut field_initialized = *struct_initialized;
350        // SAFETY: the decoder is valid according to the caller.
351        let r = unsafe { decoder.default_field(base.add(offset), &mut field_initialized) };
352        if let Err(err) = r {
353            if !field_initialized || !*struct_initialized {
354                // Drop initialized fields.
355                let initialized_until = i;
356                for (i, (&offset, decoder)) in offsets.iter().zip(decoders).enumerate() {
357                    if i < initialized_until
358                        || (i == initialized_until && field_initialized)
359                        || (i > initialized_until && *struct_initialized)
360                    {
361                        // SAFETY: the decoder is valid according to the caller, and the field is initialized.
362                        unsafe {
363                            decoder.drop_field(base.add(offset));
364                        }
365                    }
366                }
367                *struct_initialized = false;
368            }
369            return Err(err);
370        }
371        assert!(field_initialized);
372    }
373    *struct_initialized = true;
374    Ok(())
375}
376
377/// The struct metadata for decoding a struct.
378///
379/// # Safety
380///
381/// The implementor must ensure that the `DECODERS` are correct and complete for
382/// `Self`, such that if every field is decoded, then the struct value is valid.
383pub unsafe trait StructDecodeMetadata<'de, R>: StructMetadata {
384    /// The list of decoder vtables.
385    const DECODERS: &'static [ErasedDecoderEntry];
386}
387
388/// An entry in the decoder table.
389///
390/// This contains the metadata necessary to apply an decoder to a field.
391///
392/// This cannot be instantiated directly; use [`FieldDecode::ENTRY`] to get an
393/// instance for a particular decoder.
394pub struct DecoderEntry<'a, T, R>(
395    ErasedDecoderEntry,
396    PhantomData<fn(&mut T, &mut R, &'a mut ())>,
397);
398
399impl<'a, T, R> DecoderEntry<'a, T, R> {
400    /// # Safety
401    /// The caller must ensure that the erased entry is an valid entry for `T`
402    /// and `R`.
403    pub(crate) const unsafe fn new_unchecked(entry: ErasedDecoderEntry) -> Self {
404        Self(entry, PhantomData)
405    }
406
407    pub(crate) const fn custom<E: FieldDecode<'a, T, R>>() -> Self {
408        Self(
409            ErasedDecoderEntry(
410                core::ptr::from_ref(
411                    const {
412                        &StaticDecoderVtable {
413                            read_fn: read_field_dyn::<T, R, E>,
414                            default_fn: default_field_dyn::<T, R, E>,
415                            drop_fn: if core::mem::needs_drop::<T>() {
416                                Some(drop_field_dyn::<T>)
417                            } else {
418                                None
419                            },
420                        }
421                    },
422                )
423                .cast(),
424            ),
425            PhantomData,
426        )
427    }
428
429    const fn table() -> Self
430    where
431        T: StructDecodeMetadata<'a, R>,
432    {
433        Self(
434            ErasedDecoderEntry(
435                core::ptr::from_ref(
436                    const {
437                        &DecoderTable {
438                            count: T::NUMBERS.len(),
439                            numbers: T::NUMBERS.as_ptr(),
440                            decoders: T::DECODERS.as_ptr(),
441                            offsets: T::OFFSETS.as_ptr(),
442                        }
443                    },
444                )
445                .cast::<()>()
446                .wrapping_byte_add(ENTRY_IS_TABLE),
447            ),
448            PhantomData,
449        )
450    }
451
452    /// Erases the type of the decoder entry.
453    pub const fn erase(&self) -> ErasedDecoderEntry {
454        self.0
455    }
456}
457
458/// An entry in a [`StructDecodeMetadata::DECODERS`] table.
459//
460// Internally, this is a pointer to either a vtable or a table.
461// The low bit is used to distinguish between the two.
462#[derive(Copy, Clone, Debug)]
463pub struct ErasedDecoderEntry(*const ());
464
465// SAFETY: the entry represents a set of integers and function pointers, which
466// have no cross-thread constraints.
467unsafe impl Send for ErasedDecoderEntry {}
468// SAFETY: the entry represents a set of integers and function pointers, which
469// have no cross-thread constraints.
470unsafe impl Sync for ErasedDecoderEntry {}
471
472const ENTRY_IS_TABLE: usize = 1;
473
474const _: () = assert!(align_of::<ErasedDecoderEntry>() > ENTRY_IS_TABLE);
475const _: () = assert!(align_of::<StaticDecoderVtable<'_, ()>>() > ENTRY_IS_TABLE);
476
477impl ErasedDecoderEntry {
478    /// Decodes the entry into either a vtable or a table.
479    ///
480    /// # Safety
481    /// The caller must ensure that the encoder was for resource type `R`.
482    unsafe fn decode<'de, R>(&self) -> Result<&StaticDecoderVtable<'de, R>, &DecoderTable> {
483        // SAFETY: guaranteed by caller.
484        unsafe {
485            if self.0 as usize & ENTRY_IS_TABLE == 0 {
486                Ok(&*self.0.cast::<StaticDecoderVtable<'_, R>>())
487            } else {
488                Err(&*self
489                    .0
490                    .wrapping_byte_sub(ENTRY_IS_TABLE)
491                    .cast::<DecoderTable>())
492            }
493        }
494    }
495
496    /// Reads a field using the decoder metadata.
497    ///
498    /// # Safety
499    /// The caller must ensure that the decoder was for resource type `R` and
500    /// the object type matches what `ptr` is pointing to. `*init` must be set
501    /// if and only if the field is initialized.
502    pub unsafe fn read_field<R>(
503        &self,
504        ptr: *mut u8,
505        init: &mut bool,
506        reader: FieldReader<'_, '_, R>,
507    ) -> Result<(), Error> {
508        // SAFETY: guaranteed by caller.
509        unsafe {
510            match self.decode::<R>() {
511                Ok(vtable) => (vtable.read_fn)(ptr, init, reader),
512                Err(table) => read_message_by_ptr(
513                    table.count,
514                    table.numbers,
515                    table.decoders,
516                    table.offsets,
517                    ptr,
518                    init,
519                    reader,
520                ),
521            }
522        }
523    }
524
525    /// Initializes a value to its default state using the decoder metadata.
526    ///
527    /// # Safety
528    /// The caller must ensure that the decoder was for the object type matching
529    /// what `ptr` is pointing to.
530    pub unsafe fn default_field(&self, ptr: *mut u8, init: &mut bool) -> Result<(), Error> {
531        // SAFETY: guaranteed by caller.
532        unsafe {
533            match self.decode::<()>() {
534                Ok(vtable) => (vtable.default_fn)(ptr, init),
535                Err(table) => {
536                    default_fields_by_ptr(table.count, table.decoders, table.offsets, ptr, init)
537                }
538            }
539        }
540    }
541
542    /// Drops a value in place using the decoder metadata.
543    ///
544    /// # Safety
545    /// The caller must ensure that the decoder was for the object type matching
546    /// what `ptr` is pointing to, and that `ptr` is ready to be dropped.
547    pub unsafe fn drop_field(&self, ptr: *mut u8) {
548        // SAFETY: guaranteed by caller.
549        unsafe {
550            match self.decode::<()>() {
551                Ok(vtable) => {
552                    if let Some(drop_fn) = vtable.drop_fn {
553                        drop_fn(ptr);
554                    }
555                }
556                Err(table) => {
557                    for i in 0..table.count {
558                        let offset = *table.offsets.add(i);
559                        let decoder = &*table.decoders.add(i);
560                        decoder.drop_field(ptr.add(offset));
561                    }
562                }
563            }
564        }
565    }
566}
567
568struct DecoderTable {
569    count: usize,
570    numbers: *const u32,
571    decoders: *const ErasedDecoderEntry,
572    offsets: *const usize,
573}
574
575/// A vtable for decoding a message.
576#[repr(C)] // to ensure the layout is the same regardless of R
577struct StaticDecoderVtable<'de, R> {
578    read_fn: unsafe fn(*mut u8, init: *mut bool, FieldReader<'de, '_, R>) -> Result<(), Error>,
579    default_fn: unsafe fn(*mut u8, init: *mut bool) -> Result<(), Error>,
580    drop_fn: Option<unsafe fn(*mut u8)>,
581}
582
583unsafe fn read_field_dyn<'a, T, R, E: FieldDecode<'a, T, R>>(
584    field: *mut u8,
585    init: *mut bool,
586    reader: FieldReader<'a, '_, R>,
587) -> Result<(), Error> {
588    // SAFETY: `init` is valid according to the caller.
589    let init = unsafe { &mut *init };
590    // SAFETY: `field` is valid and points to a valid `MaybeUninit<T>` according
591    // to the caller.
592    let field = unsafe { &mut *field.cast::<MaybeUninit<T>>() };
593    let mut field = if *init {
594        // SAFETY: the caller attests that the field is initialized.
595        unsafe { InplaceOption::new_init_unchecked(field) }
596    } else {
597        InplaceOption::uninit(field)
598    };
599    let r = E::read_field(&mut field, reader);
600    *init = field.forget();
601    r
602}
603
604unsafe fn default_field_dyn<'a, T, R, E: FieldDecode<'a, T, R>>(
605    field: *mut u8,
606    init: *mut bool,
607) -> Result<(), Error> {
608    // SAFETY: `init` is valid according to the caller.
609    let init = unsafe { &mut *init };
610    // SAFETY: `field` is valid and points to a valid `MaybeUninit<T>` according
611    // to the caller.
612    let field = unsafe { &mut *field.cast::<MaybeUninit<T>>() };
613    let mut field = if *init {
614        // SAFETY: the caller attests that the field is initialized.
615        unsafe { InplaceOption::new_init_unchecked(field) }
616    } else {
617        InplaceOption::uninit(field)
618    };
619    let r = E::default_field(&mut field);
620    *init = field.forget();
621    r
622}
623
624unsafe fn drop_field_dyn<T>(field: *mut u8) {
625    let field = field.cast::<T>();
626    // SAFETY: `field` is valid and points to a valid `T` according to the
627    // caller.
628    unsafe { field.drop_in_place() }
629}