#![expect(unsafe_code)]
use crate::resource::Resource;
use crate::resource::SerializedMessage;
use mesh_protobuf;
use mesh_protobuf::encoding::SerializedMessageEncoder;
use mesh_protobuf::inplace;
use mesh_protobuf::inplace_none;
use mesh_protobuf::protobuf::decode_with;
use mesh_protobuf::protobuf::MessageSizer;
use mesh_protobuf::protobuf::MessageWriter;
use mesh_protobuf::DefaultEncoding;
use mesh_protobuf::MessageDecode;
use mesh_protobuf::MessageEncode;
use std::any::Any;
use std::any::TypeId;
use std::borrow::Cow;
use std::fmt;
use std::fmt::Debug;
use std::marker::PhantomData;
use std::mem::MaybeUninit;
#[derive(Default)]
pub struct OwnedMessage(OwnedMessageInner);
enum OwnedMessageInner {
Unserialized(Box<dyn DynSerializeMessage>),
Serialized(SerializedMessage),
}
impl Debug for OwnedMessage {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.pad("OwnedMessage")
}
}
impl Default for OwnedMessageInner {
fn default() -> Self {
Self::Serialized(Default::default())
}
}
impl OwnedMessage {
pub fn serialize(self) -> SerializedMessage {
match self.0 {
OwnedMessageInner::Unserialized(_) => SerializedMessage::from_message(self),
OwnedMessageInner::Serialized(message) => message,
}
}
}
pub trait MeshPayload: DefaultEncoding<Encoding = <Self as MeshPayload>::Encoding> + Sized {
type Encoding: MessageEncode<Self, Resource>
+ for<'a> MessageDecode<'a, Self, Resource>
+ mesh_protobuf::FieldEncode<Self, Resource>
+ for<'a> mesh_protobuf::FieldDecode<'a, Self, Resource>
+ Send
+ Sync;
}
impl<T> MeshPayload for T
where
T: DefaultEncoding + Any + Send + 'static,
T::Encoding: MessageEncode<T, Resource>
+ for<'a> MessageDecode<'a, T, Resource>
+ mesh_protobuf::FieldEncode<T, Resource>
+ for<'a> mesh_protobuf::FieldDecode<'a, T, Resource>
+ Send
+ Sync,
{
type Encoding = T::Encoding;
}
pub trait MeshField: DefaultEncoding<Encoding = <Self as MeshField>::Encoding> + Sized {
type Encoding: mesh_protobuf::FieldEncode<Self, Resource>
+ for<'a> mesh_protobuf::FieldDecode<'a, Self, Resource>
+ Send
+ Sync;
}
impl<T> MeshField for T
where
T: DefaultEncoding,
T::Encoding: mesh_protobuf::FieldEncode<T, Resource>
+ for<'a> mesh_protobuf::FieldDecode<'a, T, Resource>
+ Send
+ Sync,
{
type Encoding = T::Encoding;
}
pub trait SerializeMessage: 'static {
type Concrete: Any;
fn compute_message_size(&mut self, sizer: MessageSizer<'_>);
fn write_message(self, writer: MessageWriter<'_, '_, Resource>);
fn extract(self) -> Self::Concrete;
}
unsafe trait DynSerializeMessage: Send {
fn compute_message_size(&mut self, sizer: MessageSizer<'_>);
fn write_message(self: Box<Self>, writer: MessageWriter<'_, '_, Resource>);
unsafe fn extract(
self: Box<Self>,
type_id: TypeId,
ptr: *mut (),
) -> Result<(), Box<dyn DynSerializeMessage>>;
}
unsafe impl<T: SerializeMessage + Send> DynSerializeMessage for T {
fn compute_message_size(&mut self, sizer: MessageSizer<'_>) {
self.compute_message_size(sizer)
}
fn write_message(self: Box<Self>, writer: MessageWriter<'_, '_, Resource>) {
(*self).write_message(writer)
}
unsafe fn extract(
self: Box<Self>,
type_id: TypeId,
ptr: *mut (),
) -> Result<(), Box<dyn DynSerializeMessage>> {
if type_id == TypeId::of::<T::Concrete>() {
unsafe { ptr.cast::<T::Concrete>().write((*self).extract()) };
Ok(())
} else {
Err(self)
}
}
}
impl<T: 'static + MeshPayload + Send> SerializeMessage for T {
type Concrete = Self;
fn compute_message_size(&mut self, sizer: MessageSizer<'_>) {
<T as MeshPayload>::Encoding::compute_message_size(self, sizer)
}
fn write_message(self, writer: MessageWriter<'_, '_, Resource>) {
<T as MeshPayload>::Encoding::write_message(self, writer)
}
fn extract(self) -> Self::Concrete {
self
}
}
impl OwnedMessage {
#[inline]
pub fn new<T: SerializeMessage + Send>(data: T) -> Self {
Self(OwnedMessageInner::Unserialized(Box::new(data)))
}
pub fn serialized(s: SerializedMessage) -> Self {
Self(OwnedMessageInner::Serialized(s))
}
pub fn parse<T>(self) -> Result<T, mesh_protobuf::Error>
where
T: 'static + DefaultEncoding,
T::Encoding: for<'a> MessageDecode<'a, T, Resource>,
{
Message::from(self).parse()
}
pub fn try_unwrap<T: 'static>(self) -> Result<T, Self> {
match self.0 {
OwnedMessageInner::Unserialized(m) => {
let mut message = MaybeUninit::<T>::uninit();
unsafe {
match m.extract(TypeId::of::<T>(), message.as_mut_ptr().cast()) {
Ok(()) => Ok(message.assume_init()),
Err(message) => Err(Self(OwnedMessageInner::Unserialized(message))),
}
}
}
OwnedMessageInner::Serialized(_) => Err(self),
}
}
}
impl DefaultEncoding for OwnedMessage {
type Encoding = mesh_protobuf::encoding::MessageEncoding<MessageEncoder>;
}
pub struct MessageEncoder;
impl MessageEncode<Box<dyn DynSerializeMessage>, Resource> for MessageEncoder {
fn write_message(item: Box<dyn DynSerializeMessage>, writer: MessageWriter<'_, '_, Resource>) {
item.write_message(writer);
}
fn compute_message_size(item: &mut Box<dyn DynSerializeMessage>, sizer: MessageSizer<'_>) {
item.compute_message_size(sizer);
}
}
impl MessageEncode<OwnedMessage, Resource> for MessageEncoder {
fn write_message(item: OwnedMessage, writer: MessageWriter<'_, '_, Resource>) {
match item.0 {
OwnedMessageInner::Unserialized(message) => Self::write_message(message, writer),
OwnedMessageInner::Serialized(message) => {
SerializedMessageEncoder::write_message(message, writer)
}
}
}
fn compute_message_size(item: &mut OwnedMessage, sizer: MessageSizer<'_>) {
match &mut item.0 {
OwnedMessageInner::Unserialized(message) => Self::compute_message_size(message, sizer),
OwnedMessageInner::Serialized(message) => {
SerializedMessageEncoder::compute_message_size(message, sizer)
}
}
}
}
impl MessageDecode<'_, OwnedMessage, Resource> for MessageEncoder {
fn read_message(
item: &mut inplace::InplaceOption<'_, OwnedMessage>,
reader: mesh_protobuf::protobuf::MessageReader<'_, '_, Resource>,
) -> mesh_protobuf::Result<()> {
let message = item.take().map(OwnedMessage::serialize);
inplace!(message);
SerializedMessageEncoder::read_message(&mut message, reader)?;
item.set(OwnedMessage(OwnedMessageInner::Serialized(
message.take().unwrap(),
)));
Ok(())
}
}
pub struct Message<'a>(MessageInner<'a>);
impl Debug for Message<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.pad("Message")
}
}
enum MessageInner<'a> {
Owned(OwnedMessage),
Stack(StackMessage<'a>),
View(&'a [u8], Vec<Resource>),
}
impl<'a> Message<'a> {
pub fn new<T: SerializeMessage + Send>(message: T) -> Self {
OwnedMessage::new(message).into()
}
pub(crate) unsafe fn new_stack<T: 'a + DefaultEncoding>(message: &'a mut MaybeUninit<T>) -> Self
where
T::Encoding: MessageEncode<T, Resource>,
{
Message(MessageInner::Stack(StackMessage(
message.as_mut_ptr().cast(),
DynMessageVtable::stack::<T, T::Encoding>(),
PhantomData,
)))
}
pub fn serialized(data: &'a [u8], resources: Vec<Resource>) -> Self {
Self(MessageInner::View(data, resources))
}
pub fn into_owned(self) -> OwnedMessage {
let m = match self.0 {
MessageInner::Owned(m) => return m,
MessageInner::Stack(_) => SerializedMessage::from_message(self),
MessageInner::View(v, vec) => SerializedMessage {
data: v.into(),
resources: vec,
},
};
OwnedMessage(OwnedMessageInner::Serialized(m))
}
pub fn serialize(self) -> (Cow<'a, [u8]>, Vec<Resource>) {
let m = match self.0 {
MessageInner::View(data, resources) => return (Cow::Borrowed(data), resources),
MessageInner::Owned(OwnedMessage(OwnedMessageInner::Serialized(m))) => m,
m => SerializedMessage::from_message(Self(m)),
};
(Cow::Owned(m.data), m.resources)
}
fn into_data_and_resources(self) -> (Cow<'a, [u8]>, Vec<Option<Resource>>) {
let (d, r) = self.serialize();
(d, r.into_iter().map(Some).collect::<Vec<_>>())
}
pub fn parse<T>(mut self) -> Result<T, mesh_protobuf::Error>
where
T: 'static + DefaultEncoding,
T::Encoding: for<'b> MessageDecode<'b, T, Resource>,
{
if let MessageInner::Owned(m) = self.0 {
match m.try_unwrap() {
Ok(m) => return Ok(m),
Err(m) => {
self = Self(MessageInner::Owned(m));
}
}
}
self.parse_non_static()
}
pub fn parse_non_static<T>(self) -> Result<T, mesh_protobuf::Error>
where
T: DefaultEncoding,
T::Encoding: for<'b> MessageDecode<'b, T, Resource>,
{
let (data, mut resources) = self.into_data_and_resources();
inplace_none!(message: T);
decode_with::<T::Encoding, _, _>(&mut message, &data, &mut resources)?;
Ok(message.take().expect("should be constructed"))
}
}
impl From<OwnedMessage> for Message<'_> {
fn from(m: OwnedMessage) -> Self {
Self(MessageInner::Owned(m))
}
}
impl DefaultEncoding for Message<'_> {
type Encoding = mesh_protobuf::encoding::MessageEncoding<MessageEncoder>;
}
impl MessageEncode<Message<'_>, Resource> for MessageEncoder {
fn write_message(item: Message<'_>, mut writer: MessageWriter<'_, '_, Resource>) {
match item.0 {
MessageInner::Owned(m) => Self::write_message(m, writer),
MessageInner::Stack(m) => m.write_message(writer),
MessageInner::View(data, resources) => {
writer.raw_message(data, resources);
}
}
}
fn compute_message_size(item: &mut Message<'_>, mut sizer: MessageSizer<'_>) {
match &mut item.0 {
MessageInner::Owned(m) => Self::compute_message_size(m, sizer),
MessageInner::Stack(m) => m.compute_message_size(sizer),
MessageInner::View(data, resources) => {
sizer.raw_message(data.len(), resources.len() as u32);
}
}
}
}
macro_rules! stack_message {
($v:expr) => {
#[expect(unsafe_code)]
{
unsafe { $crate::message::Message::new_stack(&mut ::core::mem::MaybeUninit::new($v)) }
}
};
}
pub(crate) use stack_message;
struct StackMessage<'a>(*mut (), &'static DynMessageVtable, PhantomData<&'a mut ()>);
impl Drop for StackMessage<'_> {
fn drop(&mut self) {
unsafe { (self.1.drop)(self.0) }
}
}
impl StackMessage<'_> {
fn compute_message_size(&mut self, sizer: MessageSizer<'_>) {
unsafe { (self.1.compute_message_size)(self.0, sizer) }
}
fn write_message(self, writer: MessageWriter<'_, '_, Resource>) {
let Self(ptr, vtable, _) = self;
std::mem::forget(self);
unsafe { (vtable.write_message)(ptr, writer) }
}
}
struct DynMessageVtable {
compute_message_size: unsafe fn(*mut (), MessageSizer<'_>),
write_message: unsafe fn(*mut (), MessageWriter<'_, '_, Resource>),
drop: unsafe fn(*mut ()),
}
impl DynMessageVtable {
const fn stack<T, E: MessageEncode<T, Resource>>() -> &'static Self {
unsafe fn compute_message_size<T, E: MessageEncode<T, Resource>>(
ptr: *mut (),
sizer: MessageSizer<'_>,
) {
let v = unsafe { &mut *ptr.cast::<T>() };
E::compute_message_size(v, sizer);
}
unsafe fn write_message<T, E: MessageEncode<T, Resource>>(
ptr: *mut (),
writer: MessageWriter<'_, '_, Resource>,
) {
let v = unsafe { ptr.cast::<T>().read() };
E::write_message(v, writer);
}
unsafe fn drop<T>(ptr: *mut ()) {
unsafe { ptr.cast::<T>().drop_in_place() };
}
const {
&Self {
compute_message_size: compute_message_size::<T, E>,
write_message: write_message::<T, E>,
drop: drop::<T>,
}
}
}
}
#[cfg(test)]
mod tests {
use super::Message;
use mesh_protobuf::encoding::ImpossibleField;
#[test]
fn roundtrip_without_serialize() {
#[derive(Debug, Default)]
struct CantSerialize;
impl mesh_protobuf::DefaultEncoding for CantSerialize {
type Encoding = ImpossibleField;
}
Message::new(CantSerialize)
.parse::<CantSerialize>()
.unwrap();
}
}