#![allow(unsafe_code)]
use crate::ChannelError;
use crate::RecvError;
use mesh_node::local_node::HandleMessageError;
use mesh_node::local_node::HandlePortEvent;
use mesh_node::local_node::Port;
use mesh_node::local_node::PortField;
use mesh_node::local_node::PortWithHandler;
use mesh_node::message::MeshField;
use mesh_node::message::Message;
use mesh_node::message::OwnedMessage;
use mesh_protobuf::DefaultEncoding;
use parking_lot::Mutex;
use std::fmt::Debug;
use std::future::Future;
use std::marker::PhantomData;
use std::mem::ManuallyDrop;
use std::ptr::NonNull;
use std::sync::Arc;
use std::task::ready;
use std::task::Context;
use std::task::Poll;
use std::task::Waker;
use thiserror::Error;
pub fn oneshot<T>() -> (OneshotSender<T>, OneshotReceiver<T>) {
fn oneshot_core() -> (OneshotSenderCore, OneshotReceiverCore) {
let slot = Arc::new(Slot(Mutex::new(SlotState::Waiting(None))));
(
OneshotSenderCore(slot.clone()),
OneshotReceiverCore { slot, port: None },
)
}
let (sender, receiver) = oneshot_core();
(
OneshotSender(sender, PhantomData),
OneshotReceiver(ManuallyDrop::new(receiver), PhantomData),
)
}
pub struct OneshotSender<T>(OneshotSenderCore, PhantomData<Arc<Mutex<T>>>);
impl<T> Debug for OneshotSender<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Debug::fmt(&self.0, f)
}
}
impl<T> OneshotSender<T> {
pub fn send(self, value: T) {
unsafe { self.0.send(value) }
}
}
impl<T: MeshField> DefaultEncoding for OneshotSender<T> {
type Encoding = PortField;
}
impl<T: MeshField> From<OneshotSender<T>> for Port {
fn from(sender: OneshotSender<T>) -> Self {
unsafe { sender.0.into_port::<T>() }
}
}
impl<T: MeshField> From<Port> for OneshotSender<T> {
fn from(port: Port) -> Self {
Self(OneshotSenderCore::from_port::<T>(port), PhantomData)
}
}
unsafe fn send_message<T: MeshField>(port: &Port, value: BoxedValue) {
let value = unsafe { value.cast::<T>() };
port.send_protobuf((value,));
}
fn decode_message<T: MeshField>(message: Message<'_>) -> Result<BoxedValue, ChannelError> {
let (value,) = message.parse_non_static::<(Box<T>,)>()?;
Ok(BoxedValue::new(value))
}
#[derive(Debug)]
struct Slot(Mutex<SlotState>);
#[derive(Debug)]
struct OneshotSenderCore(Arc<Slot>);
impl Drop for OneshotSenderCore {
fn drop(&mut self) {
self.close();
}
}
impl OneshotSenderCore {
fn into_slot(self) -> Arc<Slot> {
let Self(ref slot) = *ManuallyDrop::new(self);
unsafe { <*const _>::read(slot) }
}
fn close(&self) {
let mut state = self.0 .0.lock();
match std::mem::replace(&mut *state, SlotState::Done) {
SlotState::Waiting(waker) => {
drop(state);
if let Some(waker) = waker {
waker.wake();
}
}
SlotState::Sent(v) => {
*state = SlotState::Sent(v);
}
SlotState::Done => {}
SlotState::ReceiverRemote(port, _) => {
drop(port);
}
SlotState::SenderRemote { .. } => unreachable!(),
}
}
unsafe fn send<T>(self, value: T) {
fn send(this: OneshotSenderCore, value: BoxedValue) -> Option<BoxedValue> {
let slot = this.into_slot();
let mut state = slot.0.lock();
match std::mem::replace(&mut *state, SlotState::Done) {
SlotState::ReceiverRemote(port, send) => {
unsafe { send(&port, value) };
None
}
SlotState::Waiting(waker) => {
*state = SlotState::Sent(value);
drop(state);
if let Some(waker) = waker {
waker.wake();
}
None
}
SlotState::Done => Some(value),
SlotState::Sent { .. } | SlotState::SenderRemote { .. } => unreachable!(),
}
}
if let Some(value) = send(self, BoxedValue::new(Box::new(value))) {
unsafe { value.drop::<T>() };
}
}
unsafe fn into_port<T: MeshField>(self) -> Port {
fn into_port(this: OneshotSenderCore, decode: DecodeFn) -> Port {
let slot = this.into_slot();
let mut state = slot.0.lock();
match std::mem::replace(&mut *state, SlotState::Done) {
SlotState::Waiting(waker) => {
let (send, recv) = Port::new_pair();
*state = SlotState::SenderRemote(recv, decode);
drop(state);
if let Some(waker) = waker {
waker.wake();
}
send
}
SlotState::ReceiverRemote(port, _) => port,
SlotState::Done => Port::new_pair().0,
SlotState::Sent(_) | SlotState::SenderRemote { .. } => unreachable!(),
}
}
into_port(self, decode_message::<T>)
}
fn from_port<T: MeshField>(port: Port) -> Self {
fn from_port(port: Port, send: SendFn) -> OneshotSenderCore {
let slot = Arc::new(Slot(Mutex::new(SlotState::ReceiverRemote(port, send))));
OneshotSenderCore(slot)
}
from_port(port, send_message::<T>)
}
}
pub struct OneshotReceiver<T>(
ManuallyDrop<OneshotReceiverCore>,
PhantomData<Arc<Mutex<T>>>,
);
impl<T> Debug for OneshotReceiver<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Debug::fmt(&self.0, f)
}
}
impl<T> OneshotReceiver<T> {
fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Result<T, RecvError>> {
let v = unsafe { ready!(self.0.poll_recv(cx))? };
Ok(*v).into()
}
fn into_core(self) -> OneshotReceiverCore {
let Self(ref core, _) = *ManuallyDrop::new(self);
unsafe { <*const _>::read(&**core) }
}
}
impl<T> Drop for OneshotReceiver<T> {
fn drop(&mut self) {
let core = unsafe { ManuallyDrop::take(&mut self.0) };
unsafe { core.drop::<T>() };
}
}
impl<T> Future for OneshotReceiver<T> {
type Output = Result<T, RecvError>;
fn poll(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.get_mut().poll_recv(cx)
}
}
impl<T: MeshField> DefaultEncoding for OneshotReceiver<T> {
type Encoding = PortField;
}
impl<T: MeshField> From<OneshotReceiver<T>> for Port {
fn from(receiver: OneshotReceiver<T>) -> Self {
unsafe { receiver.into_core().into_port::<T>() }
}
}
impl<T: MeshField> From<Port> for OneshotReceiver<T> {
fn from(port: Port) -> Self {
Self(
ManuallyDrop::new(OneshotReceiverCore::from_port::<T>(port)),
PhantomData,
)
}
}
#[derive(Debug)]
struct OneshotReceiverCore {
slot: Arc<Slot>,
port: Option<PortWithHandler<SlotHandler>>,
}
impl OneshotReceiverCore {
unsafe fn drop<T>(self) {
fn clear(this: OneshotReceiverCore) -> Option<BoxedValue> {
let OneshotReceiverCore { slot, port } = this;
drop(port);
let v = if let SlotState::Sent(value) =
std::mem::replace(&mut *slot.0.lock(), SlotState::Done)
{
Some(value)
} else {
None
};
v
}
if let Some(v) = clear(self) {
unsafe { v.drop::<T>() };
}
}
unsafe fn poll_recv<T>(&mut self, cx: &mut Context<'_>) -> Poll<Result<Box<T>, RecvError>> {
fn poll_recv(
this: &mut OneshotReceiverCore,
cx: &mut Context<'_>,
) -> Poll<Result<BoxedValue, RecvError>> {
let v = loop {
let mut state = this.slot.0.lock();
break match std::mem::replace(&mut *state, SlotState::Done) {
SlotState::SenderRemote(port, decode) => {
*state = SlotState::Waiting(None);
drop(state);
assert!(this.port.is_none());
this.port = Some(port.set_handler(SlotHandler {
slot: this.slot.clone(),
decode,
}));
continue;
}
SlotState::Waiting(mut waker) => {
if let Some(waker) = &mut waker {
waker.clone_from(cx.waker());
} else {
waker = Some(cx.waker().clone());
}
*state = SlotState::Waiting(waker);
return Poll::Pending;
}
SlotState::Sent(data) => Ok(data),
SlotState::Done => {
let err = this.port.as_ref().map_or(RecvError::Closed, |port| {
port.is_closed()
.map(|_| RecvError::Closed)
.unwrap_or_else(|err| RecvError::Error(err.into()))
});
Err(err)
}
SlotState::ReceiverRemote { .. } => {
unreachable!()
}
};
};
Poll::Ready(v)
}
ready!(poll_recv(self, cx))
.map(|v| {
unsafe { v.cast::<T>() }
})
.into()
}
unsafe fn into_port<T: MeshField>(self) -> Port {
fn into_port(this: OneshotReceiverCore, send: SendFn) -> Port {
let OneshotReceiverCore { slot, port } = this;
let existing = port.map(|port| port.remove_handler().0);
let mut state = slot.0.lock();
match std::mem::replace(&mut *state, SlotState::Done) {
SlotState::SenderRemote(port, _) => {
assert!(existing.is_none());
port
}
SlotState::Waiting(_) => existing.unwrap_or_else(|| {
let (sender, recv) = Port::new_pair();
*state = SlotState::ReceiverRemote(recv, send);
sender
}),
SlotState::Sent(value) => {
let (sender, recv) = Port::new_pair();
unsafe { send(&sender, value) };
if let Some(existing) = existing {
existing.bridge(sender);
}
recv
}
SlotState::Done => existing.unwrap_or_else(|| Port::new_pair().0),
SlotState::ReceiverRemote { .. } => unreachable!(),
}
}
into_port(self, send_message::<T>)
}
fn from_port<T: MeshField>(port: Port) -> Self {
fn from_port(port: Port, decode: DecodeFn) -> OneshotReceiverCore {
let slot = Arc::new(Slot(Mutex::new(SlotState::SenderRemote(port, decode))));
OneshotReceiverCore { slot, port: None }
}
from_port(port, decode_message::<T>)
}
}
#[derive(Debug)]
enum SlotState {
Done,
Waiting(Option<Waker>),
Sent(BoxedValue),
SenderRemote(Port, DecodeFn),
ReceiverRemote(Port, SendFn),
}
type SendFn = unsafe fn(&Port, BoxedValue);
type DecodeFn = unsafe fn(Message<'_>) -> Result<BoxedValue, ChannelError>;
#[derive(Debug)]
struct BoxedValue(NonNull<()>);
unsafe impl Send for BoxedValue {}
unsafe impl Sync for BoxedValue {}
impl BoxedValue {
fn new<T>(value: Box<T>) -> Self {
Self(NonNull::new(Box::into_raw(value).cast()).unwrap())
}
#[expect(clippy::unnecessary_box_returns)]
unsafe fn cast<T>(self) -> Box<T> {
unsafe { Box::from_raw(self.0.cast::<T>().as_ptr()) }
}
unsafe fn drop<T>(self) {
let _ = unsafe { self.cast::<T>() };
}
}
#[derive(Debug, Error)]
#[error("unexpected oneshot message")]
struct UnexpectedMessage;
struct SlotHandler {
slot: Arc<Slot>,
decode: DecodeFn,
}
impl SlotHandler {
fn close_or_fail(
&mut self,
control: &mut mesh_node::local_node::PortControl<'_, '_>,
fail: bool,
) {
let mut state = self.slot.0.lock();
match std::mem::replace(&mut *state, SlotState::Done) {
SlotState::Waiting(waker) => {
if let Some(waker) = waker {
control.wake(waker);
}
}
SlotState::Sent(v) => {
if !fail {
*state = SlotState::Sent(v);
}
}
SlotState::Done => {}
SlotState::SenderRemote { .. } | SlotState::ReceiverRemote { .. } => unreachable!(),
}
}
}
impl HandlePortEvent for SlotHandler {
fn message(
&mut self,
control: &mut mesh_node::local_node::PortControl<'_, '_>,
message: Message<'_>,
) -> Result<(), HandleMessageError> {
let mut state = self.slot.0.lock();
match std::mem::replace(&mut *state, SlotState::Done) {
SlotState::Waiting(waker) => {
let r = unsafe { (self.decode)(message) };
let value = match r {
Ok(v) => v,
Err(err) => {
*state = SlotState::Waiting(waker);
return Err(HandleMessageError::new(err));
}
};
*state = SlotState::Sent(value);
drop(state);
if let Some(waker) = waker {
control.wake(waker);
}
}
SlotState::Sent(v) => {
*state = SlotState::Sent(v);
return Err(HandleMessageError::new(UnexpectedMessage));
}
SlotState::Done => {
*state = SlotState::Done;
}
SlotState::SenderRemote { .. } | SlotState::ReceiverRemote { .. } => unreachable!(),
}
Ok(())
}
fn close(&mut self, control: &mut mesh_node::local_node::PortControl<'_, '_>) {
self.close_or_fail(control, false);
}
fn fail(
&mut self,
control: &mut mesh_node::local_node::PortControl<'_, '_>,
_err: mesh_node::local_node::NodeError,
) {
self.close_or_fail(control, true);
}
fn drain(&mut self) -> Vec<OwnedMessage> {
Vec::new()
}
}
#[cfg(test)]
mod tests {
use super::oneshot;
use crate::OneshotReceiver;
use crate::OneshotSender;
use crate::RecvError;
use futures::executor::block_on;
use futures::task::SpawnExt;
use futures::FutureExt;
use mesh_node::local_node::Port;
use mesh_node::message::Message;
use std::cell::Cell;
use std::future::poll_fn;
use test_with_tracing::test;
static_assertions::assert_impl_all!(OneshotSender<i32>: Send, Sync);
static_assertions::assert_impl_all!(OneshotReceiver<i32>: Send, Sync);
static_assertions::assert_impl_all!(OneshotSender<Cell<i32>>: Send, Sync);
static_assertions::assert_impl_all!(OneshotReceiver<Cell<i32>>: Send, Sync);
static_assertions::assert_not_impl_any!(OneshotSender<*const ()>: Send, Sync);
static_assertions::assert_not_impl_any!(OneshotReceiver<*const ()>: Send, Sync);
#[test]
fn test_oneshot() {
block_on(async {
let (sender, receiver) = oneshot();
sender.send(String::from("foo"));
assert_eq!(receiver.await.unwrap(), "foo");
})
}
#[test]
fn test_oneshot_convert_sender_port() {
block_on(async {
let (sender, receiver) = oneshot::<String>();
let sender = OneshotSender::<String>::from(Port::from(sender));
sender.send(String::from("foo"));
assert_eq!(receiver.await.unwrap(), "foo");
})
}
#[test]
fn test_oneshot_convert_receiver_port() {
block_on(async {
let (sender, receiver) = oneshot::<String>();
let receiver = OneshotReceiver::<String>::from(Port::from(receiver));
sender.send(String::from("foo"));
assert_eq!(receiver.await.unwrap(), "foo");
})
}
#[test]
fn test_oneshot_convert_receiver_port_after_send() {
block_on(async {
let (sender, receiver) = oneshot::<String>();
sender.send(String::from("foo"));
let receiver = OneshotReceiver::<String>::from(Port::from(receiver));
assert_eq!(receiver.await.unwrap(), "foo");
})
}
#[test]
fn test_oneshot_convert_both() {
block_on(async {
let (sender, receiver) = oneshot::<String>();
let sender = OneshotSender::<String>::from(Port::from(sender));
let receiver = OneshotReceiver::<String>::from(Port::from(receiver));
sender.send(String::from("foo"));
assert_eq!(receiver.await.unwrap(), "foo");
})
}
#[test]
fn test_oneshot_convert_both_poll_first() {
block_on(async {
let (sender, mut receiver) = oneshot::<String>();
let sender = OneshotSender::<String>::from(Port::from(sender));
assert!(poll_fn(|cx| receiver.poll_recv(cx))
.now_or_never()
.is_none());
let receiver = OneshotReceiver::<String>::from(Port::from(receiver));
sender.send(String::from("foo"));
assert_eq!(receiver.await.unwrap(), "foo");
})
}
#[test]
fn test_oneshot_message_corruption() {
let mut pool = futures::executor::LocalPool::new();
let spawner = pool.spawner();
pool.run_until(async {
let (sender, receiver) = oneshot();
let receiver = OneshotReceiver::<i32>::from(Port::from(receiver));
let receiver = spawner.spawn_with_handle(receiver).unwrap();
futures::pending!();
sender.send("text".to_owned());
let RecvError::Error(err) = receiver.await.unwrap_err() else {
panic!()
};
tracing::info!(error = &err as &dyn std::error::Error, "expected error");
})
}
#[test]
fn test_oneshot_extra_messages() {
block_on(async {
let (sender, mut receiver) = oneshot::<()>();
let sender = Port::from(sender);
assert!(futures::poll!(&mut receiver).is_pending());
sender.send(Message::new(()));
sender.send(Message::new(()));
let RecvError::Error(err) = receiver.await.unwrap_err() else {
panic!()
};
tracing::info!(error = &err as &dyn std::error::Error, "expected error");
})
}
#[test]
fn test_oneshot_closed() {
block_on(async {
let (sender, receiver) = oneshot::<()>();
drop(sender);
let RecvError::Closed = receiver.await.unwrap_err() else {
panic!()
};
})
}
}