mesh_protobuf/
protobuf.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! Tools to encode and decode protobuf messages.
5
6use super::DecodeError;
7use super::InplaceOption;
8use super::MessageDecode;
9use super::MessageEncode;
10use super::RefCell;
11use super::Result;
12use super::buffer;
13use super::buffer::Buf;
14use super::buffer::Buffer;
15use crate::DefaultEncoding;
16use alloc::vec;
17use alloc::vec::Vec;
18use core::marker::PhantomData;
19use core::ops::Range;
20
21/// Writes a variable-length integer, as defined in the protobuf specification.
22fn write_varint(v: &mut Buf<'_>, mut n: u64) {
23    while n > 0x7f {
24        v.push(0x80 | (n & 0x7f) as u8);
25        n >>= 7;
26    }
27    v.push(n as u8);
28}
29
30/// Computes the length of an encoded variable-length integer.
31const fn varint_size(n: u64) -> usize {
32    if n == 0 {
33        1
34    } else {
35        let bits = 64 - n.leading_zeros() as usize;
36        (((bits - 1) / 7) + 1) & 0xff
37    }
38}
39
40/// Reads a variable-length integer, advancing `v`.
41pub(crate) fn read_varint(v: &mut &[u8]) -> Result<u64> {
42    let mut shift = 0;
43    let mut r = 0;
44    loop {
45        let (b, rest) = v.split_first().ok_or(DecodeError::EofVarInt)?;
46        *v = rest;
47        r |= (*b as u64 & 0x7f) << shift;
48        if *b & 0x80 == 0 {
49            break;
50        }
51        shift += 7;
52        if shift > 63 {
53            return Err(DecodeError::VarIntTooBig.into());
54        }
55    }
56    Ok(r)
57}
58
59/// Zigzag encodes a signed integer, as defined in the protobuf spec.
60///
61/// This is used when writing a variable-sized signed integer to keep the
62/// encoding small.
63fn zigzag(n: i64) -> u64 {
64    ((n << 1) ^ (n >> 63)) as u64
65}
66
67/// Reverses the zigzag encoding.
68fn unzigzag(n: u64) -> i64 {
69    let n = n as i64;
70    ((n << 63) >> 63) ^ (n >> 1)
71}
72
73/// The protobuf wire type.
74#[repr(u32)]
75#[derive(Debug, Copy, Clone, PartialEq, Eq)]
76pub enum WireType {
77    /// Variable-length integer.
78    Varint = 0,
79    /// Fixed 64-bit value.
80    Fixed64 = 1,
81    /// Variable-length byte buffer.
82    Variable = 2,
83    /// Fixed 32-bit value.
84    Fixed32 = 5,
85
86    /// Mesh extension: just like Variable but prefixed with two varints:
87    /// * The number of ports used by the message.
88    /// * The number of resources used by the message.
89    MeshMessage = 6,
90
91    /// Mesh extension. Consumes the next resource.
92    Resource = 7,
93}
94
95struct DecodeInner<'a, R> {
96    resources: &'a mut [Option<R>],
97}
98
99struct DecodeState<'a, R>(RefCell<DecodeInner<'a, R>>);
100
101impl<'a, R> DecodeState<'a, R> {
102    fn new(resources: &'a mut [Option<R>]) -> Self {
103        Self(RefCell::new(DecodeInner { resources }))
104    }
105
106    /// Takes resource `index`.
107    fn resource(&self, index: u32) -> Result<R> {
108        (|| {
109            self.0
110                .borrow_mut()
111                .resources
112                .get_mut(index as usize)?
113                .take()
114        })()
115        .ok_or_else(|| DecodeError::MissingResource.into())
116    }
117}
118
119struct EncodeState<'a, R> {
120    data: Buf<'a>,
121    message_sizes: core::slice::Iter<'a, MessageSize>,
122    resources: &'a mut Vec<R>,
123    field_number: u32,
124    in_sequence: bool,
125}
126
127impl<'a, R> EncodeState<'a, R> {
128    fn new(data: Buf<'a>, message_sizes: &'a [MessageSize], resources: &'a mut Vec<R>) -> Self {
129        Self {
130            data,
131            resources,
132            message_sizes: message_sizes.iter(),
133            field_number: 0,
134            in_sequence: false,
135        }
136    }
137}
138
139/// Type used to write field values.
140pub struct FieldWriter<'a, 'buf, R> {
141    state: &'a mut EncodeState<'buf, R>,
142}
143
144impl<'a, 'buf, R> FieldWriter<'a, 'buf, R> {
145    /// Writes the field key.
146    fn key(&mut self, ty: WireType) {
147        write_varint(
148            &mut self.state.data,
149            ((self.state.field_number << 3) | ty as u32).into(),
150        );
151    }
152
153    fn cached_variable<F>(mut self, f: F)
154    where
155        F: FnOnce(&mut Self),
156    {
157        if let Some(expected_len) = self.write_next_cached_message_header() {
158            f(&mut self);
159            assert_eq!(expected_len, self.state.data.len(), "wrong size");
160        }
161    }
162
163    /// Returns the expected size of the message, or None if the message is
164    /// empty and `skip_empty` is true, and so the message does not need to be
165    /// encoded.
166    fn write_next_cached_message_header(&mut self) -> Option<usize> {
167        let size = self
168            .state
169            .message_sizes
170            .next()
171            .expect("not enough messages in size calculation");
172        if size.num_resources > 0 {
173            self.key(WireType::MeshMessage);
174            write_varint(&mut self.state.data, size.num_resources.into());
175        } else if size.len > 0 || self.state.in_sequence {
176            self.key(WireType::Variable);
177        } else {
178            return None;
179        }
180        write_varint(&mut self.state.data, size.len as u64);
181        Some(self.state.data.len() + size.len)
182    }
183
184    /// Returns a sequence writer for writing the field multiple times.
185    ///
186    /// Panics if called while already writing a sequence, since this would
187    /// result in an invalid protobuf message.
188    pub fn sequence(self) -> SequenceWriter<'a, 'buf, R> {
189        assert!(!self.state.in_sequence);
190        SequenceWriter {
191            field_number: self.state.field_number,
192            state: self.state,
193        }
194    }
195
196    /// Returns whether this write is occurring within a sequence.
197    pub fn write_empty(&self) -> bool {
198        self.state.in_sequence
199    }
200
201    /// Calls `f` with a writer for a message.
202    pub fn message<F>(self, f: F)
203    where
204        F: FnOnce(MessageWriter<'_, 'buf, R>),
205    {
206        self.cached_variable(|this| {
207            f(MessageWriter { state: this.state });
208        });
209    }
210
211    /// Writes a resource.
212    pub fn resource(mut self, resource: R) {
213        self.key(WireType::Resource);
214        self.state.resources.push(resource);
215    }
216
217    /// Writes an unsigned variable-sized integer.
218    pub fn varint(mut self, n: u64) {
219        if n != 0 || self.state.in_sequence {
220            self.key(WireType::Varint);
221            write_varint(&mut self.state.data, n);
222        }
223    }
224
225    /// Writes a signed variable-sized integer.
226    pub fn svarint(mut self, n: i64) {
227        if n != 0 || self.state.in_sequence {
228            self.key(WireType::Varint);
229            write_varint(&mut self.state.data, zigzag(n));
230        }
231    }
232
233    /// Writes a fixed 64-bit integer.
234    pub fn fixed64(mut self, n: u64) {
235        if n != 0 || self.state.in_sequence {
236            self.key(WireType::Fixed64);
237            self.state.data.append(&n.to_le_bytes());
238        }
239    }
240
241    /// Writes a fixed 32-bit integer.
242    pub fn fixed32(mut self, n: u32) {
243        if n != 0 || self.state.in_sequence {
244            self.key(WireType::Fixed32);
245            self.state.data.append(&n.to_le_bytes());
246        }
247    }
248
249    /// Writes a byte slice.
250    pub fn bytes(mut self, b: &[u8]) {
251        if !b.is_empty() || self.state.in_sequence {
252            self.key(WireType::Variable);
253            write_varint(&mut self.state.data, b.len() as u64);
254            self.state.data.append(b);
255        }
256    }
257
258    /// Calls `f` with a writer for the packed field.
259    pub fn packed<F>(self, f: F)
260    where
261        F: FnOnce(PackedWriter<'_, '_>),
262    {
263        self.cached_variable(|this| {
264            f(PackedWriter {
265                data: &mut this.state.data,
266            })
267        })
268    }
269}
270
271/// A writer for writing a sequence of fields.
272pub struct SequenceWriter<'a, 'buf, R> {
273    state: &'a mut EncodeState<'buf, R>,
274    field_number: u32,
275}
276
277impl<'buf, R> SequenceWriter<'_, 'buf, R> {
278    /// Gets a field writer to write the next field in the sequence.
279    pub fn field(&mut self) -> FieldWriter<'_, 'buf, R> {
280        self.state.field_number = self.field_number;
281        self.state.in_sequence = true;
282        FieldWriter { state: self.state }
283    }
284}
285
286/// A writer for a message.
287pub struct MessageWriter<'a, 'buf, R> {
288    state: &'a mut EncodeState<'buf, R>,
289}
290
291impl<'buf, R> MessageWriter<'_, 'buf, R> {
292    /// Returns a field writer for field number `n`.
293    ///
294    /// It's legal to write fields in any order and to write fields that
295    /// duplicate previous fields. By convention, later fields overwrite
296    /// previous ones (or append, in the case of sequences).
297    pub fn field(&mut self, n: u32) -> FieldWriter<'_, 'buf, R> {
298        self.state.field_number = n;
299        self.state.in_sequence = false;
300        FieldWriter { state: self.state }
301    }
302
303    /// Writes a raw message from bytes.
304    pub fn bytes(&mut self, data: &[u8]) {
305        self.state.data.append(data);
306    }
307
308    /// Writes a raw message.
309    pub fn raw_message(&mut self, data: &[u8], resources: impl IntoIterator<Item = R>) {
310        self.state.data.append(data);
311        self.state.resources.extend(resources);
312    }
313}
314
315#[derive(Copy, Clone, Default)]
316struct MessageSize {
317    len: usize,
318    num_resources: u32,
319}
320
321struct SizeState {
322    message_sizes: Vec<MessageSize>,
323    index: usize,
324    tag_size: u8,
325    in_sequence: bool,
326}
327
328impl SizeState {
329    fn new() -> Self {
330        Self {
331            message_sizes: vec![MessageSize::default()],
332            index: 0,
333            tag_size: 0,
334            in_sequence: false,
335        }
336    }
337}
338
339/// Type used to compute the size of field values.
340pub struct FieldSizer<'a> {
341    state: &'a mut SizeState,
342}
343
344struct PreviousSizeParams {
345    index: u32,
346    tag_size: u8,
347    in_sequence: bool,
348}
349
350impl<'a> FieldSizer<'a> {
351    fn add(&mut self, size: usize) {
352        // Add room for the field tag.
353        self.state.message_sizes[self.state.index].len += self.state.tag_size as usize + size;
354    }
355
356    /// Makes and returns a writer for a message.
357    fn cached_variable<F>(&mut self, f: F)
358    where
359        F: FnOnce(&mut Self),
360    {
361        // Cache the size for use when writing the message.
362        let prev = self.reserve_cached_message_size_entry();
363        f(self);
364        self.set_cached_message_size(prev);
365    }
366
367    fn reserve_cached_message_size_entry(&mut self) -> PreviousSizeParams {
368        let index = self.state.message_sizes.len();
369        self.state.message_sizes.push(MessageSize::default());
370        PreviousSizeParams {
371            index: core::mem::replace(&mut self.state.index, index) as u32,
372            tag_size: self.state.tag_size,
373            in_sequence: self.state.in_sequence,
374        }
375    }
376
377    fn set_cached_message_size(&mut self, prev: PreviousSizeParams) {
378        let size = self.state.message_sizes[self.state.index];
379        let index = core::mem::replace(&mut self.state.index, prev.index as usize);
380        let parent_size = &mut self.state.message_sizes[self.state.index];
381        let mut len = varint_size(size.len as u64) + size.len;
382        if size.num_resources > 0 {
383            // This will be a MeshMessage field.
384            len += varint_size(size.num_resources as u64);
385            parent_size.num_resources += size.num_resources;
386        } else if !prev.in_sequence && size.len == 0 {
387            // This message is empty, so skip it and any nested messages.
388            self.state.message_sizes[index] = Default::default();
389            self.state.message_sizes.truncate(index + 1);
390            return;
391        }
392        parent_size.len += prev.tag_size as usize + len;
393    }
394
395    /// Returns a sequence sizer for sizing the field multiple times.
396    ///
397    /// Panics if called while already sizing a sequence, since this would
398    /// result in an invalid protobuf message.
399    pub fn sequence(self) -> SequenceSizer<'a> {
400        SequenceSizer {
401            tag_size: self.state.tag_size,
402            state: self.state,
403        }
404    }
405
406    /// If true, encoders must write their fields even if they are empty.
407    pub fn write_empty(&self) -> bool {
408        self.state.in_sequence
409    }
410
411    /// Computes the size for a message. Calls `f` with a [`MessageSizer`] to
412    /// calculate the size of each field.
413    pub fn message<F>(mut self, f: F)
414    where
415        F: FnOnce(MessageSizer<'_>),
416    {
417        self.cached_variable(|this| {
418            f(MessageSizer::new(this.state));
419        })
420    }
421
422    /// Computes the size for a resource.
423    pub fn resource(mut self) {
424        self.state.message_sizes[self.state.index].num_resources += 1;
425        self.add(0);
426    }
427
428    /// Computes the size for an unsigned variable-sized integer.
429    pub fn varint(mut self, n: u64) {
430        if n != 0 || self.state.in_sequence {
431            self.add(varint_size(n));
432        }
433    }
434
435    /// Computes the size for a signed variable-sized integer.
436    pub fn svarint(mut self, n: i64) {
437        if n != 0 || self.state.in_sequence {
438            self.add(varint_size(zigzag(n)));
439        }
440    }
441
442    /// Computes the size for a fixed 64-bit integer.
443    pub fn fixed64(mut self, n: u64) {
444        if n != 0 || self.state.in_sequence {
445            self.add(8);
446        }
447    }
448
449    /// Computes the size for a fixed 32-bit integer.
450    pub fn fixed32(mut self, n: u32) {
451        if n != 0 || self.state.in_sequence {
452            self.add(4);
453        }
454    }
455
456    /// Computes the size for a byte slice.
457    pub fn bytes(mut self, len: usize) {
458        if len != 0 || self.state.in_sequence {
459            self.add(varint_size(len as u64) + len);
460        }
461    }
462
463    /// Computes the size of a packed value. Calls `f` with a [`PackedSizer`] to
464    /// sum the size of each element.
465    pub fn packed<F>(mut self, f: F)
466    where
467        F: FnOnce(PackedSizer<'_>),
468    {
469        self.cached_variable(|this| {
470            f(PackedSizer {
471                size: &mut this.state.message_sizes[this.state.index].len,
472            });
473        })
474    }
475}
476
477/// A sizer for computing the size of a sequence of fields.
478pub struct SequenceSizer<'a> {
479    state: &'a mut SizeState,
480    tag_size: u8,
481}
482
483impl SequenceSizer<'_> {
484    /// Gets a field sizer for the next field in the sequence.
485    pub fn field(&mut self) -> FieldSizer<'_> {
486        self.state.tag_size = self.tag_size;
487        self.state.in_sequence = true;
488        FieldSizer { state: self.state }
489    }
490}
491
492/// A type to compute the size of a message.
493pub struct MessageSizer<'a> {
494    state: &'a mut SizeState,
495}
496
497impl<'a> MessageSizer<'a> {
498    fn new(state: &'a mut SizeState) -> Self {
499        Self { state }
500    }
501
502    /// Returns a field sizer for field number `n`.
503    pub fn field(&mut self, n: u32) -> FieldSizer<'_> {
504        self.state.tag_size = varint_size((n as u64) << 3) as u8;
505        self.state.in_sequence = false;
506        FieldSizer { state: self.state }
507    }
508
509    /// Sizes the message as `n` bytes.
510    pub fn bytes(&mut self, n: usize) {
511        self.state.message_sizes[self.state.index] = MessageSize {
512            len: n,
513            ..Default::default()
514        };
515    }
516
517    /// Sizes the message as `n` bytes plus `num_resources` resources.
518    pub fn raw_message(&mut self, len: usize, num_resources: u32) {
519        self.state.message_sizes[self.state.index] = MessageSize { len, num_resources }
520    }
521}
522
523/// A parsed protobuf value.
524#[derive(Debug, Clone)]
525enum Value<'a> {
526    Varint(u64),
527    Fixed64(u64),
528    Variable(&'a [u8]),
529    Fixed32(u32),
530    Resource(u32),
531    MeshMessage {
532        data: &'a [u8],
533        resources: Range<u32>,
534    },
535}
536
537/// A reader for a payload field.
538pub struct FieldReader<'a, 'b, R> {
539    field: Value<'a>,
540    state: &'b DecodeState<'b, R>,
541}
542
543impl<'a, 'b, R> FieldReader<'a, 'b, R> {
544    /// Gets the wire type for the field.
545    pub fn wire_type(&self) -> WireType {
546        match &self.field {
547            Value::Varint(_) => WireType::Varint,
548            Value::Fixed64(_) => WireType::Fixed64,
549            Value::Variable(_) => WireType::Variable,
550            Value::Fixed32(_) => WireType::Fixed32,
551            Value::MeshMessage { .. } => WireType::MeshMessage,
552            Value::Resource { .. } => WireType::Resource,
553        }
554    }
555
556    /// Makes and returns an message reader.
557    pub fn message(self) -> Result<MessageReader<'a, 'b, R>> {
558        if let Value::Variable(data) = self.field {
559            Ok(MessageReader {
560                data,
561                state: self.state,
562                resources: 0..0,
563            })
564        } else if let Value::MeshMessage { data, resources } = self.field {
565            Ok(MessageReader {
566                data,
567                state: self.state,
568                resources,
569            })
570        } else {
571            Err(DecodeError::ExpectedMessage.into())
572        }
573    }
574
575    /// Reads a resource.
576    pub fn resource(self) -> Result<R> {
577        if let Value::Resource(index) = self.field {
578            self.state.resource(index)
579        } else {
580            Err(DecodeError::ExpectedResource.into())
581        }
582    }
583
584    /// Reads an unsigned variable-sized integer.
585    pub fn varint(self) -> Result<u64> {
586        if let Value::Varint(n) = self.field {
587            Ok(n)
588        } else {
589            Err(DecodeError::ExpectedVarInt.into())
590        }
591    }
592
593    /// Reads a signed variable-sized integer.
594    pub fn svarint(self) -> Result<i64> {
595        Ok(unzigzag(self.varint()?))
596    }
597
598    /// Reads a fixed 64-bit integer.
599    pub fn fixed64(self) -> Result<u64> {
600        if let Value::Fixed64(n) = self.field {
601            Ok(n)
602        } else {
603            Err(DecodeError::ExpectedFixed64.into())
604        }
605    }
606
607    /// Reads a fixed 32-bit integer.
608    pub fn fixed32(self) -> Result<u32> {
609        if let Value::Fixed32(n) = self.field {
610            Ok(n)
611        } else {
612            Err(DecodeError::ExpectedFixed32.into())
613        }
614    }
615
616    /// Reads a byte slice.
617    pub fn bytes(self) -> Result<&'a [u8]> {
618        if let Value::Variable(data) = self.field {
619            Ok(data)
620        } else {
621            Err(DecodeError::ExpectedByteArray.into())
622        }
623    }
624
625    /// Gets a reader for a packed field.
626    pub fn packed(self) -> Result<PackedReader<'a>> {
627        Ok(PackedReader {
628            data: self.bytes()?,
629        })
630    }
631}
632
633/// Reader for an message.
634///
635/// Implements [`Iterator`] to return (field number, [`FieldReader`]) pairs.
636/// Users must be prepared to handle fields in any order, allowing unknown and
637/// duplicate fields.
638pub struct MessageReader<'a, 'b, R> {
639    data: &'a [u8],
640    resources: Range<u32>,
641    state: &'b DecodeState<'b, R>,
642}
643
644impl<'a, 'b, R> IntoIterator for MessageReader<'a, 'b, R> {
645    type Item = Result<(u32, FieldReader<'a, 'b, R>)>;
646    type IntoIter = FieldIterator<'a, 'b, R>;
647
648    fn into_iter(self) -> Self::IntoIter {
649        FieldIterator(self)
650    }
651}
652
653impl<'a, 'b, R> MessageReader<'a, 'b, R> {
654    fn new(data: &'a [u8], state: &'b DecodeState<'b, R>) -> Self {
655        let num_resources = state.0.borrow().resources.len() as u32;
656        Self {
657            data,
658            state,
659            resources: 0..num_resources,
660        }
661    }
662
663    /// Gets the message data as a byte slice.
664    pub fn bytes(&self) -> &'a [u8] {
665        self.data
666    }
667
668    /// Returns an iterator to consume the resources for this message.
669    pub fn take_resources(&mut self) -> impl ExactSizeIterator<Item = Result<R>> + use<'b, R> {
670        let state = self.state;
671        self.resources.clone().map(move |i| {
672            state
673                .0
674                .borrow_mut()
675                .resources
676                .get_mut(i as usize)
677                .and_then(|x| x.take())
678                .ok_or_else(|| DecodeError::MissingResource.into())
679        })
680    }
681
682    fn parse_field(&mut self) -> Result<(u32, FieldReader<'a, 'b, R>)> {
683        let key = read_varint(&mut self.data)?;
684        let wire_type = (key & 7) as u32;
685        let field_number = (key >> 3) as u32;
686        let field = match wire_type {
687            0 => Value::Varint(read_varint(&mut self.data)?),
688            1 => {
689                if self.data.len() < 8 {
690                    return Err(DecodeError::EofFixed64.into());
691                }
692                let (n, rest) = self.data.split_at(8);
693                self.data = rest;
694                Value::Fixed64(u64::from_le_bytes(n.try_into().unwrap()))
695            }
696            2 => {
697                let len = read_varint(&mut self.data)?;
698                if (self.data.len() as u64) < len {
699                    return Err(DecodeError::EofByteArray.into());
700                }
701                let (data, rest) = self.data.split_at(len as usize);
702                self.data = rest;
703                Value::Variable(data)
704            }
705            5 => {
706                if self.data.len() < 4 {
707                    return Err(DecodeError::EofFixed32.into());
708                }
709                let (n, rest) = self.data.split_at(4);
710                self.data = rest;
711                Value::Fixed32(u32::from_le_bytes(n.try_into().unwrap()))
712            }
713            6 => {
714                let num_resources = read_varint(&mut self.data)? as u32;
715                let len = read_varint(&mut self.data)?;
716
717                if self.resources.len() < num_resources as usize {
718                    return Err(DecodeError::InvalidResourceRange.into());
719                }
720                if (self.data.len() as u64) < len {
721                    return Err(DecodeError::EofByteArray.into());
722                }
723
724                let (data, rest) = self.data.split_at(len as usize);
725                self.data = rest;
726
727                let resources = self.resources.start..self.resources.start + num_resources;
728                self.resources = resources.end..self.resources.end;
729
730                Value::MeshMessage { data, resources }
731            }
732            7 => {
733                let resource = self.resources.next().ok_or(DecodeError::MissingResource)?;
734                Value::Resource(resource)
735            }
736            n => return Err(DecodeError::UnknownWireType(n).into()),
737        };
738        Ok((
739            field_number,
740            FieldReader {
741                field,
742                state: self.state,
743            },
744        ))
745    }
746}
747
748/// An iterator over message fields.
749///
750/// Returned by [`MessageReader::into_iter()`].
751pub struct FieldIterator<'a, 'b, R>(MessageReader<'a, 'b, R>);
752
753impl<'a, 'b, R> Iterator for FieldIterator<'a, 'b, R> {
754    type Item = Result<(u32, FieldReader<'a, 'b, R>)>;
755
756    fn next(&mut self) -> Option<Self::Item> {
757        if self.0.data.is_empty() {
758            return None;
759        }
760        Some(self.0.parse_field())
761    }
762}
763
764/// A writer for a packed field.
765pub struct PackedWriter<'a, 'buf> {
766    data: &'a mut Buf<'buf>,
767}
768
769impl PackedWriter<'_, '_> {
770    /// Appends `bytes`.
771    pub fn bytes(&mut self, bytes: &[u8]) {
772        self.data.append(bytes);
773    }
774
775    /// Appends varint `v`.
776    pub fn varint(&mut self, v: u64) {
777        write_varint(self.data, v);
778    }
779
780    /// Appends signed (zigzag-encoded) varint `v`.
781    pub fn svarint(&mut self, v: i64) {
782        write_varint(self.data, zigzag(v));
783    }
784
785    /// Appends fixed 64-bit value `v`.
786    pub fn fixed64(&mut self, v: u64) {
787        self.bytes(&v.to_le_bytes());
788    }
789
790    /// Appends fixed 32-bit value `v`.
791    pub fn fixed32(&mut self, v: u32) {
792        self.bytes(&v.to_le_bytes());
793    }
794}
795
796/// A type to help compute the size of a packed field.
797pub struct PackedSizer<'a> {
798    size: &'a mut usize,
799}
800
801impl PackedSizer<'_> {
802    /// Adds the size of `len` bytes.
803    pub fn bytes(&mut self, len: usize) {
804        *self.size += len;
805    }
806
807    /// Adds the size of a varint value `v`.
808    pub fn varint(&mut self, v: u64) {
809        *self.size += varint_size(v);
810    }
811
812    /// Adds the size of a signed (zigzag-encoded) varint value `v`.
813    pub fn svarint(&mut self, v: i64) {
814        *self.size += varint_size(zigzag(v));
815    }
816
817    /// Adds the size of a fixed 64-bit value.
818    pub fn fixed64(&mut self) {
819        *self.size += 8;
820    }
821
822    /// Adds the size of a fixed 32-bit value.
823    pub fn fixed32(&mut self) {
824        *self.size += 4;
825    }
826}
827
828/// Reader for packed fields.
829pub struct PackedReader<'a> {
830    data: &'a [u8],
831}
832
833impl<'a> PackedReader<'a> {
834    /// Reads the remaining bytes.
835    pub fn bytes(&mut self) -> &'a [u8] {
836        core::mem::take(&mut self.data)
837    }
838
839    /// Reads a varint.
840    ///
841    /// Returns `Ok(None)` if there are no more values.
842    pub fn varint(&mut self) -> Result<Option<u64>> {
843        if self.data.is_empty() {
844            Ok(None)
845        } else {
846            read_varint(&mut self.data).map(Some)
847        }
848    }
849
850    /// Reads a signed (zigzag-encoded) varint.
851    ///
852    /// Returns `Ok(None)` if there are no more values.
853    pub fn svarint(&mut self) -> Result<Option<i64>> {
854        if self.data.is_empty() {
855            Ok(None)
856        } else {
857            read_varint(&mut self.data).map(|n| Some(unzigzag(n)))
858        }
859    }
860
861    /// Reads a fixed 64-bit value.
862    ///
863    /// Returns `Ok(None)` if there are no more values.
864    pub fn fixed64(&mut self) -> Result<Option<u64>> {
865        if self.data.is_empty() {
866            Ok(None)
867        } else if self.data.len() < 8 {
868            Err(DecodeError::EofFixed64.into())
869        } else {
870            let (b, data) = self.data.split_at(8);
871            self.data = data;
872            Ok(Some(u64::from_le_bytes(b.try_into().unwrap())))
873        }
874    }
875
876    /// Reads a fixed 32-bit value.
877    ///
878    /// Returns `Ok(None)` if there are no more values.
879    pub fn fixed32(&mut self) -> Result<Option<u32>> {
880        if self.data.is_empty() {
881            Ok(None)
882        } else if self.data.len() < 4 {
883            Err(DecodeError::EofFixed32.into())
884        } else {
885            let (b, data) = self.data.split_at(4);
886            self.data = data;
887            Ok(Some(u32::from_le_bytes(b.try_into().unwrap())))
888        }
889    }
890}
891
892/// An encoder for a single message of type `T`, using the messaging encoding
893/// `E`.
894pub struct Encoder<T, E, R> {
895    message: T,
896    message_sizes: Vec<MessageSize>,
897    _phantom: PhantomData<(fn() -> R, E)>,
898}
899
900impl<R, T: DefaultEncoding> Encoder<T, T::Encoding, R>
901where
902    T::Encoding: MessageEncode<T, R>,
903{
904    /// Creates an encoder for `message`.F
905    pub fn new(message: T) -> Self {
906        Encoder::with_encoding(message)
907    }
908}
909
910impl<T, R, E: MessageEncode<T, R>> Encoder<T, E, R> {
911    /// Creates an encoder for `message` with a specific encoder.
912    pub fn with_encoding(mut message: T) -> Self {
913        let mut state = SizeState::new();
914        E::compute_message_size(&mut message, MessageSizer::new(&mut state));
915        Self {
916            message,
917            message_sizes: state.message_sizes,
918            _phantom: PhantomData,
919        }
920    }
921
922    /// Returns the length of the message in bytes.
923    pub fn len(&self) -> usize {
924        self.message_sizes[0].len
925    }
926
927    /// Returns the number of resources in the message.
928    pub fn resource_count(&self) -> usize {
929        self.message_sizes[0].num_resources as usize
930    }
931
932    /// Encodes the message into `buffer`.
933    pub fn encode_into(self, buffer: &mut dyn Buffer, resources: &mut Vec<R>) {
934        buffer::write_with(buffer, |buf| {
935            let capacity = buf.remaining();
936            let init_resources = resources.len();
937            let mut state = EncodeState::new(buf, &self.message_sizes, resources);
938            let size = state.message_sizes.next().unwrap();
939            E::write_message(self.message, MessageWriter { state: &mut state });
940            assert_eq!(capacity - state.data.remaining(), size.len);
941            assert_eq!(
942                state.resources.len() - init_resources,
943                size.num_resources as usize
944            );
945            assert!(state.message_sizes.next().is_none());
946        })
947    }
948
949    /// Encodes the message.
950    pub fn encode(self) -> (Vec<u8>, Vec<R>) {
951        let mut data = Vec::with_capacity(self.len());
952        let mut resources = Vec::with_capacity(self.resource_count());
953        self.encode_into(&mut data, &mut resources);
954        (data, resources)
955    }
956}
957
958/// Decodes a protobuf message into `message` using encoding `T`.
959///
960/// If `message` already exists, then the fields are merged according to
961/// protobuf rules.
962pub fn decode_with<'a, E: MessageDecode<'a, T, R>, T, R>(
963    message: &mut InplaceOption<'_, T>,
964    data: &'a [u8],
965    resources: &mut [Option<R>],
966) -> Result<()> {
967    let state = DecodeState::new(resources);
968    let reader = MessageReader::new(data, &state);
969    E::read_message(message, reader)?;
970    Ok(())
971}
972
973#[cfg(test)]
974mod tests {
975    extern crate std;
976
977    use super::*;
978    use crate::buffer;
979    use std::eprintln;
980
981    #[test]
982    fn test_zigzag() {
983        let cases: &[(i64, u64)] = &[
984            (0, 0),
985            (-1, 1),
986            (1, 2),
987            (-2, 3),
988            (2147483647, 4294967294),
989            (-2147483648, 4294967295),
990        ];
991        for (a, b) in cases.iter().copied() {
992            assert_eq!(zigzag(a), b);
993            assert_eq!(a, unzigzag(b));
994        }
995    }
996
997    #[test]
998    fn test_varint() {
999        let cases: &[(u64, &[u8])] = &[
1000            (0, &[0]),
1001            (1, &[1]),
1002            (0x80, &[0x80, 1]),
1003            (
1004                -1i64 as u64,
1005                &[0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x1],
1006            ),
1007        ];
1008        for (a, mut b) in cases.iter().copied() {
1009            eprintln!("{:#x}, {:#x?}", a, b);
1010            assert_eq!(varint_size(a), b.len());
1011            let mut v = Vec::with_capacity(10);
1012            buffer::write_with(&mut v, |mut buf| write_varint(&mut buf, a));
1013            assert_eq!(&v, b);
1014            assert_eq!(a, read_varint(&mut b).unwrap());
1015            assert!(b.is_empty());
1016        }
1017    }
1018
1019    #[test]
1020    fn test_resource() {
1021        let mut state = SizeState::new();
1022        let mut sizer = MessageSizer::new(&mut state);
1023        sizer.field(1).resource();
1024        sizer.field(2).resource();
1025        sizer.field(3).message(|mut sizer| {
1026            sizer.field(1).resource();
1027            sizer.field(1).resource();
1028            sizer.field(1).resource();
1029        });
1030        let size = state.message_sizes.remove(0);
1031        assert_eq!(size.num_resources, 5);
1032
1033        let mut data = Vec::with_capacity(size.len);
1034        let mut resources = Vec::with_capacity(size.num_resources as usize);
1035        buffer::write_with(&mut data, |buf| {
1036            let mut state = EncodeState::new(buf, &state.message_sizes, &mut resources);
1037            let mut writer = MessageWriter { state: &mut state };
1038            writer.field(1).resource(());
1039            writer.field(2).resource(());
1040            writer.field(3).message(|mut writer| {
1041                writer.field(1).resource(());
1042                writer.field(1).resource(());
1043                writer.field(1).resource(());
1044            });
1045        });
1046
1047        let mut resources: Vec<_> = resources.into_iter().map(Some).collect();
1048        let state = DecodeState(RefCell::new(DecodeInner {
1049            resources: &mut resources,
1050        }));
1051        let reader = MessageReader {
1052            data: &data,
1053            state: &state,
1054            resources: 0..5,
1055        };
1056
1057        let mut it = reader.into_iter();
1058        let (n, r) = it.next().unwrap().unwrap();
1059        assert_eq!(n, 1);
1060        r.resource().unwrap();
1061        let (n, r) = it.next().unwrap().unwrap();
1062        assert_eq!(n, 2);
1063        r.resource().unwrap();
1064        let (n, r) = it.next().unwrap().unwrap();
1065        assert_eq!(n, 3);
1066        let message = r.message().unwrap();
1067        assert!(it.next().is_none());
1068
1069        let mut it = message.into_iter();
1070        let (n, r) = it.next().unwrap().unwrap();
1071        assert_eq!(n, 1);
1072        r.resource().unwrap();
1073        let (n, r) = it.next().unwrap().unwrap();
1074        assert_eq!(n, 1);
1075        r.resource().unwrap();
1076        let (n, r) = it.next().unwrap().unwrap();
1077        assert_eq!(n, 1);
1078        r.resource().unwrap();
1079        assert!(it.next().is_none());
1080    }
1081}