mesh_node/
message.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! Implements the `Message` type.
5
6// UNSAFETY: Needed to define, implement, and call the unsafe extract function.
7#![expect(unsafe_code)]
8
9use crate::resource::Resource;
10use crate::resource::SerializedMessage;
11use mesh_protobuf;
12use mesh_protobuf::DefaultEncoding;
13use mesh_protobuf::MessageDecode;
14use mesh_protobuf::MessageEncode;
15use mesh_protobuf::encoding::SerializedMessageEncoder;
16use mesh_protobuf::inplace;
17use mesh_protobuf::inplace_none;
18use mesh_protobuf::protobuf::MessageSizer;
19use mesh_protobuf::protobuf::MessageWriter;
20use mesh_protobuf::protobuf::decode_with;
21use std::any::Any;
22use std::any::TypeId;
23use std::borrow::Cow;
24use std::fmt;
25use std::fmt::Debug;
26use std::marker::PhantomData;
27use std::mem::MaybeUninit;
28
29/// A message on a port.
30///
31/// The message has a static lifetime and is `Send`, so it is appropriate for
32/// storing and using across threads.
33///
34/// See [`Message`] for a version that can reference data with non-static
35/// lifetime.
36#[derive(Default)]
37pub struct OwnedMessage(OwnedMessageInner);
38
39enum OwnedMessageInner {
40    Unserialized(Box<dyn DynSerializeMessage>),
41    Serialized(SerializedMessage),
42}
43
44impl Debug for OwnedMessage {
45    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
46        f.pad("OwnedMessage")
47    }
48}
49
50impl Default for OwnedMessageInner {
51    fn default() -> Self {
52        Self::Serialized(Default::default())
53    }
54}
55
56impl OwnedMessage {
57    /// Serializes the message and returns it.
58    pub fn serialize(self) -> SerializedMessage {
59        match self.0 {
60            OwnedMessageInner::Unserialized(_) => SerializedMessage::from_message(self),
61            OwnedMessageInner::Serialized(message) => message,
62        }
63    }
64}
65
66/// Trait for types that can be constructed as a [`Message`].
67///
68/// This does not include scalar types such as `u32`, which are encoded as
69/// non-message types.
70pub trait MeshPayload: DefaultEncoding<Encoding = <Self as MeshPayload>::Encoding> + Sized {
71    type Encoding: MessageEncode<Self, Resource>
72        + for<'a> MessageDecode<'a, Self, Resource>
73        + mesh_protobuf::FieldEncode<Self, Resource>
74        + for<'a> mesh_protobuf::FieldDecode<'a, Self, Resource>
75        + Send
76        + Sync;
77}
78
79impl<T> MeshPayload for T
80where
81    T: DefaultEncoding + Any + Send + 'static,
82    T::Encoding: MessageEncode<T, Resource>
83        + for<'a> MessageDecode<'a, T, Resource>
84        + mesh_protobuf::FieldEncode<T, Resource>
85        + for<'a> mesh_protobuf::FieldDecode<'a, T, Resource>
86        + Send
87        + Sync,
88{
89    type Encoding = T::Encoding;
90}
91
92/// Trait for types that can be a field in a mesh message, including both scalar
93/// types and types that implement [`MeshPayload`].
94pub trait MeshField: DefaultEncoding<Encoding = <Self as MeshField>::Encoding> + Sized {
95    type Encoding: mesh_protobuf::FieldEncode<Self, Resource>
96        + for<'a> mesh_protobuf::FieldDecode<'a, Self, Resource>
97        + Send
98        + Sync;
99}
100
101impl<T> MeshField for T
102where
103    T: DefaultEncoding,
104    T::Encoding: mesh_protobuf::FieldEncode<T, Resource>
105        + for<'a> mesh_protobuf::FieldDecode<'a, T, Resource>
106        + Send
107        + Sync,
108{
109    type Encoding = T::Encoding;
110}
111
112/// Trait implemented by concrete messages that can be extracted or serialized
113/// into [`SerializedMessage`].
114pub trait SerializeMessage: 'static {
115    /// The underlying concrete message type.
116    type Concrete: Any;
117
118    /// Computes the message size, as in [`MessageEncode::compute_message_size`].
119    fn compute_message_size(&mut self, sizer: MessageSizer<'_>);
120
121    /// Writes the message, as in [`MessageEncode::write_message`].
122    fn write_message(self, writer: MessageWriter<'_, '_, Resource>);
123
124    /// Extract the concrete message.
125    fn extract(self) -> Self::Concrete;
126}
127
128/// # Safety
129///
130/// The implementor must ensure that `extract_or_serialize` initializes the
131/// pointer if it returns `Ok(())`.
132unsafe trait DynSerializeMessage: Send {
133    fn compute_message_size(&mut self, sizer: MessageSizer<'_>);
134    fn write_message(self: Box<Self>, writer: MessageWriter<'_, '_, Resource>);
135
136    /// # Safety
137    ///
138    /// The caller must ensure that `ptr` points to storage whose type matches
139    /// `type_id`.
140    unsafe fn extract(
141        self: Box<Self>,
142        type_id: TypeId,
143        ptr: *mut (),
144    ) -> Result<(), Box<dyn DynSerializeMessage>>;
145}
146
147// SAFETY: extract_or_serialize satisfies implementation requirements.
148unsafe impl<T: SerializeMessage + Send> DynSerializeMessage for T {
149    fn compute_message_size(&mut self, sizer: MessageSizer<'_>) {
150        self.compute_message_size(sizer)
151    }
152
153    fn write_message(self: Box<Self>, writer: MessageWriter<'_, '_, Resource>) {
154        (*self).write_message(writer)
155    }
156
157    unsafe fn extract(
158        self: Box<Self>,
159        type_id: TypeId,
160        ptr: *mut (),
161    ) -> Result<(), Box<dyn DynSerializeMessage>> {
162        if type_id == TypeId::of::<T::Concrete>() {
163            // SAFETY: ptr is guaranteed to be T::Concrete by caller.
164            unsafe { ptr.cast::<T::Concrete>().write((*self).extract()) };
165            Ok(())
166        } else {
167            Err(self)
168        }
169    }
170}
171
172impl<T: 'static + MeshPayload + Send> SerializeMessage for T {
173    type Concrete = Self;
174
175    fn compute_message_size(&mut self, sizer: MessageSizer<'_>) {
176        <T as MeshPayload>::Encoding::compute_message_size(self, sizer)
177    }
178
179    fn write_message(self, writer: MessageWriter<'_, '_, Resource>) {
180        <T as MeshPayload>::Encoding::write_message(self, writer)
181    }
182
183    fn extract(self) -> Self::Concrete {
184        self
185    }
186}
187
188impl OwnedMessage {
189    /// Creates a new message wrapping `data`, which will be lazily serialized
190    /// when needed.
191    #[inline]
192    pub fn new<T: SerializeMessage + Send>(data: T) -> Self {
193        Self(OwnedMessageInner::Unserialized(Box::new(data)))
194    }
195
196    /// Creates a new message from already-serialized data in `s`.
197    pub fn serialized(s: SerializedMessage) -> Self {
198        Self(OwnedMessageInner::Serialized(s))
199    }
200
201    /// Parses the message into a value of type `T`.
202    ///
203    /// If the message was constructed with `new<T>`, then the round trip
204    /// serialization/deserialization is skipped.
205    pub fn parse<T>(self) -> Result<T, mesh_protobuf::Error>
206    where
207        T: 'static + DefaultEncoding,
208        T::Encoding: for<'a> MessageDecode<'a, T, Resource>,
209    {
210        Message::from(self).parse()
211    }
212
213    /// Tries to unwrap the message into a value of type `T`.
214    ///
215    /// If the message was not created with [`OwnedMessage::new<T>`], then this
216    /// returns `Err(self)`.
217    //
218    // FUTURE: remove this optimization once nothing depends on it for
219    // functionality or performance.
220    pub fn try_unwrap<T: 'static>(self) -> Result<T, Self> {
221        match self.0 {
222            OwnedMessageInner::Unserialized(m) => {
223                let mut message = MaybeUninit::<T>::uninit();
224                // SAFETY: calling with appropriately sized and aligned buffer
225                // for writing T.
226                unsafe {
227                    match m.extract(TypeId::of::<T>(), message.as_mut_ptr().cast()) {
228                        Ok(()) => Ok(message.assume_init()),
229                        Err(message) => Err(Self(OwnedMessageInner::Unserialized(message))),
230                    }
231                }
232            }
233            OwnedMessageInner::Serialized(_) => Err(self),
234        }
235    }
236}
237
238impl DefaultEncoding for OwnedMessage {
239    type Encoding = mesh_protobuf::encoding::MessageEncoding<MessageEncoder>;
240}
241
242pub struct MessageEncoder;
243
244impl MessageEncode<Box<dyn DynSerializeMessage>, Resource> for MessageEncoder {
245    fn write_message(item: Box<dyn DynSerializeMessage>, writer: MessageWriter<'_, '_, Resource>) {
246        item.write_message(writer);
247    }
248
249    fn compute_message_size(item: &mut Box<dyn DynSerializeMessage>, sizer: MessageSizer<'_>) {
250        item.compute_message_size(sizer);
251    }
252}
253
254impl MessageEncode<OwnedMessage, Resource> for MessageEncoder {
255    fn write_message(item: OwnedMessage, writer: MessageWriter<'_, '_, Resource>) {
256        match item.0 {
257            OwnedMessageInner::Unserialized(message) => Self::write_message(message, writer),
258            OwnedMessageInner::Serialized(message) => {
259                SerializedMessageEncoder::write_message(message, writer)
260            }
261        }
262    }
263
264    fn compute_message_size(item: &mut OwnedMessage, sizer: MessageSizer<'_>) {
265        match &mut item.0 {
266            OwnedMessageInner::Unserialized(message) => Self::compute_message_size(message, sizer),
267            OwnedMessageInner::Serialized(message) => {
268                SerializedMessageEncoder::compute_message_size(message, sizer)
269            }
270        }
271    }
272}
273
274impl MessageDecode<'_, OwnedMessage, Resource> for MessageEncoder {
275    fn read_message(
276        item: &mut inplace::InplaceOption<'_, OwnedMessage>,
277        reader: mesh_protobuf::protobuf::MessageReader<'_, '_, Resource>,
278    ) -> mesh_protobuf::Result<()> {
279        let message = item.take().map(OwnedMessage::serialize);
280        inplace!(message);
281        SerializedMessageEncoder::read_message(&mut message, reader)?;
282        item.set(OwnedMessage(OwnedMessageInner::Serialized(
283            message.take().unwrap(),
284        )));
285        Ok(())
286    }
287}
288
289/// A message on a port.
290///
291/// The message may reference data with non-static lifetime, and it may not be
292/// [`Send`]. See [`OwnedMessage`] for a version that is [`Send`].
293pub struct Message<'a>(MessageInner<'a>);
294
295impl Debug for Message<'_> {
296    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
297        f.pad("Message")
298    }
299}
300
301enum MessageInner<'a> {
302    Owned(OwnedMessage),
303    Stack(StackMessage<'a>),
304    View(&'a [u8], Vec<Resource>),
305}
306
307impl<'a> Message<'a> {
308    /// Returns a new instance wrapping `message`. The message will be boxed and
309    /// will be lazily serialized when needed.
310    pub fn new<T: SerializeMessage + Send>(message: T) -> Self {
311        OwnedMessage::new(message).into()
312    }
313
314    /// Returns a new instance that logically owns the contents of `message`,
315    /// while keeping the storage for `message` in place.
316    ///
317    /// Note that `message` need not be `Send`.
318    ///
319    /// This should be used via the [`stack_message!`] macro.
320    ///
321    /// # Safety
322    /// The caller must ensure that `message` is initialized. It will be dropped
323    /// in place when the message is dropped, so it must not be used again.
324    pub(crate) unsafe fn new_stack<T: 'a + DefaultEncoding>(message: &'a mut MaybeUninit<T>) -> Self
325    where
326        T::Encoding: MessageEncode<T, Resource>,
327    {
328        Message(MessageInner::Stack(StackMessage(
329            message.as_mut_ptr().cast(),
330            DynMessageVtable::stack::<T, T::Encoding>(),
331            PhantomData,
332        )))
333    }
334
335    /// Returns an instance for a serialized message with `data` and
336    /// `resources`.
337    pub fn serialized(data: &'a [u8], resources: Vec<Resource>) -> Self {
338        Self(MessageInner::View(data, resources))
339    }
340
341    /// Converts the message into an [`OwnedMessage`].
342    ///
343    /// If the message was created with [`Message::new`] or
344    /// [`From<OwnedMessage>`], then this operation is cheap. Otherwise, this
345    /// operation will serialize the message and allocate a new buffer.
346    pub fn into_owned(self) -> OwnedMessage {
347        let m = match self.0 {
348            MessageInner::Owned(m) => return m,
349            MessageInner::Stack(_) => SerializedMessage::from_message(self),
350            MessageInner::View(v, vec) => SerializedMessage {
351                data: v.into(),
352                resources: vec,
353            },
354        };
355        OwnedMessage(OwnedMessageInner::Serialized(m))
356    }
357
358    /// Serializes the message and returns it.
359    ///
360    /// If the message is already serialized, then this is a cheap operation.
361    pub fn serialize(self) -> (Cow<'a, [u8]>, Vec<Resource>) {
362        let m = match self.0 {
363            MessageInner::View(data, resources) => return (Cow::Borrowed(data), resources),
364            MessageInner::Owned(OwnedMessage(OwnedMessageInner::Serialized(m))) => m,
365            m => SerializedMessage::from_message(Self(m)),
366        };
367        (Cow::Owned(m.data), m.resources)
368    }
369
370    fn into_data_and_resources(self) -> (Cow<'a, [u8]>, Vec<Option<Resource>>) {
371        let (d, r) = self.serialize();
372        (d, r.into_iter().map(Some).collect::<Vec<_>>())
373    }
374
375    /// Parses the message into a value of type `T`.
376    ///
377    /// If the message was constructed with `new<T>`, then the round trip
378    /// serialization/deserialization is skipped.
379    pub fn parse<T>(mut self) -> Result<T, mesh_protobuf::Error>
380    where
381        T: 'static + DefaultEncoding,
382        T::Encoding: for<'b> MessageDecode<'b, T, Resource>,
383    {
384        if let MessageInner::Owned(m) = self.0 {
385            match m.try_unwrap() {
386                Ok(m) => return Ok(m),
387                Err(m) => {
388                    self = Self(MessageInner::Owned(m));
389                }
390            }
391        }
392        self.parse_non_static()
393    }
394
395    /// Parses the message into a value of type `T`.
396    ///
397    /// When `T` has static lifetime, prefer [`Message::parse`] instead, since
398    /// it can recover a `T` passed to [`Message::new`] without round-trip
399    /// serialization.
400    pub fn parse_non_static<T>(self) -> Result<T, mesh_protobuf::Error>
401    where
402        T: DefaultEncoding,
403        T::Encoding: for<'b> MessageDecode<'b, T, Resource>,
404    {
405        let (data, mut resources) = self.into_data_and_resources();
406        inplace_none!(message: T);
407        decode_with::<T::Encoding, _, _>(&mut message, &data, &mut resources)?;
408        Ok(message.take().expect("should be constructed"))
409    }
410}
411
412impl From<OwnedMessage> for Message<'_> {
413    fn from(m: OwnedMessage) -> Self {
414        Self(MessageInner::Owned(m))
415    }
416}
417
418impl DefaultEncoding for Message<'_> {
419    type Encoding = mesh_protobuf::encoding::MessageEncoding<MessageEncoder>;
420}
421
422impl MessageEncode<Message<'_>, Resource> for MessageEncoder {
423    fn write_message(item: Message<'_>, mut writer: MessageWriter<'_, '_, Resource>) {
424        match item.0 {
425            MessageInner::Owned(m) => Self::write_message(m, writer),
426            MessageInner::Stack(m) => m.write_message(writer),
427            MessageInner::View(data, resources) => {
428                writer.raw_message(data, resources);
429            }
430        }
431    }
432
433    fn compute_message_size(item: &mut Message<'_>, mut sizer: MessageSizer<'_>) {
434        match &mut item.0 {
435            MessageInner::Owned(m) => Self::compute_message_size(m, sizer),
436            MessageInner::Stack(m) => m.compute_message_size(sizer),
437            MessageInner::View(data, resources) => {
438                sizer.raw_message(data.len(), resources.len() as u32);
439            }
440        }
441    }
442}
443
444/// Returns a [`Message`] that takes ownership of a value but leaves the value
445/// in place on the stack.
446macro_rules! stack_message {
447    ($v:expr) => {
448        (|v| {
449            // UNSAFETY: required to call unsafe function.
450            #[expect(unsafe_code)]
451            // SAFETY: The value is initialized and never used again.
452            unsafe {
453                $crate::message::Message::new_stack(v)
454            }
455        })(&mut ::core::mem::MaybeUninit::new($v))
456    };
457}
458pub(crate) use stack_message;
459
460/// A message whose storage is on the stack.
461struct StackMessage<'a>(*mut (), &'static DynMessageVtable, PhantomData<&'a mut ()>);
462
463impl Drop for StackMessage<'_> {
464    fn drop(&mut self) {
465        // SAFETY: The value is owned.
466        unsafe { (self.1.drop)(self.0) }
467    }
468}
469
470impl StackMessage<'_> {
471    fn compute_message_size(&mut self, sizer: MessageSizer<'_>) {
472        // SAFETY: The value is owned and the vtable type matches.
473        unsafe { (self.1.compute_message_size)(self.0, sizer) }
474    }
475
476    fn write_message(self, writer: MessageWriter<'_, '_, Resource>) {
477        let Self(ptr, vtable, _) = self;
478        std::mem::forget(self);
479        // SAFETY: The value is owned and the vtable type matches.
480        unsafe { (vtable.write_message)(ptr, writer) }
481    }
482}
483
484struct DynMessageVtable {
485    compute_message_size: unsafe fn(*mut (), MessageSizer<'_>),
486    write_message: unsafe fn(*mut (), MessageWriter<'_, '_, Resource>),
487    drop: unsafe fn(*mut ()),
488}
489
490impl DynMessageVtable {
491    const fn stack<T, E: MessageEncode<T, Resource>>() -> &'static Self {
492        /// # Safety
493        ///
494        /// The caller must ensure that `ptr` points to a valid owned `T`.
495        unsafe fn compute_message_size<T, E: MessageEncode<T, Resource>>(
496            ptr: *mut (),
497            sizer: MessageSizer<'_>,
498        ) {
499            // SAFETY: The value is owned and the vtable type matches.
500            let v = unsafe { &mut *ptr.cast::<T>() };
501            E::compute_message_size(v, sizer);
502        }
503
504        /// # Safety
505        ///
506        /// The caller must ensure that `ptr` points to a valid owned `T`.
507        unsafe fn write_message<T, E: MessageEncode<T, Resource>>(
508            ptr: *mut (),
509            writer: MessageWriter<'_, '_, Resource>,
510        ) {
511            // SAFETY: The value is owned and the vtable type matches.
512            let v = unsafe { ptr.cast::<T>().read() };
513            E::write_message(v, writer);
514        }
515
516        /// # Safety
517        ///
518        /// The caller must ensure that `ptr` points to a valid owned `T`.
519        unsafe fn drop<T>(ptr: *mut ()) {
520            // SAFETY: The value is owned and the vtable type matches.
521            unsafe { ptr.cast::<T>().drop_in_place() };
522        }
523
524        const {
525            &Self {
526                compute_message_size: compute_message_size::<T, E>,
527                write_message: write_message::<T, E>,
528                drop: drop::<T>,
529            }
530        }
531    }
532}
533
534#[cfg(test)]
535mod tests {
536    use super::Message;
537    use mesh_protobuf::encoding::ImpossibleField;
538
539    #[test]
540    fn roundtrip_without_serialize() {
541        #[derive(Debug, Default)]
542        struct CantSerialize;
543        impl mesh_protobuf::DefaultEncoding for CantSerialize {
544            type Encoding = ImpossibleField;
545        }
546
547        Message::new(CantSerialize)
548            .parse::<CantSerialize>()
549            .unwrap();
550    }
551}