#![allow(unsafe_code)]
use crate::resource::Resource;
use crate::resource::SerializedMessage;
use mesh_protobuf;
use mesh_protobuf::encoding::SerializedMessageEncoder;
use mesh_protobuf::inplace;
use mesh_protobuf::protobuf::Encoder;
use mesh_protobuf::protobuf::MessageSizer;
use mesh_protobuf::protobuf::MessageWriter;
use mesh_protobuf::MessageEncode;
use std::any::Any;
use std::any::TypeId;
use std::fmt;
use std::fmt::Debug;
use std::mem::MaybeUninit;
#[derive(Default)]
pub struct Message(MessageInner);
enum MessageInner {
Unserialized(Box<dyn DynSerializeMessage>),
Serialized(SerializedMessage),
}
impl Debug for Message {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.pad("Message")
}
}
impl Default for MessageInner {
fn default() -> Self {
Self::Serialized(Default::default())
}
}
impl Message {
pub fn serialize(self) -> SerializedMessage {
match self.0 {
MessageInner::Unserialized(_) => SerializedMessage::from_message(self),
MessageInner::Serialized(message) => message,
}
}
}
pub trait MeshPayload:
mesh_protobuf::DefaultEncoding<Encoding = <Self as MeshPayload>::Encoding> + Send + 'static + Sized
{
type Encoding: MessageEncode<Self, Resource>
+ for<'a> mesh_protobuf::MessageDecode<'a, Self, Resource>
+ mesh_protobuf::FieldEncode<Self, Resource>
+ for<'a> mesh_protobuf::FieldDecode<'a, Self, Resource>
+ Send
+ Sync;
}
impl<T> MeshPayload for T
where
T: mesh_protobuf::DefaultEncoding + Any + Send + 'static,
T::Encoding: MessageEncode<T, Resource>
+ for<'a> mesh_protobuf::MessageDecode<'a, T, Resource>
+ mesh_protobuf::FieldEncode<T, Resource>
+ for<'a> mesh_protobuf::FieldDecode<'a, T, Resource>
+ Send
+ Sync,
{
type Encoding = T::Encoding;
}
pub trait MeshField:
mesh_protobuf::DefaultEncoding<Encoding = <Self as MeshField>::Encoding> + Send + 'static + Sized
{
type Encoding: mesh_protobuf::FieldEncode<Self, Resource>
+ for<'a> mesh_protobuf::FieldDecode<'a, Self, Resource>
+ Send
+ Sync;
}
impl<T> MeshField for T
where
T: mesh_protobuf::DefaultEncoding + Any + Send + 'static,
T::Encoding: mesh_protobuf::FieldEncode<T, Resource>
+ for<'a> mesh_protobuf::FieldDecode<'a, T, Resource>
+ Send
+ Sync,
{
type Encoding = T::Encoding;
}
pub trait SerializeMessage: 'static + Send {
type Concrete: Any;
fn compute_message_size(&mut self, sizer: MessageSizer<'_>);
fn write_message(self, writer: MessageWriter<'_, '_, Resource>);
fn extract(self) -> Self::Concrete;
}
unsafe trait DynSerializeMessage: Send {
fn compute_message_size(&mut self, sizer: MessageSizer<'_>);
fn write_message(self: Box<Self>, writer: MessageWriter<'_, '_, Resource>);
unsafe fn extract(
self: Box<Self>,
type_id: TypeId,
ptr: *mut (),
) -> Result<(), Box<dyn DynSerializeMessage>>;
}
unsafe impl<T: SerializeMessage> DynSerializeMessage for T {
fn compute_message_size(&mut self, sizer: MessageSizer<'_>) {
self.compute_message_size(sizer)
}
fn write_message(self: Box<Self>, writer: MessageWriter<'_, '_, Resource>) {
(*self).write_message(writer)
}
unsafe fn extract(
self: Box<Self>,
type_id: TypeId,
ptr: *mut (),
) -> Result<(), Box<dyn DynSerializeMessage>> {
if type_id == TypeId::of::<T::Concrete>() {
unsafe { ptr.cast::<T::Concrete>().write((*self).extract()) };
Ok(())
} else {
Err(self)
}
}
}
fn serialize_dyn_message(message: Box<dyn DynSerializeMessage>) -> SerializedMessage {
let (data, resources) = Encoder::<_, MessageEncoder, _>::with_encoding(message).encode();
SerializedMessage { data, resources }
}
impl<T: MeshPayload> SerializeMessage for T {
type Concrete = Self;
fn compute_message_size(&mut self, sizer: MessageSizer<'_>) {
<T as MeshPayload>::Encoding::compute_message_size(self, sizer)
}
fn write_message(self, writer: MessageWriter<'_, '_, Resource>) {
<T as MeshPayload>::Encoding::write_message(self, writer)
}
fn extract(self) -> Self::Concrete {
self
}
}
impl Message {
#[inline]
pub fn new<T: SerializeMessage>(data: T) -> Self {
Self(MessageInner::Unserialized(Box::new(data)))
}
pub fn serialized(s: SerializedMessage) -> Self {
Self(MessageInner::Serialized(s))
}
pub fn parse<T: MeshPayload>(self) -> Result<T, mesh_protobuf::Error> {
self.try_parse().or_else(|m| m.into_message())
}
pub fn try_parse<T: 'static + Send>(self) -> Result<T, SerializedMessage> {
match self.0 {
MessageInner::Unserialized(m) => {
let mut message = MaybeUninit::<T>::uninit();
unsafe {
match m.extract(TypeId::of::<T>(), message.as_mut_ptr().cast()) {
Ok(()) => Ok(message.assume_init()),
Err(message) => Err(serialize_dyn_message(message)),
}
}
}
MessageInner::Serialized(m) => Err(m),
}
}
}
impl mesh_protobuf::DefaultEncoding for Message {
type Encoding = mesh_protobuf::encoding::MessageEncoding<MessageEncoder>;
}
pub struct MessageEncoder;
impl MessageEncode<Box<dyn DynSerializeMessage>, Resource> for MessageEncoder {
fn write_message(item: Box<dyn DynSerializeMessage>, writer: MessageWriter<'_, '_, Resource>) {
item.write_message(writer);
}
fn compute_message_size(item: &mut Box<dyn DynSerializeMessage>, sizer: MessageSizer<'_>) {
item.compute_message_size(sizer);
}
}
impl MessageEncode<Message, Resource> for MessageEncoder {
fn write_message(item: Message, writer: MessageWriter<'_, '_, Resource>) {
match item.0 {
MessageInner::Unserialized(message) => Self::write_message(message, writer),
MessageInner::Serialized(message) => {
SerializedMessageEncoder::write_message(message, writer)
}
}
}
fn compute_message_size(item: &mut Message, sizer: MessageSizer<'_>) {
match &mut item.0 {
MessageInner::Unserialized(message) => Self::compute_message_size(message, sizer),
MessageInner::Serialized(message) => {
SerializedMessageEncoder::compute_message_size(message, sizer)
}
}
}
}
impl mesh_protobuf::MessageDecode<'_, Message, Resource> for MessageEncoder {
fn read_message(
item: &mut inplace::InplaceOption<'_, Message>,
reader: mesh_protobuf::protobuf::MessageReader<'_, '_, Resource>,
) -> mesh_protobuf::Result<()> {
let message = item.take().map(Message::serialize);
inplace!(message);
SerializedMessageEncoder::read_message(&mut message, reader)?;
item.set(Message::serialized(message.take().unwrap()));
Ok(())
}
}
impl<T: MeshPayload> mesh_protobuf::Downcast<T> for Message {}
#[cfg(test)]
mod tests {
use super::Message;
use mesh_protobuf::encoding::ImpossibleField;
#[test]
fn roundtrip_without_serialize() {
#[derive(Debug, Default)]
struct CantSerialize;
impl mesh_protobuf::DefaultEncoding for CantSerialize {
type Encoding = ImpossibleField;
}
Message::new(CantSerialize)
.parse::<CantSerialize>()
.unwrap();
}
}