1#![expect(unsafe_code)]
19
20use crate::ChannelError;
21use crate::RecvError;
22use crate::sync_unsafe_cell::SyncUnsafeCell;
23use mesh_node::local_node::HandleMessageError;
24use mesh_node::local_node::HandlePortEvent;
25use mesh_node::local_node::Port;
26use mesh_node::local_node::PortField;
27use mesh_node::local_node::PortWithHandler;
28use mesh_node::message::MeshField;
29use mesh_node::message::Message;
30use mesh_node::message::OwnedMessage;
31use mesh_protobuf::DefaultEncoding;
32use parking_lot::Mutex;
33use std::fmt::Debug;
34use std::future::Future;
35use std::marker::PhantomData;
36use std::mem::ManuallyDrop;
37use std::ptr::NonNull;
38use std::sync::Arc;
39use std::task::Context;
40use std::task::Poll;
41use std::task::Waker;
42use std::task::ready;
43use thiserror::Error;
44
45pub fn oneshot<T>() -> (OneshotSender<T>, OneshotReceiver<T>) {
71 fn oneshot_core() -> (OneshotSenderCore, OneshotReceiverCore) {
72 let slot = Arc::new(Slot {
73 state: Mutex::new(SlotState::Waiting(None)),
74 receiver: Default::default(),
75 });
76 (OneshotSenderCore(slot.clone()), OneshotReceiverCore(slot))
77 }
78
79 let (sender, receiver) = oneshot_core();
80 (
81 OneshotSender(sender, PhantomData),
82 OneshotReceiver(ManuallyDrop::new(receiver), PhantomData),
83 )
84}
85
86pub struct OneshotSender<T>(OneshotSenderCore, PhantomData<Arc<Mutex<T>>>);
92
93impl<T> Debug for OneshotSender<T> {
94 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95 Debug::fmt(&self.0, f)
96 }
97}
98
99impl<T> OneshotSender<T> {
100 pub fn send(self, value: T) {
102 unsafe { self.0.send(value) }
104 }
105
106 pub fn is_closed(&self) -> bool {
108 self.0.is_closed()
109 }
110}
111
112impl<T: MeshField> DefaultEncoding for OneshotSender<T> {
113 type Encoding = PortField;
114}
115
116impl<T: MeshField> From<OneshotSender<T>> for Port {
117 fn from(sender: OneshotSender<T>) -> Self {
118 unsafe { sender.0.into_port::<T>() }
120 }
121}
122
123impl<T: MeshField> From<Port> for OneshotSender<T> {
124 fn from(port: Port) -> Self {
125 Self(OneshotSenderCore::from_port::<T>(port), PhantomData)
126 }
127}
128
129unsafe fn send_message<T: MeshField>(port: Port, value: BoxedValue) {
132 let value = unsafe { value.cast::<T>() };
134 port.send_protobuf_and_close((value,));
135}
136
137fn decode_message<T: MeshField>(message: Message<'_>) -> Result<BoxedValue, ChannelError> {
138 let (value,) = message.parse_non_static::<(Box<T>,)>()?;
139 Ok(BoxedValue::new(value))
140}
141
142#[derive(Debug)]
143struct Slot {
144 state: Mutex<SlotState>,
145 receiver: SyncUnsafeCell<ManuallyDrop<ReceiverState>>,
148}
149
150#[derive(Debug)]
151struct OneshotSenderCore(Arc<Slot>);
152
153impl Drop for OneshotSenderCore {
154 fn drop(&mut self) {
155 self.close();
156 }
157}
158
159impl OneshotSenderCore {
160 fn into_slot(self) -> Arc<Slot> {
161 let Self(ref slot) = *ManuallyDrop::new(self);
162 unsafe { <*const _>::read(slot) }
164 }
165
166 fn close(&self) {
167 let mut state = self.0.state.lock();
168 match std::mem::replace(&mut *state, SlotState::Done) {
169 SlotState::Waiting(waker) => {
170 drop(state);
171 if let Some(waker) = waker {
172 waker.wake();
173 }
174 }
175 SlotState::Sent(v) => {
176 *state = SlotState::Sent(v);
177 }
178 SlotState::Done => {}
179 SlotState::ReceiverRemote(port, _) => {
180 drop(port);
181 }
182 SlotState::SenderRemote { .. } => unreachable!(),
183 }
184 }
185
186 fn is_closed(&self) -> bool {
187 match &*self.0.state.lock() {
188 SlotState::Done => true,
189 SlotState::Sent(_) => true,
190 SlotState::Waiting(_) => false,
191 SlotState::ReceiverRemote(port, _) => port.is_closed().unwrap_or(false),
192 SlotState::SenderRemote { .. } => unreachable!(),
193 }
194 }
195
196 unsafe fn send<T>(self, value: T) {
199 fn send(this: OneshotSenderCore, value: BoxedValue) -> Option<BoxedValue> {
200 let slot = this.into_slot();
201 let mut state = slot.state.lock();
202 match std::mem::replace(&mut *state, SlotState::Done) {
203 SlotState::ReceiverRemote(port, send) => {
204 unsafe { send(port, value) };
207 None
208 }
209 SlotState::Waiting(waker) => {
210 *state = SlotState::Sent(value);
211 drop(state);
212 if let Some(waker) = waker {
213 waker.wake();
214 }
215 None
216 }
217 SlotState::Done => Some(value),
218 SlotState::Sent { .. } | SlotState::SenderRemote { .. } => unreachable!(),
219 }
220 }
221 if let Some(value) = send(self, BoxedValue::new(Box::new(value))) {
222 unsafe { value.drop::<T>() };
224 }
225 }
226
227 unsafe fn into_port<T: MeshField>(self) -> Port {
230 fn into_port(this: OneshotSenderCore, decode: DecodeFn) -> Port {
231 let slot = this.into_slot();
232 let mut state = slot.state.lock();
233 match std::mem::replace(&mut *state, SlotState::Done) {
234 SlotState::Waiting(waker) => {
235 let (send, recv) = Port::new_pair();
236 *state = SlotState::SenderRemote(recv, decode);
237 drop(state);
238 if let Some(waker) = waker {
239 waker.wake();
240 }
241 send
242 }
243 SlotState::ReceiverRemote(port, _) => port,
244 SlotState::Done => Port::new_pair().0,
245 SlotState::Sent(_) | SlotState::SenderRemote { .. } => unreachable!(),
246 }
247 }
248 into_port(self, decode_message::<T>)
249 }
250
251 fn from_port<T: MeshField>(port: Port) -> Self {
252 fn from_port(port: Port, send: SendFn) -> OneshotSenderCore {
253 let slot = Arc::new(Slot {
254 state: Mutex::new(SlotState::ReceiverRemote(port, send)),
255 receiver: Default::default(),
256 });
257 OneshotSenderCore(slot)
258 }
259 from_port(port, send_message::<T>)
260 }
261}
262
263pub struct OneshotReceiver<T>(
271 ManuallyDrop<OneshotReceiverCore>,
272 PhantomData<Arc<Mutex<T>>>,
273);
274
275impl<T> Debug for OneshotReceiver<T> {
276 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
277 Debug::fmt(&self.0, f)
278 }
279}
280
281impl<T> OneshotReceiver<T> {
282 fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Result<T, RecvError>> {
283 let v = unsafe { ready!(self.0.poll_recv(cx))? };
285 Ok(*v).into()
286 }
287
288 fn into_core(self) -> OneshotReceiverCore {
289 let Self(ref core, _) = *ManuallyDrop::new(self);
290 unsafe { <*const _>::read(&**core) }
292 }
293}
294
295impl<T> Drop for OneshotReceiver<T> {
296 fn drop(&mut self) {
297 let core = unsafe { ManuallyDrop::take(&mut self.0) };
299 unsafe { core.drop::<T>() };
301 }
302}
303
304impl<T> Future for OneshotReceiver<T> {
307 type Output = Result<T, RecvError>;
308
309 fn poll(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
310 self.get_mut().poll_recv(cx)
311 }
312}
313
314impl<T: MeshField> DefaultEncoding for OneshotReceiver<T> {
315 type Encoding = PortField;
316}
317
318impl<T: MeshField> From<OneshotReceiver<T>> for Port {
319 fn from(receiver: OneshotReceiver<T>) -> Self {
320 unsafe { receiver.into_core().into_port::<T>() }
322 }
323}
324
325impl<T: MeshField> From<Port> for OneshotReceiver<T> {
326 fn from(port: Port) -> Self {
327 Self(
328 ManuallyDrop::new(OneshotReceiverCore::from_port::<T>(port)),
329 PhantomData,
330 )
331 }
332}
333
334#[derive(Debug)]
335struct OneshotReceiverCore(Arc<Slot>);
336
337#[derive(Default)]
338struct ReceiverState {
339 port: Option<PortWithHandler<SlotHandler>>,
340}
341
342impl OneshotReceiverCore {
343 fn split(self) -> (Arc<Slot>, ReceiverState) {
344 let receiver = unsafe { ManuallyDrop::take(&mut *self.0.receiver.0.get()) };
347 (self.0, receiver)
348 }
349
350 fn split_mut(&mut self) -> (&Arc<Slot>, &mut ReceiverState) {
351 let receiver = unsafe { &mut *self.0.receiver.0.get() };
353 (&self.0, receiver)
354 }
355
356 unsafe fn drop<T>(self) {
364 fn clear(this: OneshotReceiverCore) -> Option<BoxedValue> {
365 let (slot, ReceiverState { port }) = this.split();
366 drop(port);
367 if let SlotState::Sent(value) =
372 std::mem::replace(&mut *slot.state.lock(), SlotState::Done)
373 {
374 Some(value)
375 } else {
376 None
377 }
378 }
379 if let Some(v) = clear(self) {
380 unsafe { v.drop::<T>() };
382 }
383 }
384
385 unsafe fn poll_recv<T>(&mut self, cx: &mut Context<'_>) -> Poll<Result<Box<T>, RecvError>> {
389 fn poll_recv(
390 this: &mut OneshotReceiverCore,
391 cx: &mut Context<'_>,
392 ) -> Poll<Result<BoxedValue, RecvError>> {
393 let (slot, recv) = this.split_mut();
394 let v = loop {
395 let mut state = slot.state.lock();
396 break match std::mem::replace(&mut *state, SlotState::Done) {
397 SlotState::SenderRemote(port, decode) => {
398 *state = SlotState::Waiting(None);
399 drop(state);
400 assert!(recv.port.is_none());
401 recv.port = Some(port.set_handler(SlotHandler {
402 slot: slot.clone(),
403 decode,
404 }));
405 continue;
406 }
407 SlotState::Waiting(mut waker) => {
408 if let Some(waker) = &mut waker {
409 waker.clone_from(cx.waker());
410 } else {
411 waker = Some(cx.waker().clone());
412 }
413 *state = SlotState::Waiting(waker);
414 return Poll::Pending;
415 }
416 SlotState::Sent(data) => Ok(data),
417 SlotState::Done => {
418 let err = recv.port.as_ref().map_or(RecvError::Closed, |port| {
419 port.is_closed()
420 .map(|_| RecvError::Closed)
421 .unwrap_or_else(|err| RecvError::Error(err.into()))
422 });
423 Err(err)
424 }
425 SlotState::ReceiverRemote { .. } => {
426 unreachable!()
427 }
428 };
429 };
430 Poll::Ready(v)
431 }
432 ready!(poll_recv(self, cx))
433 .map(|v| {
434 unsafe { v.cast::<T>() }
436 })
437 .into()
438 }
439
440 unsafe fn into_port<T: MeshField>(self) -> Port {
444 fn into_port(this: OneshotReceiverCore, send: SendFn) -> Port {
445 let (slot, ReceiverState { port }) = this.split();
446 let existing = port.map(|port| port.remove_handler().0);
447 let mut state = slot.state.lock();
448 match std::mem::replace(&mut *state, SlotState::Done) {
449 SlotState::SenderRemote(port, _) => {
450 assert!(existing.is_none());
451 port
452 }
453 SlotState::Waiting(_) => existing.unwrap_or_else(|| {
454 let (sender, recv) = Port::new_pair();
455 *state = SlotState::ReceiverRemote(recv, send);
456 sender
457 }),
458 SlotState::Sent(value) => {
459 let (sender, recv) = Port::new_pair();
460 unsafe { send(sender, value) };
463 recv
467 }
468 SlotState::Done => existing.unwrap_or_else(|| Port::new_pair().0),
469 SlotState::ReceiverRemote { .. } => unreachable!(),
470 }
471 }
472 into_port(self, send_message::<T>)
473 }
474
475 fn from_port<T: MeshField>(port: Port) -> Self {
476 fn from_port(port: Port, decode: DecodeFn) -> OneshotReceiverCore {
477 let slot = Arc::new(Slot {
478 state: Mutex::new(SlotState::SenderRemote(port, decode)),
479 receiver: Default::default(),
480 });
481 OneshotReceiverCore(slot)
482 }
483 from_port(port, decode_message::<T>)
484 }
485}
486
487#[derive(Debug)]
488enum SlotState {
489 Done,
490 Waiting(Option<Waker>),
491 Sent(BoxedValue),
492 SenderRemote(Port, DecodeFn),
493 ReceiverRemote(Port, SendFn),
494}
495
496type SendFn = unsafe fn(Port, BoxedValue);
497type DecodeFn = unsafe fn(Message<'_>) -> Result<BoxedValue, ChannelError>;
498
499#[derive(Debug)]
500struct BoxedValue(NonNull<()>);
501
502unsafe impl Send for BoxedValue {}
506unsafe impl Sync for BoxedValue {}
508
509impl BoxedValue {
510 fn new<T>(value: Box<T>) -> Self {
511 Self(NonNull::new(Box::into_raw(value).cast()).unwrap())
512 }
513
514 #[expect(clippy::unnecessary_box_returns)]
518 unsafe fn cast<T>(self) -> Box<T> {
519 unsafe { Box::from_raw(self.0.cast::<T>().as_ptr()) }
521 }
522
523 unsafe fn drop<T>(self) {
527 let _ = unsafe { self.cast::<T>() };
529 }
530}
531
532#[derive(Debug, Error)]
533#[error("unexpected oneshot message")]
534struct UnexpectedMessage;
535
536struct SlotHandler {
537 slot: Arc<Slot>,
538 decode: DecodeFn,
539}
540
541impl SlotHandler {
542 fn close_or_fail(
543 &mut self,
544 control: &mut mesh_node::local_node::PortControl<'_, '_>,
545 fail: bool,
546 ) {
547 let mut state = self.slot.state.lock();
548 match std::mem::replace(&mut *state, SlotState::Done) {
549 SlotState::Waiting(waker) => {
550 if let Some(waker) = waker {
551 control.wake(waker);
552 }
553 }
554 SlotState::Sent(v) => {
555 if !fail {
556 *state = SlotState::Sent(v);
557 }
558 }
559 SlotState::Done => {}
560 SlotState::SenderRemote { .. } | SlotState::ReceiverRemote { .. } => unreachable!(),
561 }
562 }
563}
564
565impl HandlePortEvent for SlotHandler {
566 fn message(
567 &mut self,
568 control: &mut mesh_node::local_node::PortControl<'_, '_>,
569 message: Message<'_>,
570 ) -> Result<(), HandleMessageError> {
571 let mut state = self.slot.state.lock();
572 match std::mem::replace(&mut *state, SlotState::Done) {
573 SlotState::Waiting(waker) => {
574 let r = unsafe { (self.decode)(message) };
578 let value = match r {
579 Ok(v) => v,
580 Err(err) => {
581 *state = SlotState::Waiting(waker);
583 return Err(HandleMessageError::new(err));
584 }
585 };
586 *state = SlotState::Sent(value);
587 drop(state);
588 if let Some(waker) = waker {
589 control.wake(waker);
590 }
591 }
592 SlotState::Sent(v) => {
593 *state = SlotState::Sent(v);
594 return Err(HandleMessageError::new(UnexpectedMessage));
595 }
596 SlotState::Done => {
597 *state = SlotState::Done;
598 }
599 SlotState::SenderRemote { .. } | SlotState::ReceiverRemote { .. } => unreachable!(),
600 }
601 Ok(())
602 }
603
604 fn close(&mut self, control: &mut mesh_node::local_node::PortControl<'_, '_>) {
605 self.close_or_fail(control, false);
606 }
607
608 fn fail(
609 &mut self,
610 control: &mut mesh_node::local_node::PortControl<'_, '_>,
611 _err: mesh_node::local_node::NodeError,
612 ) {
613 self.close_or_fail(control, true);
614 }
615
616 fn drain(&mut self) -> Vec<OwnedMessage> {
617 Vec::new()
618 }
619}
620
621#[cfg(test)]
622mod tests {
623 use super::oneshot;
624 use crate::OneshotReceiver;
625 use crate::OneshotSender;
626 use crate::RecvError;
627 use futures::FutureExt;
628 use futures::executor::block_on;
629 use futures::task::SpawnExt;
630 use mesh_node::local_node::Port;
631 use mesh_node::message::Message;
632 use std::cell::Cell;
633 use std::future::poll_fn;
634 use test_with_tracing::test;
635
636 static_assertions::assert_impl_all!(OneshotSender<i32>: Send, Sync);
638 static_assertions::assert_impl_all!(OneshotReceiver<i32>: Send, Sync);
639 static_assertions::assert_impl_all!(OneshotSender<Cell<i32>>: Send, Sync);
640 static_assertions::assert_impl_all!(OneshotReceiver<Cell<i32>>: Send, Sync);
641 static_assertions::assert_not_impl_any!(OneshotSender<*const ()>: Send, Sync);
642 static_assertions::assert_not_impl_any!(OneshotReceiver<*const ()>: Send, Sync);
643
644 #[test]
645 fn test_oneshot() {
646 block_on(async {
647 let (sender, receiver) = oneshot();
648 sender.send(String::from("foo"));
649 assert_eq!(receiver.await.unwrap(), "foo");
650 })
651 }
652
653 #[test]
654 fn test_oneshot_convert_sender_port() {
655 block_on(async {
656 let (sender, receiver) = oneshot::<String>();
657 let sender = OneshotSender::<String>::from(Port::from(sender));
658 sender.send(String::from("foo"));
659 assert_eq!(receiver.await.unwrap(), "foo");
660 })
661 }
662
663 #[test]
664 fn test_oneshot_convert_receiver_port() {
665 block_on(async {
666 let (sender, receiver) = oneshot::<String>();
667 let receiver = OneshotReceiver::<String>::from(Port::from(receiver));
668 sender.send(String::from("foo"));
669 assert_eq!(receiver.await.unwrap(), "foo");
670 })
671 }
672
673 #[test]
674 fn test_oneshot_convert_receiver_port_after_send() {
675 block_on(async {
676 let (sender, receiver) = oneshot::<String>();
677 sender.send(String::from("foo"));
678 let receiver = OneshotReceiver::<String>::from(Port::from(receiver));
679 assert_eq!(receiver.await.unwrap(), "foo");
680 })
681 }
682
683 #[test]
684 fn test_oneshot_convert_both() {
685 block_on(async {
686 let (sender, receiver) = oneshot::<String>();
687 let sender = OneshotSender::<String>::from(Port::from(sender));
688 let receiver = OneshotReceiver::<String>::from(Port::from(receiver));
689 sender.send(String::from("foo"));
690 assert_eq!(receiver.await.unwrap(), "foo");
691 })
692 }
693
694 #[test]
695 fn test_oneshot_convert_both_poll_first() {
696 block_on(async {
697 let (sender, mut receiver) = oneshot::<String>();
698 let sender = OneshotSender::<String>::from(Port::from(sender));
699 assert!(
701 poll_fn(|cx| receiver.poll_recv(cx))
702 .now_or_never()
703 .is_none()
704 );
705 let receiver = OneshotReceiver::<String>::from(Port::from(receiver));
706 sender.send(String::from("foo"));
707 assert_eq!(receiver.await.unwrap(), "foo");
708 })
709 }
710
711 #[test]
712 fn test_oneshot_message_corruption() {
713 let mut pool = futures::executor::LocalPool::new();
714 let spawner = pool.spawner();
715 pool.run_until(async {
716 let (sender, receiver) = oneshot();
717 let receiver = OneshotReceiver::<i32>::from(Port::from(receiver));
718 let receiver = spawner.spawn_with_handle(receiver).unwrap();
721 futures::pending!();
722 sender.send("text".to_owned());
723 let RecvError::Error(err) = receiver.await.unwrap_err() else {
724 panic!()
725 };
726 tracing::info!(error = &err as &dyn std::error::Error, "expected error");
727 })
728 }
729
730 #[test]
731 fn test_oneshot_extra_messages() {
732 block_on(async {
733 let (sender, mut receiver) = oneshot::<()>();
734 let sender = Port::from(sender);
735 assert!(futures::poll!(&mut receiver).is_pending());
736 sender.send(Message::new(()));
737 sender.send(Message::new(()));
738 let RecvError::Error(err) = receiver.await.unwrap_err() else {
739 panic!()
740 };
741 tracing::info!(error = &err as &dyn std::error::Error, "expected error");
742 })
743 }
744
745 #[test]
746 fn test_oneshot_closed() {
747 block_on(async {
748 let (sender, receiver) = oneshot::<()>();
749 drop(sender);
750 let RecvError::Closed = receiver.await.unwrap_err() else {
751 panic!()
752 };
753 })
754 }
755}