mesh_protobuf/table/
encode.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! Table-based encoding.
5
6#![expect(clippy::missing_safety_doc)]
7
8use super::StructMetadata;
9use super::TableEncoder;
10use crate::FieldEncode;
11use crate::MessageEncode;
12use crate::protobuf::FieldSizer;
13use crate::protobuf::FieldWriter;
14use crate::protobuf::MessageSizer;
15use crate::protobuf::MessageWriter;
16use alloc::slice;
17use core::marker::PhantomData;
18use core::mem::MaybeUninit;
19
20impl<T, R> MessageEncode<T, R> for TableEncoder
21where
22    T: StructEncodeMetadata<R>,
23{
24    fn write_message(item: T, writer: MessageWriter<'_, '_, R>) {
25        let mut item = MaybeUninit::new(item);
26        // SAFETY: `T` guarantees that its encoders and offsets are correct for
27        // this type.
28        unsafe {
29            write_fields(
30                T::NUMBERS,
31                T::ENCODERS,
32                T::OFFSETS,
33                item.as_mut_ptr().cast(),
34                writer,
35            );
36        }
37    }
38
39    fn compute_message_size(item: &mut T, sizer: MessageSizer<'_>) {
40        // SAFETY: `T` guarantees that its encoders and offsets are correct for
41        // this type.
42        unsafe {
43            compute_size_fields::<R>(
44                T::NUMBERS,
45                T::ENCODERS,
46                T::OFFSETS,
47                core::ptr::from_mut(item).cast::<u8>(),
48                sizer,
49            );
50        }
51    }
52}
53
54impl<T, R> FieldEncode<T, R> for TableEncoder
55where
56    T: StructEncodeMetadata<R>,
57{
58    // Override the default implementation to use the table encoder directly.
59    // This saves code size by avoiding extra stub functions and vtables.
60    const ENTRY: EncoderEntry<T, R> = EncoderEntry::table();
61
62    fn write_field(item: T, writer: FieldWriter<'_, '_, R>) {
63        let mut item = MaybeUninit::new(item);
64        // SAFETY: `T` guarantees that its encoders and offsets are correct for
65        // this type.
66        unsafe {
67            write_message(
68                T::NUMBERS,
69                T::ENCODERS,
70                T::OFFSETS,
71                item.as_mut_ptr().cast(),
72                writer,
73            );
74        }
75    }
76
77    fn compute_field_size(item: &mut T, sizer: FieldSizer<'_>) {
78        // SAFETY: `T` guarantees that its encoders and offsets are correct for
79        // this type.
80        unsafe {
81            compute_size_message::<R>(
82                T::NUMBERS,
83                T::ENCODERS,
84                T::OFFSETS,
85                core::ptr::from_mut(item).cast::<u8>(),
86                sizer,
87            );
88        }
89    }
90}
91
92unsafe fn write_message<R>(
93    numbers: &[u32],
94    encoders: &[ErasedEncoderEntry],
95    offsets: &[usize],
96    base: *mut u8,
97    writer: FieldWriter<'_, '_, R>,
98) {
99    assert_eq!(numbers.len(), encoders.len());
100    assert_eq!(numbers.len(), offsets.len());
101    // SAFETY: guaranteed by the caller.
102    unsafe {
103        write_message_by_ptr(
104            numbers.len(),
105            numbers.as_ptr(),
106            encoders.as_ptr(),
107            offsets.as_ptr(),
108            base,
109            writer,
110        )
111    }
112}
113
114#[inline(never)]
115unsafe fn write_message_by_ptr<R>(
116    count: usize,
117    numbers: *const u32,
118    encoders: *const ErasedEncoderEntry,
119    offsets: *const usize,
120    base: *mut u8,
121    writer: FieldWriter<'_, '_, R>,
122) {
123    // SAFETY: guaranteed by the caller.
124    writer.message(|writer| unsafe {
125        write_fields_inline(
126            slice::from_raw_parts(numbers, count),
127            slice::from_raw_parts(encoders, count),
128            slice::from_raw_parts(offsets, count),
129            base,
130            writer,
131        )
132    })
133}
134
135/// Writes the fields of a message using the provided metadata.
136///
137/// Note that `base` will no longer contain a valid object after this function
138/// returns.
139///
140/// # Safety
141/// The caller must ensure that the provided encoders and offsets correspond to
142/// fields in the struct at `base`, and that `base` is owned.
143#[doc(hidden)] // only used publicly for `mesh_derive`
144pub unsafe fn write_fields<R>(
145    numbers: &[u32],
146    encoders: &[ErasedEncoderEntry],
147    offsets: &[usize],
148    base: *mut u8,
149    writer: MessageWriter<'_, '_, R>,
150) {
151    assert_eq!(numbers.len(), encoders.len());
152    assert_eq!(numbers.len(), offsets.len());
153    // SAFETY: guaranteed by the caller.
154    unsafe {
155        write_fields_by_ptr(
156            numbers.len(),
157            numbers.as_ptr(),
158            encoders.as_ptr(),
159            offsets.as_ptr(),
160            base,
161            writer,
162        )
163    }
164}
165
166#[inline(never)]
167unsafe fn write_fields_by_ptr<R>(
168    count: usize,
169    numbers: *const u32,
170    encoders: *const ErasedEncoderEntry,
171    offsets: *const usize,
172    base: *mut u8,
173    writer: MessageWriter<'_, '_, R>,
174) {
175    // SAFETY: guaranteed by the caller.
176    unsafe {
177        write_fields_inline(
178            slice::from_raw_parts(numbers, count),
179            slice::from_raw_parts(encoders, count),
180            slice::from_raw_parts(offsets, count),
181            base,
182            writer,
183        )
184    }
185}
186
187unsafe fn write_fields_inline<R>(
188    numbers: &[u32],
189    encoders: &[ErasedEncoderEntry],
190    offsets: &[usize],
191    base: *mut u8,
192    mut writer: MessageWriter<'_, '_, R>,
193) {
194    for ((&number, encoder), &offset) in numbers.iter().zip(encoders).zip(offsets) {
195        let writer = writer.field(number);
196        // SAFETY: the caller guarantees that `base` points to an object
197        // compatible with this encoder and that it will not access the object
198        // through `base` after this returns.
199        unsafe {
200            let ptr = base.add(offset);
201            encoder.write_field(ptr, writer);
202        }
203    }
204}
205
206unsafe fn compute_size_message<R>(
207    numbers: &[u32],
208    encoders: &[ErasedEncoderEntry],
209    offsets: &[usize],
210    base: *mut u8,
211    sizer: FieldSizer<'_>,
212) {
213    assert_eq!(numbers.len(), encoders.len());
214    assert_eq!(numbers.len(), offsets.len());
215    // SAFETY: guaranteed by the caller.
216    unsafe {
217        compute_size_message_by_ptr::<R>(
218            numbers.len(),
219            numbers.as_ptr(),
220            encoders.as_ptr(),
221            offsets.as_ptr(),
222            base,
223            sizer,
224        )
225    }
226}
227
228#[inline(never)]
229unsafe fn compute_size_message_by_ptr<R>(
230    count: usize,
231    numbers: *const u32,
232    encoders: *const ErasedEncoderEntry,
233    offsets: *const usize,
234    base: *mut u8,
235    sizer: FieldSizer<'_>,
236) {
237    // SAFETY: guaranteed by the caller.
238    sizer.message(|sizer| unsafe {
239        compute_size_fields_inline::<R>(
240            slice::from_raw_parts(numbers, count),
241            slice::from_raw_parts(encoders, count),
242            slice::from_raw_parts(offsets, count),
243            base,
244            sizer,
245        )
246    })
247}
248
249/// Computes the size of a message using the provided metadata.
250///
251/// # Safety
252/// The caller must ensure that the provided encoders and offsets correspond to
253/// fields in the struct at `base`, and that `base` is valid for write.
254#[doc(hidden)] // only used publicly for `mesh_derive`
255pub unsafe fn compute_size_fields<R>(
256    numbers: &[u32],
257    encoders: &[ErasedEncoderEntry],
258    offsets: &[usize],
259    base: *mut u8,
260    sizer: MessageSizer<'_>,
261) {
262    assert_eq!(numbers.len(), encoders.len());
263    assert_eq!(numbers.len(), offsets.len());
264    // SAFETY: guaranteed by the caller.
265    unsafe {
266        compute_size_fields_by_ptr::<R>(
267            numbers.len(),
268            numbers.as_ptr(),
269            encoders.as_ptr(),
270            offsets.as_ptr(),
271            base,
272            sizer,
273        )
274    }
275}
276
277#[inline(never)]
278unsafe fn compute_size_fields_by_ptr<R>(
279    count: usize,
280    numbers: *const u32,
281    encoders: *const ErasedEncoderEntry,
282    offsets: *const usize,
283    base: *mut u8,
284    sizer: MessageSizer<'_>,
285) {
286    // SAFETY: guaranteed by the caller.
287    unsafe {
288        compute_size_fields_inline::<R>(
289            slice::from_raw_parts(numbers, count),
290            slice::from_raw_parts(encoders, count),
291            slice::from_raw_parts(offsets, count),
292            base,
293            sizer,
294        )
295    }
296}
297
298unsafe fn compute_size_fields_inline<R>(
299    numbers: &[u32],
300    encoders: &[ErasedEncoderEntry],
301    offsets: &[usize],
302    base: *mut u8,
303    mut sizer: MessageSizer<'_>,
304) {
305    for ((&number, encoder), &offset) in numbers.iter().zip(encoders).zip(offsets) {
306        let sizer = sizer.field(number);
307        // SAFETY: the caller guarantees that `base` points to an object
308        // compatible with this encoder.
309        unsafe {
310            let ptr = base.add(offset);
311            encoder.compute_size_field::<R>(ptr, sizer);
312        }
313    }
314}
315
316/// Metadata for encoding a struct.
317///
318/// # Safety
319///
320/// The implementor must ensure that the `ENCODERS` are correct and complete for
321/// `Self` and `R`.
322pub unsafe trait StructEncodeMetadata<R>: StructMetadata {
323    /// The list of encoder vtables.
324    const ENCODERS: &'static [ErasedEncoderEntry];
325}
326
327/// An entry in the encoder table.
328///
329/// This contains the metadata necessary to apply an encoder to a field.
330///
331/// This cannot be instantiated directly; use [`FieldEncode::ENTRY`] to get an
332/// instance for a particular encoder.
333pub struct EncoderEntry<T, R>(ErasedEncoderEntry, PhantomData<fn(T, &mut R)>);
334
335impl<T, R> EncoderEntry<T, R> {
336    /// # Safety
337    /// The caller must ensure that the erased entry is an valid entry for `T`
338    /// and `R`.
339    pub(crate) const unsafe fn new_unchecked(entry: ErasedEncoderEntry) -> Self {
340        Self(entry, PhantomData)
341    }
342
343    /// Returns an encoder entry that contains a vtable with methods for
344    /// encoding the field.
345    pub(crate) const fn custom<E: FieldEncode<T, R>>() -> Self {
346        Self(
347            ErasedEncoderEntry(
348                core::ptr::from_ref(
349                    const {
350                        &StaticEncoderVtable {
351                            write_fn: write_field_dyn::<T, R, E>,
352                            compute_size_fn: compute_size_field_dyn::<T, R, E>,
353                        }
354                    },
355                )
356                .cast::<()>(),
357            ),
358            PhantomData,
359        )
360    }
361
362    /// Returns an encoder entry that contains an encoder table.
363    const fn table() -> Self
364    where
365        T: StructEncodeMetadata<R>,
366    {
367        Self(
368            ErasedEncoderEntry(
369                core::ptr::from_ref(
370                    const {
371                        &EncoderTable {
372                            count: T::NUMBERS.len(),
373                            numbers: T::NUMBERS.as_ptr(),
374                            encoders: T::ENCODERS.as_ptr(),
375                            offsets: T::OFFSETS.as_ptr(),
376                        }
377                    },
378                )
379                .cast::<()>()
380                .wrapping_byte_add(ENTRY_IS_TABLE),
381            ),
382            PhantomData,
383        )
384    }
385
386    /// Return the type-erased encoder entry.
387    pub const fn erase(&self) -> ErasedEncoderEntry {
388        self.0
389    }
390}
391
392/// An type-erased version of [`EncoderEntry`], for use in a
393/// [`StructEncodeMetadata::ENCODERS`] table.
394//
395// Internally, this is a pointer to either a vtable or a table.
396// The low bit is used to distinguish between the two.
397#[derive(Copy, Clone, Debug)]
398pub struct ErasedEncoderEntry(*const ());
399
400// SAFETY: the entry represents a set of integers and function pointers, which
401// have no cross-thread constraints.
402unsafe impl Send for ErasedEncoderEntry {}
403// SAFETY: the entry represents a set of integers and function pointers, which
404// have no cross-thread constraints.
405unsafe impl Sync for ErasedEncoderEntry {}
406
407const ENTRY_IS_TABLE: usize = 1;
408
409const _: () = assert!(align_of::<ErasedEncoderEntry>() > ENTRY_IS_TABLE);
410const _: () = assert!(align_of::<StaticEncoderVtable<()>>() > ENTRY_IS_TABLE);
411
412impl ErasedEncoderEntry {
413    /// Decodes the entry into either a vtable or a table.
414    ///
415    /// # Safety
416    /// The caller must ensure that the encoder was for resource type `R`.
417    unsafe fn decode<R>(&self) -> Result<&StaticEncoderVtable<R>, &EncoderTable> {
418        // SAFETY: `R` is guaranteed by caller to be the right type.
419        unsafe {
420            if self.0 as usize & ENTRY_IS_TABLE == 0 {
421                Ok(&*self.0.cast::<StaticEncoderVtable<R>>())
422            } else {
423                Err(&*self
424                    .0
425                    .wrapping_byte_sub(ENTRY_IS_TABLE)
426                    .cast::<EncoderTable>())
427            }
428        }
429    }
430
431    /// Writes a value to a field using the encoder, taking ownership of `field`.
432    ///
433    /// # Safety
434    /// The caller must ensure that `field` points to a valid object of type `T`, and
435    /// that the encoder is correct for `T` and `R`.
436    pub unsafe fn write_field<R>(&self, field: *mut u8, writer: FieldWriter<'_, '_, R>) {
437        // SAFETY: caller guarantees this encoder is correct for `T` and `R` and
438        // that `field` points to a valid object of type `T`.
439        unsafe {
440            match self.decode::<R>() {
441                Ok(vtable) => (vtable.write_fn)(field, writer),
442                Err(table) => {
443                    write_message_by_ptr(
444                        table.count,
445                        table.numbers,
446                        table.encoders,
447                        table.offsets,
448                        field,
449                        writer,
450                    );
451                }
452            }
453        }
454    }
455
456    /// Computes the size of a field using the encoder.
457    ///
458    /// # Safety
459    /// The caller must ensure that `field` points to a valid object of type `T`, and
460    /// that the encoder is correct for `T` and `R`.
461    pub unsafe fn compute_size_field<R>(&self, field: *mut u8, sizer: FieldSizer<'_>) {
462        // SAFETY: caller guarantees this encoder is correct for `T` and `R` and
463        // that `field` points to a valid object of type `T`.
464        unsafe {
465            match self.decode::<R>() {
466                Ok(vtable) => (vtable.compute_size_fn)(field, sizer),
467                Err(table) => {
468                    compute_size_message_by_ptr::<R>(
469                        table.count,
470                        table.numbers,
471                        table.encoders,
472                        table.offsets,
473                        field,
474                        sizer,
475                    );
476                }
477            }
478        }
479    }
480}
481
482struct EncoderTable {
483    count: usize,
484    numbers: *const u32,
485    encoders: *const ErasedEncoderEntry,
486    offsets: *const usize,
487}
488
489/// A vtable for encoding a message.
490struct StaticEncoderVtable<R> {
491    write_fn: unsafe fn(*mut u8, FieldWriter<'_, '_, R>),
492    compute_size_fn: unsafe fn(*mut u8, FieldSizer<'_>),
493}
494
495unsafe fn write_field_dyn<T, R, E: FieldEncode<T, R>>(
496    field: *mut u8,
497    writer: FieldWriter<'_, '_, R>,
498) {
499    // SAFETY: caller guarantees that `field` points to a `T`, and this function
500    // takes ownership of it.
501    let field = unsafe { field.cast::<T>().read() };
502    E::write_field(field, writer);
503}
504
505unsafe fn compute_size_field_dyn<T, R, E: FieldEncode<T, R>>(
506    field: *mut u8,
507    sizer: FieldSizer<'_>,
508) {
509    // SAFETY: caller guarantees that `field` points to a `T`.
510    let field = unsafe { &mut *field.cast::<T>() };
511    E::compute_field_size(field, sizer);
512}