1#![expect(unsafe_code)]
19
20use crate::ChannelError;
21use crate::RecvError;
22use crate::sync_unsafe_cell::SyncUnsafeCell;
23use mesh_node::local_node::HandleMessageError;
24use mesh_node::local_node::HandlePortEvent;
25use mesh_node::local_node::Port;
26use mesh_node::local_node::PortField;
27use mesh_node::local_node::PortWithHandler;
28use mesh_node::message::MeshField;
29use mesh_node::message::Message;
30use mesh_node::message::OwnedMessage;
31use mesh_protobuf::DefaultEncoding;
32use parking_lot::Mutex;
33use std::fmt::Debug;
34use std::future::Future;
35use std::marker::PhantomData;
36use std::mem::ManuallyDrop;
37use std::ptr::NonNull;
38use std::sync::Arc;
39use std::task::Context;
40use std::task::Poll;
41use std::task::Waker;
42use std::task::ready;
43use thiserror::Error;
44
45pub fn oneshot<T>() -> (OneshotSender<T>, OneshotReceiver<T>) {
71 fn oneshot_core() -> (OneshotSenderCore, OneshotReceiverCore) {
72 let slot = Arc::new(Slot {
73 state: Mutex::new(SlotState::Waiting(None)),
74 receiver: Default::default(),
75 });
76 (OneshotSenderCore(slot.clone()), OneshotReceiverCore(slot))
77 }
78
79 let (sender, receiver) = oneshot_core();
80 (
81 OneshotSender(sender, PhantomData),
82 OneshotReceiver(ManuallyDrop::new(receiver), PhantomData),
83 )
84}
85
86pub struct OneshotSender<T>(OneshotSenderCore, PhantomData<Arc<Mutex<T>>>);
92
93impl<T> Debug for OneshotSender<T> {
94 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95 Debug::fmt(&self.0, f)
96 }
97}
98
99impl<T> OneshotSender<T> {
100 pub fn send(self, value: T) {
102 unsafe { self.0.send(value) }
104 }
105
106 pub fn is_closed(&self) -> bool {
108 self.0.is_closed()
109 }
110}
111
112impl<T: MeshField> DefaultEncoding for OneshotSender<T> {
113 type Encoding = PortField;
114}
115
116impl<T: MeshField> From<OneshotSender<T>> for Port {
117 fn from(sender: OneshotSender<T>) -> Self {
118 unsafe { sender.0.into_port::<T>() }
120 }
121}
122
123impl<T: MeshField> From<Port> for OneshotSender<T> {
124 fn from(port: Port) -> Self {
125 Self(OneshotSenderCore::from_port::<T>(port), PhantomData)
126 }
127}
128
129unsafe fn send_message<T: MeshField>(port: Port, value: BoxedValue) {
132 let value = unsafe { value.cast::<T>() };
134 port.send_protobuf_and_close((value,));
135}
136
137fn decode_message<T: MeshField>(message: Message<'_>) -> Result<BoxedValue, ChannelError> {
138 let (value,) = message.parse_non_static::<(Box<T>,)>()?;
139 Ok(BoxedValue::new(value))
140}
141
142#[derive(Debug)]
143struct Slot {
144 state: Mutex<SlotState>,
145 receiver: SyncUnsafeCell<ManuallyDrop<ReceiverState>>,
148}
149
150#[derive(Debug)]
151struct OneshotSenderCore(Arc<Slot>);
152
153impl Drop for OneshotSenderCore {
154 fn drop(&mut self) {
155 self.close();
156 }
157}
158
159impl OneshotSenderCore {
160 fn into_slot(self) -> Arc<Slot> {
161 let Self(ref slot) = *ManuallyDrop::new(self);
162 unsafe { <*const _>::read(slot) }
164 }
165
166 fn close(&self) {
167 let mut state = self.0.state.lock();
168 match std::mem::replace(&mut *state, SlotState::Done) {
169 SlotState::Waiting(waker) => {
170 drop(state);
171 if let Some(waker) = waker {
172 waker.wake();
173 }
174 }
175 SlotState::Sent(v) => {
176 *state = SlotState::Sent(v);
177 }
178 SlotState::Done => {}
179 SlotState::ReceiverRemote(port, _) => {
180 drop(port);
181 }
182 SlotState::SenderRemote { .. } => unreachable!(),
183 }
184 }
185
186 fn is_closed(&self) -> bool {
187 match &*self.0.state.lock() {
188 SlotState::Done => true,
189 SlotState::Sent(_) => true,
190 SlotState::Waiting(_) => false,
191 SlotState::ReceiverRemote(port, _) => port.is_closed().unwrap_or(false),
192 SlotState::SenderRemote { .. } => unreachable!(),
193 }
194 }
195
196 unsafe fn send<T>(self, value: T) {
199 fn send(this: OneshotSenderCore, value: BoxedValue) -> Option<BoxedValue> {
200 let slot = this.into_slot();
201 let mut state = slot.state.lock();
202 match std::mem::replace(&mut *state, SlotState::Done) {
203 SlotState::ReceiverRemote(port, send) => {
204 unsafe { send(port, value) };
207 None
208 }
209 SlotState::Waiting(waker) => {
210 *state = SlotState::Sent(value);
211 drop(state);
212 if let Some(waker) = waker {
213 waker.wake();
214 }
215 None
216 }
217 SlotState::Done => Some(value),
218 SlotState::Sent { .. } | SlotState::SenderRemote { .. } => unreachable!(),
219 }
220 }
221 if let Some(value) = send(self, BoxedValue::new(Box::new(value))) {
222 unsafe { value.drop::<T>() };
224 }
225 }
226
227 unsafe fn into_port<T: MeshField>(self) -> Port {
230 fn into_port(this: OneshotSenderCore, decode: DecodeFn) -> Port {
231 let slot = this.into_slot();
232 let mut state = slot.state.lock();
233 match std::mem::replace(&mut *state, SlotState::Done) {
234 SlotState::Waiting(waker) => {
235 let (send, recv) = Port::new_pair();
236 *state = SlotState::SenderRemote(recv, decode);
237 drop(state);
238 if let Some(waker) = waker {
239 waker.wake();
240 }
241 send
242 }
243 SlotState::ReceiverRemote(port, _) => port,
244 SlotState::Done => Port::new_pair().0,
245 SlotState::Sent(_) | SlotState::SenderRemote { .. } => unreachable!(),
246 }
247 }
248 into_port(self, decode_message::<T>)
249 }
250
251 fn from_port<T: MeshField>(port: Port) -> Self {
252 fn from_port(port: Port, send: SendFn) -> OneshotSenderCore {
253 let slot = Arc::new(Slot {
254 state: Mutex::new(SlotState::ReceiverRemote(port, send)),
255 receiver: Default::default(),
256 });
257 OneshotSenderCore(slot)
258 }
259 from_port(port, send_message::<T>)
260 }
261}
262
263pub struct OneshotReceiver<T>(
271 ManuallyDrop<OneshotReceiverCore>,
272 PhantomData<Arc<Mutex<T>>>,
273);
274
275impl<T> Debug for OneshotReceiver<T> {
276 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
277 Debug::fmt(&self.0, f)
278 }
279}
280
281impl<T> OneshotReceiver<T> {
282 fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Result<T, RecvError>> {
283 let v = unsafe { ready!(self.0.poll_recv(cx))? };
285 Ok(*v).into()
286 }
287
288 fn into_core(self) -> OneshotReceiverCore {
289 let Self(ref core, _) = *ManuallyDrop::new(self);
290 unsafe { <*const _>::read(&**core) }
292 }
293}
294
295impl<T> Drop for OneshotReceiver<T> {
296 fn drop(&mut self) {
297 let core = unsafe { ManuallyDrop::take(&mut self.0) };
299 unsafe { core.drop::<T>() };
301 }
302}
303
304impl<T> Future for OneshotReceiver<T> {
307 type Output = Result<T, RecvError>;
308
309 fn poll(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
310 self.get_mut().poll_recv(cx)
311 }
312}
313
314impl<T: MeshField> DefaultEncoding for OneshotReceiver<T> {
315 type Encoding = PortField;
316}
317
318impl<T: MeshField> From<OneshotReceiver<T>> for Port {
319 fn from(receiver: OneshotReceiver<T>) -> Self {
320 unsafe { receiver.into_core().into_port::<T>() }
322 }
323}
324
325impl<T: MeshField> From<Port> for OneshotReceiver<T> {
326 fn from(port: Port) -> Self {
327 Self(
328 ManuallyDrop::new(OneshotReceiverCore::from_port::<T>(port)),
329 PhantomData,
330 )
331 }
332}
333
334#[derive(Debug)]
335struct OneshotReceiverCore(Arc<Slot>);
336
337#[derive(Default)]
338struct ReceiverState {
339 port: Option<PortWithHandler<SlotHandler>>,
340}
341
342impl OneshotReceiverCore {
343 fn split(self) -> (Arc<Slot>, ReceiverState) {
344 let receiver = unsafe { ManuallyDrop::take(&mut *self.0.receiver.0.get()) };
347 (self.0, receiver)
348 }
349
350 fn split_mut(&mut self) -> (&Arc<Slot>, &mut ReceiverState) {
351 let receiver = unsafe { &mut *self.0.receiver.0.get() };
353 (&self.0, receiver)
354 }
355
356 unsafe fn drop<T>(self) {
364 fn clear(this: OneshotReceiverCore) -> Option<BoxedValue> {
365 let (slot, ReceiverState { port }) = this.split();
366 drop(port);
367 let v = if let SlotState::Sent(value) =
372 std::mem::replace(&mut *slot.state.lock(), SlotState::Done)
373 {
374 Some(value)
375 } else {
376 None
377 };
378 v
379 }
380 if let Some(v) = clear(self) {
381 unsafe { v.drop::<T>() };
383 }
384 }
385
386 unsafe fn poll_recv<T>(&mut self, cx: &mut Context<'_>) -> Poll<Result<Box<T>, RecvError>> {
390 fn poll_recv(
391 this: &mut OneshotReceiverCore,
392 cx: &mut Context<'_>,
393 ) -> Poll<Result<BoxedValue, RecvError>> {
394 let (slot, recv) = this.split_mut();
395 let v = loop {
396 let mut state = slot.state.lock();
397 break match std::mem::replace(&mut *state, SlotState::Done) {
398 SlotState::SenderRemote(port, decode) => {
399 *state = SlotState::Waiting(None);
400 drop(state);
401 assert!(recv.port.is_none());
402 recv.port = Some(port.set_handler(SlotHandler {
403 slot: slot.clone(),
404 decode,
405 }));
406 continue;
407 }
408 SlotState::Waiting(mut waker) => {
409 if let Some(waker) = &mut waker {
410 waker.clone_from(cx.waker());
411 } else {
412 waker = Some(cx.waker().clone());
413 }
414 *state = SlotState::Waiting(waker);
415 return Poll::Pending;
416 }
417 SlotState::Sent(data) => Ok(data),
418 SlotState::Done => {
419 let err = recv.port.as_ref().map_or(RecvError::Closed, |port| {
420 port.is_closed()
421 .map(|_| RecvError::Closed)
422 .unwrap_or_else(|err| RecvError::Error(err.into()))
423 });
424 Err(err)
425 }
426 SlotState::ReceiverRemote { .. } => {
427 unreachable!()
428 }
429 };
430 };
431 Poll::Ready(v)
432 }
433 ready!(poll_recv(self, cx))
434 .map(|v| {
435 unsafe { v.cast::<T>() }
437 })
438 .into()
439 }
440
441 unsafe fn into_port<T: MeshField>(self) -> Port {
445 fn into_port(this: OneshotReceiverCore, send: SendFn) -> Port {
446 let (slot, ReceiverState { port }) = this.split();
447 let existing = port.map(|port| port.remove_handler().0);
448 let mut state = slot.state.lock();
449 match std::mem::replace(&mut *state, SlotState::Done) {
450 SlotState::SenderRemote(port, _) => {
451 assert!(existing.is_none());
452 port
453 }
454 SlotState::Waiting(_) => existing.unwrap_or_else(|| {
455 let (sender, recv) = Port::new_pair();
456 *state = SlotState::ReceiverRemote(recv, send);
457 sender
458 }),
459 SlotState::Sent(value) => {
460 let (sender, recv) = Port::new_pair();
461 unsafe { send(sender, value) };
464 recv
468 }
469 SlotState::Done => existing.unwrap_or_else(|| Port::new_pair().0),
470 SlotState::ReceiverRemote { .. } => unreachable!(),
471 }
472 }
473 into_port(self, send_message::<T>)
474 }
475
476 fn from_port<T: MeshField>(port: Port) -> Self {
477 fn from_port(port: Port, decode: DecodeFn) -> OneshotReceiverCore {
478 let slot = Arc::new(Slot {
479 state: Mutex::new(SlotState::SenderRemote(port, decode)),
480 receiver: Default::default(),
481 });
482 OneshotReceiverCore(slot)
483 }
484 from_port(port, decode_message::<T>)
485 }
486}
487
488#[derive(Debug)]
489enum SlotState {
490 Done,
491 Waiting(Option<Waker>),
492 Sent(BoxedValue),
493 SenderRemote(Port, DecodeFn),
494 ReceiverRemote(Port, SendFn),
495}
496
497type SendFn = unsafe fn(Port, BoxedValue);
498type DecodeFn = unsafe fn(Message<'_>) -> Result<BoxedValue, ChannelError>;
499
500#[derive(Debug)]
501struct BoxedValue(NonNull<()>);
502
503unsafe impl Send for BoxedValue {}
507unsafe impl Sync for BoxedValue {}
509
510impl BoxedValue {
511 fn new<T>(value: Box<T>) -> Self {
512 Self(NonNull::new(Box::into_raw(value).cast()).unwrap())
513 }
514
515 #[expect(clippy::unnecessary_box_returns)]
519 unsafe fn cast<T>(self) -> Box<T> {
520 unsafe { Box::from_raw(self.0.cast::<T>().as_ptr()) }
522 }
523
524 unsafe fn drop<T>(self) {
528 let _ = unsafe { self.cast::<T>() };
530 }
531}
532
533#[derive(Debug, Error)]
534#[error("unexpected oneshot message")]
535struct UnexpectedMessage;
536
537struct SlotHandler {
538 slot: Arc<Slot>,
539 decode: DecodeFn,
540}
541
542impl SlotHandler {
543 fn close_or_fail(
544 &mut self,
545 control: &mut mesh_node::local_node::PortControl<'_, '_>,
546 fail: bool,
547 ) {
548 let mut state = self.slot.state.lock();
549 match std::mem::replace(&mut *state, SlotState::Done) {
550 SlotState::Waiting(waker) => {
551 if let Some(waker) = waker {
552 control.wake(waker);
553 }
554 }
555 SlotState::Sent(v) => {
556 if !fail {
557 *state = SlotState::Sent(v);
558 }
559 }
560 SlotState::Done => {}
561 SlotState::SenderRemote { .. } | SlotState::ReceiverRemote { .. } => unreachable!(),
562 }
563 }
564}
565
566impl HandlePortEvent for SlotHandler {
567 fn message(
568 &mut self,
569 control: &mut mesh_node::local_node::PortControl<'_, '_>,
570 message: Message<'_>,
571 ) -> Result<(), HandleMessageError> {
572 let mut state = self.slot.state.lock();
573 match std::mem::replace(&mut *state, SlotState::Done) {
574 SlotState::Waiting(waker) => {
575 let r = unsafe { (self.decode)(message) };
579 let value = match r {
580 Ok(v) => v,
581 Err(err) => {
582 *state = SlotState::Waiting(waker);
584 return Err(HandleMessageError::new(err));
585 }
586 };
587 *state = SlotState::Sent(value);
588 drop(state);
589 if let Some(waker) = waker {
590 control.wake(waker);
591 }
592 }
593 SlotState::Sent(v) => {
594 *state = SlotState::Sent(v);
595 return Err(HandleMessageError::new(UnexpectedMessage));
596 }
597 SlotState::Done => {
598 *state = SlotState::Done;
599 }
600 SlotState::SenderRemote { .. } | SlotState::ReceiverRemote { .. } => unreachable!(),
601 }
602 Ok(())
603 }
604
605 fn close(&mut self, control: &mut mesh_node::local_node::PortControl<'_, '_>) {
606 self.close_or_fail(control, false);
607 }
608
609 fn fail(
610 &mut self,
611 control: &mut mesh_node::local_node::PortControl<'_, '_>,
612 _err: mesh_node::local_node::NodeError,
613 ) {
614 self.close_or_fail(control, true);
615 }
616
617 fn drain(&mut self) -> Vec<OwnedMessage> {
618 Vec::new()
619 }
620}
621
622#[cfg(test)]
623mod tests {
624 use super::oneshot;
625 use crate::OneshotReceiver;
626 use crate::OneshotSender;
627 use crate::RecvError;
628 use futures::FutureExt;
629 use futures::executor::block_on;
630 use futures::task::SpawnExt;
631 use mesh_node::local_node::Port;
632 use mesh_node::message::Message;
633 use std::cell::Cell;
634 use std::future::poll_fn;
635 use test_with_tracing::test;
636
637 static_assertions::assert_impl_all!(OneshotSender<i32>: Send, Sync);
639 static_assertions::assert_impl_all!(OneshotReceiver<i32>: Send, Sync);
640 static_assertions::assert_impl_all!(OneshotSender<Cell<i32>>: Send, Sync);
641 static_assertions::assert_impl_all!(OneshotReceiver<Cell<i32>>: Send, Sync);
642 static_assertions::assert_not_impl_any!(OneshotSender<*const ()>: Send, Sync);
643 static_assertions::assert_not_impl_any!(OneshotReceiver<*const ()>: Send, Sync);
644
645 #[test]
646 fn test_oneshot() {
647 block_on(async {
648 let (sender, receiver) = oneshot();
649 sender.send(String::from("foo"));
650 assert_eq!(receiver.await.unwrap(), "foo");
651 })
652 }
653
654 #[test]
655 fn test_oneshot_convert_sender_port() {
656 block_on(async {
657 let (sender, receiver) = oneshot::<String>();
658 let sender = OneshotSender::<String>::from(Port::from(sender));
659 sender.send(String::from("foo"));
660 assert_eq!(receiver.await.unwrap(), "foo");
661 })
662 }
663
664 #[test]
665 fn test_oneshot_convert_receiver_port() {
666 block_on(async {
667 let (sender, receiver) = oneshot::<String>();
668 let receiver = OneshotReceiver::<String>::from(Port::from(receiver));
669 sender.send(String::from("foo"));
670 assert_eq!(receiver.await.unwrap(), "foo");
671 })
672 }
673
674 #[test]
675 fn test_oneshot_convert_receiver_port_after_send() {
676 block_on(async {
677 let (sender, receiver) = oneshot::<String>();
678 sender.send(String::from("foo"));
679 let receiver = OneshotReceiver::<String>::from(Port::from(receiver));
680 assert_eq!(receiver.await.unwrap(), "foo");
681 })
682 }
683
684 #[test]
685 fn test_oneshot_convert_both() {
686 block_on(async {
687 let (sender, receiver) = oneshot::<String>();
688 let sender = OneshotSender::<String>::from(Port::from(sender));
689 let receiver = OneshotReceiver::<String>::from(Port::from(receiver));
690 sender.send(String::from("foo"));
691 assert_eq!(receiver.await.unwrap(), "foo");
692 })
693 }
694
695 #[test]
696 fn test_oneshot_convert_both_poll_first() {
697 block_on(async {
698 let (sender, mut receiver) = oneshot::<String>();
699 let sender = OneshotSender::<String>::from(Port::from(sender));
700 assert!(
702 poll_fn(|cx| receiver.poll_recv(cx))
703 .now_or_never()
704 .is_none()
705 );
706 let receiver = OneshotReceiver::<String>::from(Port::from(receiver));
707 sender.send(String::from("foo"));
708 assert_eq!(receiver.await.unwrap(), "foo");
709 })
710 }
711
712 #[test]
713 fn test_oneshot_message_corruption() {
714 let mut pool = futures::executor::LocalPool::new();
715 let spawner = pool.spawner();
716 pool.run_until(async {
717 let (sender, receiver) = oneshot();
718 let receiver = OneshotReceiver::<i32>::from(Port::from(receiver));
719 let receiver = spawner.spawn_with_handle(receiver).unwrap();
722 futures::pending!();
723 sender.send("text".to_owned());
724 let RecvError::Error(err) = receiver.await.unwrap_err() else {
725 panic!()
726 };
727 tracing::info!(error = &err as &dyn std::error::Error, "expected error");
728 })
729 }
730
731 #[test]
732 fn test_oneshot_extra_messages() {
733 block_on(async {
734 let (sender, mut receiver) = oneshot::<()>();
735 let sender = Port::from(sender);
736 assert!(futures::poll!(&mut receiver).is_pending());
737 sender.send(Message::new(()));
738 sender.send(Message::new(()));
739 let RecvError::Error(err) = receiver.await.unwrap_err() else {
740 panic!()
741 };
742 tracing::info!(error = &err as &dyn std::error::Error, "expected error");
743 })
744 }
745
746 #[test]
747 fn test_oneshot_closed() {
748 block_on(async {
749 let (sender, receiver) = oneshot::<()>();
750 drop(sender);
751 let RecvError::Closed = receiver.await.unwrap_err() else {
752 panic!()
753 };
754 })
755 }
756}