1#![expect(clippy::missing_safety_doc)]
7
8use super::StructMetadata;
9use super::TableEncoder;
10use crate::Error;
11use crate::FieldDecode;
12use crate::MessageDecode;
13use crate::inplace::InplaceOption;
14use crate::protobuf::FieldReader;
15use crate::protobuf::MessageReader;
16use alloc::slice;
17use alloc::vec;
18use core::marker::PhantomData;
19use core::mem::MaybeUninit;
20
21unsafe fn run_inplace<T, R>(
28 item: &mut InplaceOption<'_, T>,
29 f: impl FnOnce(*mut u8, &mut bool) -> R,
30) -> R {
31 let mut initialized = item.forget();
32 let base = item.as_mut_ptr().cast::<u8>();
33 let r = f(base, &mut initialized);
34 if initialized {
35 unsafe { item.set_init_unchecked() };
37 }
38 r
39}
40
41impl<'de, T, R> MessageDecode<'de, T, R> for TableEncoder
42where
43 T: StructDecodeMetadata<'de, R>,
44{
45 fn read_message(
46 item: &mut InplaceOption<'_, T>,
47 reader: MessageReader<'de, '_, R>,
48 ) -> crate::Result<()> {
49 unsafe {
51 run_inplace(item, |base, initialized| {
52 read_fields(
53 T::NUMBERS,
54 T::DECODERS,
55 T::OFFSETS,
56 base,
57 initialized,
58 reader,
59 )
60 })
61 }
62 }
63}
64
65impl<'de, T, R> FieldDecode<'de, T, R> for TableEncoder
66where
67 T: StructDecodeMetadata<'de, R>,
68{
69 const ENTRY: DecoderEntry<'de, T, R> = DecoderEntry::table();
72
73 fn read_field(
74 item: &mut InplaceOption<'_, T>,
75 reader: FieldReader<'de, '_, R>,
76 ) -> crate::Result<()> {
77 unsafe {
79 run_inplace(item, |base, initialized| {
80 read_message(
81 T::NUMBERS,
82 T::DECODERS,
83 T::OFFSETS,
84 base,
85 initialized,
86 reader,
87 )
88 })
89 }
90 }
91
92 fn default_field(item: &mut InplaceOption<'_, T>) -> crate::Result<()> {
93 unsafe {
95 run_inplace(item, |base, initialized| {
96 default_fields(T::DECODERS, T::OFFSETS, base, initialized)
97 })
98 }
99 }
100}
101
102#[doc(hidden)] pub unsafe fn read_message<R>(
111 numbers: &[u32],
112 decoders: &[ErasedDecoderEntry],
113 offsets: &[usize],
114 base: *mut u8,
115 struct_initialized: &mut bool,
116 reader: FieldReader<'_, '_, R>,
117) -> Result<(), Error> {
118 assert_eq!(numbers.len(), decoders.len());
119 assert_eq!(numbers.len(), offsets.len());
120 unsafe {
122 read_message_by_ptr(
125 numbers.len(),
126 numbers.as_ptr(),
127 decoders.as_ptr(),
128 offsets.as_ptr(),
129 base,
130 struct_initialized,
131 reader,
132 )
133 }
134}
135
136#[inline(never)]
138unsafe fn read_message_by_ptr<R>(
139 count: usize,
140 numbers: *const u32,
141 decoders: *const ErasedDecoderEntry,
142 offsets: *const usize,
143 base: *mut u8,
144 struct_initialized: &mut bool,
145 reader: FieldReader<'_, '_, R>,
146) -> Result<(), Error> {
147 unsafe {
149 read_fields_inline(
150 slice::from_raw_parts(numbers, count),
151 slice::from_raw_parts(decoders, count),
152 slice::from_raw_parts(offsets, count),
153 base,
154 struct_initialized,
155 reader.message()?,
156 )
157 }
158}
159
160unsafe fn read_fields<R>(
168 numbers: &[u32],
169 decoders: &[ErasedDecoderEntry],
170 offsets: &[usize],
171 base: *mut u8,
172 struct_initialized: &mut bool,
173 reader: MessageReader<'_, '_, R>,
174) -> Result<(), Error> {
175 assert_eq!(numbers.len(), decoders.len());
176 assert_eq!(numbers.len(), offsets.len());
177 unsafe {
179 read_fields_by_ptr(
182 numbers.len(),
183 numbers.as_ptr(),
184 decoders.as_ptr(),
185 offsets.as_ptr(),
186 base,
187 struct_initialized,
188 reader,
189 )
190 }
191}
192
193#[inline(never)]
195unsafe fn read_fields_by_ptr<R>(
196 count: usize,
197 numbers: *const u32,
198 decoders: *const ErasedDecoderEntry,
199 offsets: *const usize,
200 base: *mut u8,
201 struct_initialized: &mut bool,
202 reader: MessageReader<'_, '_, R>,
203) -> Result<(), Error> {
204 unsafe {
206 read_fields_inline(
207 slice::from_raw_parts(numbers, count),
208 slice::from_raw_parts(decoders, count),
209 slice::from_raw_parts(offsets, count),
210 base,
211 struct_initialized,
212 reader,
213 )
214 }
215}
216
217unsafe fn read_fields_inline<R>(
225 numbers: &[u32],
226 decoders: &[ErasedDecoderEntry],
227 offsets: &[usize],
228 base: *mut u8,
229 struct_initialized: &mut bool,
230 reader: MessageReader<'_, '_, R>,
231) -> Result<(), Error> {
232 const STACK_LIMIT: usize = 32;
233 let mut field_init_static;
234 let mut field_init_dynamic;
235 let field_inits = if numbers.len() <= STACK_LIMIT {
236 field_init_static = [false; STACK_LIMIT];
237 field_init_static[..numbers.len()].fill(*struct_initialized);
238 &mut field_init_static[..numbers.len()]
239 } else {
240 field_init_dynamic = vec![*struct_initialized; numbers.len()];
241 &mut field_init_dynamic[..]
242 };
243
244 let r = unsafe { read_fields_inner(numbers, decoders, offsets, base, field_inits, reader) };
246 *struct_initialized = true;
247 if r.is_err() && !field_inits.iter().all(|&b| b) {
248 for ((field_init, &offset), decoder) in field_inits.iter_mut().zip(offsets).zip(decoders) {
250 if *field_init {
251 unsafe {
253 decoder.drop_field(base.add(offset));
254 }
255 }
256 }
257 *struct_initialized = false;
258 }
259 r
260}
261
262unsafe fn read_fields_inner<R>(
265 numbers: &[u32],
266 decoders: &[ErasedDecoderEntry],
267 offsets: &[usize],
268 base: *mut u8,
269 field_init: &mut [bool],
270 reader: MessageReader<'_, '_, R>,
271) -> Result<(), Error> {
272 let decoders = &decoders[..numbers.len()];
273 let offsets = &offsets[..numbers.len()];
274 let field_init = &mut field_init[..numbers.len()];
275 for field in reader {
276 let (number, reader) = field?;
277 if let Some(index) = numbers.iter().position(|&n| n == number) {
278 let decoder = &decoders[index];
279 unsafe {
281 decoder.read_field(base.add(offsets[index]), &mut field_init[index], reader)?;
282 }
283 }
284 }
285 for ((field_init, &offset), decoder) in field_init.iter_mut().zip(offsets).zip(decoders) {
286 if !*field_init {
287 unsafe {
289 decoder.default_field(base.add(offset), field_init)?;
290 }
291 assert!(*field_init);
292 }
293 }
294 Ok(())
295}
296
297unsafe fn default_fields(
305 decoders: &[ErasedDecoderEntry],
306 offsets: &[usize],
307 base: *mut u8,
308 struct_initialized: &mut bool,
309) -> Result<(), Error> {
310 assert_eq!(decoders.len(), offsets.len());
311 unsafe {
313 default_fields_by_ptr(
314 decoders.len(),
315 decoders.as_ptr(),
316 offsets.as_ptr(),
317 base,
318 struct_initialized,
319 )
320 }
321}
322
323#[inline(never)]
324unsafe fn default_fields_by_ptr(
325 count: usize,
326 decoders: *const ErasedDecoderEntry,
327 offsets: *const usize,
328 base: *mut u8,
329 struct_initialized: &mut bool,
330) -> Result<(), Error> {
331 unsafe {
333 default_fields_inline(
334 slice::from_raw_parts(decoders, count),
335 slice::from_raw_parts(offsets, count),
336 base,
337 struct_initialized,
338 )
339 }
340}
341
342unsafe fn default_fields_inline(
343 decoders: &[ErasedDecoderEntry],
344 offsets: &[usize],
345 base: *mut u8,
346 struct_initialized: &mut bool,
347) -> Result<(), Error> {
348 for (i, (&offset, decoder)) in offsets.iter().zip(decoders).enumerate() {
349 let mut field_initialized = *struct_initialized;
350 let r = unsafe { decoder.default_field(base.add(offset), &mut field_initialized) };
352 if let Err(err) = r {
353 if !field_initialized || !*struct_initialized {
354 let initialized_until = i;
356 for (i, (&offset, decoder)) in offsets.iter().zip(decoders).enumerate() {
357 if i < initialized_until
358 || (i == initialized_until && field_initialized)
359 || (i > initialized_until && *struct_initialized)
360 {
361 unsafe {
363 decoder.drop_field(base.add(offset));
364 }
365 }
366 }
367 *struct_initialized = false;
368 }
369 return Err(err);
370 }
371 assert!(field_initialized);
372 }
373 *struct_initialized = true;
374 Ok(())
375}
376
377pub unsafe trait StructDecodeMetadata<'de, R>: StructMetadata {
384 const DECODERS: &'static [ErasedDecoderEntry];
386}
387
388pub struct DecoderEntry<'a, T, R>(
395 ErasedDecoderEntry,
396 PhantomData<fn(&mut T, &mut R, &'a mut ())>,
397);
398
399impl<'a, T, R> DecoderEntry<'a, T, R> {
400 pub(crate) const unsafe fn new_unchecked(entry: ErasedDecoderEntry) -> Self {
404 Self(entry, PhantomData)
405 }
406
407 pub(crate) const fn custom<E: FieldDecode<'a, T, R>>() -> Self {
408 Self(
409 ErasedDecoderEntry(
410 core::ptr::from_ref(
411 const {
412 &StaticDecoderVtable {
413 read_fn: read_field_dyn::<T, R, E>,
414 default_fn: default_field_dyn::<T, R, E>,
415 drop_fn: if core::mem::needs_drop::<T>() {
416 Some(drop_field_dyn::<T>)
417 } else {
418 None
419 },
420 }
421 },
422 )
423 .cast(),
424 ),
425 PhantomData,
426 )
427 }
428
429 const fn table() -> Self
430 where
431 T: StructDecodeMetadata<'a, R>,
432 {
433 Self(
434 ErasedDecoderEntry(
435 core::ptr::from_ref(
436 const {
437 &DecoderTable {
438 count: T::NUMBERS.len(),
439 numbers: T::NUMBERS.as_ptr(),
440 decoders: T::DECODERS.as_ptr(),
441 offsets: T::OFFSETS.as_ptr(),
442 }
443 },
444 )
445 .cast::<()>()
446 .wrapping_byte_add(ENTRY_IS_TABLE),
447 ),
448 PhantomData,
449 )
450 }
451
452 pub const fn erase(&self) -> ErasedDecoderEntry {
454 self.0
455 }
456}
457
458#[derive(Copy, Clone, Debug)]
463pub struct ErasedDecoderEntry(*const ());
464
465unsafe impl Send for ErasedDecoderEntry {}
468unsafe impl Sync for ErasedDecoderEntry {}
471
472const ENTRY_IS_TABLE: usize = 1;
473
474const _: () = assert!(align_of::<ErasedDecoderEntry>() > ENTRY_IS_TABLE);
475const _: () = assert!(align_of::<StaticDecoderVtable<'_, ()>>() > ENTRY_IS_TABLE);
476
477impl ErasedDecoderEntry {
478 unsafe fn decode<'de, R>(&self) -> Result<&StaticDecoderVtable<'de, R>, &DecoderTable> {
483 unsafe {
485 if self.0 as usize & ENTRY_IS_TABLE == 0 {
486 Ok(&*self.0.cast::<StaticDecoderVtable<'_, R>>())
487 } else {
488 Err(&*self
489 .0
490 .wrapping_byte_sub(ENTRY_IS_TABLE)
491 .cast::<DecoderTable>())
492 }
493 }
494 }
495
496 pub unsafe fn read_field<R>(
503 &self,
504 ptr: *mut u8,
505 init: &mut bool,
506 reader: FieldReader<'_, '_, R>,
507 ) -> Result<(), Error> {
508 unsafe {
510 match self.decode::<R>() {
511 Ok(vtable) => (vtable.read_fn)(ptr, init, reader),
512 Err(table) => read_message_by_ptr(
513 table.count,
514 table.numbers,
515 table.decoders,
516 table.offsets,
517 ptr,
518 init,
519 reader,
520 ),
521 }
522 }
523 }
524
525 pub unsafe fn default_field(&self, ptr: *mut u8, init: &mut bool) -> Result<(), Error> {
531 unsafe {
533 match self.decode::<()>() {
534 Ok(vtable) => (vtable.default_fn)(ptr, init),
535 Err(table) => {
536 default_fields_by_ptr(table.count, table.decoders, table.offsets, ptr, init)
537 }
538 }
539 }
540 }
541
542 pub unsafe fn drop_field(&self, ptr: *mut u8) {
548 unsafe {
550 match self.decode::<()>() {
551 Ok(vtable) => {
552 if let Some(drop_fn) = vtable.drop_fn {
553 drop_fn(ptr);
554 }
555 }
556 Err(table) => {
557 for i in 0..table.count {
558 let offset = *table.offsets.add(i);
559 let decoder = &*table.decoders.add(i);
560 decoder.drop_field(ptr.add(offset));
561 }
562 }
563 }
564 }
565 }
566}
567
568struct DecoderTable {
569 count: usize,
570 numbers: *const u32,
571 decoders: *const ErasedDecoderEntry,
572 offsets: *const usize,
573}
574
575#[repr(C)] struct StaticDecoderVtable<'de, R> {
578 read_fn: unsafe fn(*mut u8, init: *mut bool, FieldReader<'de, '_, R>) -> Result<(), Error>,
579 default_fn: unsafe fn(*mut u8, init: *mut bool) -> Result<(), Error>,
580 drop_fn: Option<unsafe fn(*mut u8)>,
581}
582
583unsafe fn read_field_dyn<'a, T, R, E: FieldDecode<'a, T, R>>(
584 field: *mut u8,
585 init: *mut bool,
586 reader: FieldReader<'a, '_, R>,
587) -> Result<(), Error> {
588 let init = unsafe { &mut *init };
590 let field = unsafe { &mut *field.cast::<MaybeUninit<T>>() };
593 let mut field = if *init {
594 unsafe { InplaceOption::new_init_unchecked(field) }
596 } else {
597 InplaceOption::uninit(field)
598 };
599 let r = E::read_field(&mut field, reader);
600 *init = field.forget();
601 r
602}
603
604unsafe fn default_field_dyn<'a, T, R, E: FieldDecode<'a, T, R>>(
605 field: *mut u8,
606 init: *mut bool,
607) -> Result<(), Error> {
608 let init = unsafe { &mut *init };
610 let field = unsafe { &mut *field.cast::<MaybeUninit<T>>() };
613 let mut field = if *init {
614 unsafe { InplaceOption::new_init_unchecked(field) }
616 } else {
617 InplaceOption::uninit(field)
618 };
619 let r = E::default_field(&mut field);
620 *init = field.forget();
621 r
622}
623
624unsafe fn drop_field_dyn<T>(field: *mut u8) {
625 let field = field.cast::<T>();
626 unsafe { field.drop_in_place() }
629}