mesh_protobuf/
message.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! Type-erased protobuf message support.
5
6use crate::DefaultEncoding;
7use crate::DescribedProtobuf;
8use crate::Error;
9use crate::MessageDecode;
10use crate::MessageEncode;
11use crate::Protobuf;
12use crate::decode;
13use crate::encode;
14use crate::encoding::MessageEncoding;
15use crate::inplace::InplaceOption;
16use crate::protobuf::MessageReader;
17use crate::protobuf::MessageSizer;
18use crate::protobuf::MessageWriter;
19use crate::protofile::DescribeField;
20use crate::protofile::FieldType;
21use crate::protofile::MessageDescription;
22use crate::table::DescribeTable;
23use alloc::string::String;
24use alloc::string::ToString;
25use alloc::vec::Vec;
26use thiserror::Error;
27
28/// An opaque protobuf message.
29//
30// TODO: delay encoding like in mesh::Message. This requires splitting some of
31// the encoding traits up to remove the resource type.
32#[derive(Debug)]
33pub struct ProtobufMessage(Vec<u8>);
34
35impl ProtobufMessage {
36    /// Encodes `data` as a protobuf message.
37    pub fn new(data: impl Protobuf) -> Self {
38        Self(encode(data))
39    }
40
41    /// Decodes the protobuf message into `T`.
42    pub fn parse<T: Protobuf>(&self) -> Result<T, Error> {
43        decode(&self.0)
44    }
45}
46
47impl DefaultEncoding for ProtobufMessage {
48    type Encoding = MessageEncoding<ProtobufMessageEncoding>;
49}
50
51impl DescribeField<ProtobufMessage> for MessageEncoding<ProtobufMessageEncoding> {
52    const FIELD_TYPE: FieldType<'static> = FieldType::builtin("bytes");
53}
54
55/// Encoder for [`ProtobufMessage`].
56#[derive(Debug)]
57pub struct ProtobufMessageEncoding;
58
59impl<R> MessageEncode<ProtobufMessage, R> for ProtobufMessageEncoding {
60    fn write_message(item: ProtobufMessage, mut writer: MessageWriter<'_, '_, R>) {
61        writer.bytes(&item.0);
62    }
63
64    fn compute_message_size(item: &mut ProtobufMessage, mut sizer: MessageSizer<'_>) {
65        sizer.bytes(item.0.len());
66    }
67}
68
69impl<R> MessageDecode<'_, ProtobufMessage, R> for ProtobufMessageEncoding {
70    fn read_message(
71        item: &mut InplaceOption<'_, ProtobufMessage>,
72        reader: MessageReader<'_, '_, R>,
73    ) -> crate::Result<()> {
74        item.get_or_insert_with(|| ProtobufMessage(Vec::new()))
75            .0
76            .extend(reader.bytes());
77        Ok(())
78    }
79}
80
81/// A protobuf message and the associated protobuf type URL.
82///
83/// This has the encoding of `google.protobuf.Any`.
84#[derive(Protobuf)]
85pub struct ProtobufAny {
86    #[mesh(1)]
87    type_url: String, // FUTURE: avoid allocation here
88    #[mesh(2)]
89    value: ProtobufMessage,
90}
91
92impl core::fmt::Debug for ProtobufAny {
93    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
94        if f.alternate() {
95            // Full debug output like derive would produce
96            f.debug_struct("ProtobufAny")
97                .field("type_url", &self.type_url)
98                .field("value", &self.value)
99                .finish()
100        } else {
101            // Compact output with just type_url and data length
102            f.debug_struct("ProtobufAny")
103                .field("type_url", &self.type_url)
104                .field("value_len", &self.value.0.len())
105                .finish()
106        }
107    }
108}
109
110#[derive(Debug, Error)]
111#[error("protobuf type mismatch, expected {expected}, got {actual}")]
112struct TypeMismatch {
113    expected: String,
114    actual: String,
115}
116
117impl DescribeTable for ProtobufAny {
118    const DESCRIPTION: MessageDescription<'static> = MessageDescription::External {
119        name: "google.protobuf.Any",
120        import_path: "google/protobuf/any.proto",
121    };
122}
123
124impl ProtobufAny {
125    /// Encodes `data` as a protobuf message.
126    pub fn new<T: DescribedProtobuf>(data: T) -> Self {
127        Self {
128            type_url: T::TYPE_URL.to_string(),
129            value: ProtobufMessage::new(data),
130        }
131    }
132
133    /// Decodes the protobuf message into `T`.
134    ///
135    /// Fails if this message is an encoding of a different type.
136    pub fn parse<T: DescribedProtobuf>(&self) -> Result<T, Error> {
137        if &T::TYPE_URL != self.type_url.as_str() {
138            return Err(Error::new(TypeMismatch {
139                expected: T::TYPE_URL.to_string(),
140                actual: self.type_url.clone(),
141            }));
142        }
143        self.value.parse()
144    }
145
146    /// Returns `true` if this message is an encoding of `T`.
147    pub fn is_message<T: DescribedProtobuf>(&self) -> bool {
148        &T::TYPE_URL == self.type_url.as_str()
149    }
150}
151
152#[cfg(test)]
153mod tests {
154    extern crate std;
155
156    use crate::Protobuf;
157    use crate::encode;
158    use crate::message::ProtobufAny;
159    use crate::message::ProtobufMessage;
160    use crate::tests::as_expect_str;
161    use expect_test::expect;
162    use std::println;
163
164    #[test]
165    fn test_message() {
166        let message = (5u32,);
167
168        // Round trips.
169        assert_eq!(
170            ProtobufMessage::new(message).parse::<(u32,)>().unwrap(),
171            message
172        );
173
174        let expected = expect!([r#"
175            1: varint 5
176            raw: 0805"#]);
177        let actual = encode(ProtobufMessage::new(message));
178        expected.assert_eq(&as_expect_str(&actual));
179
180        // Is transparent.
181        assert_eq!(actual, encode(message));
182    }
183
184    #[test]
185    fn test_any() {
186        #[derive(Protobuf, PartialEq, Eq, Copy, Clone, Debug)]
187        #[mesh(package = "test")]
188        struct Message {
189            #[mesh(1)]
190            x: u32,
191        }
192
193        #[derive(Protobuf, Debug)]
194        #[mesh(package = "test")]
195        struct Other {
196            #[mesh(1)]
197            x: u32,
198        }
199
200        let msg = Message { x: 5 };
201        let any = ProtobufAny::new(msg);
202
203        assert_eq!(any.type_url, "type.googleapis.com/test.Message");
204        assert!(any.is_message::<Message>());
205        assert!(!any.is_message::<Other>());
206        assert_eq!(any.parse::<Message>().unwrap(), msg);
207        println!("{:?}", any.parse::<Other>().unwrap_err());
208    }
209}