mesh_channel_core/
mpsc.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! Implementation of an async multi-producer, single-consumer (MPSC) channel
5//! that can be used to communicate between mesh nodes.
6//!
7//! The main design requirements of this channel are:
8//! * It roughly follows the semantics of the Rust standard library's
9//!   `std::sync::mpsc` channel, but with async support.
10//! * It is efficient enough for single process use that it can be used as a
11//!   general purpose channel.
12//! * It leverages `mesh_node` ports and `mesh_protobuf` serialization to allow
13//!   communication between mesh nodes, which can be on different processes or
14//!   machines.
15//! * Its contribution to binary size is minimal.
16//!
17//! To achieve the binary size goal, this implementation avoids generics where
18//! practical. This has the tradeoff of requiring a fair amount of unsafe code,
19//! but this makes it practical to use this channel in space-constrained
20//! environments.
21
22// UNSAFETY: Needed to erase types to avoid monomorphization overhead.
23#![expect(unsafe_code)]
24
25use crate::deque::ElementVtable;
26use crate::deque::ErasedVecDeque;
27use crate::error::ChannelError;
28use crate::error::RecvError;
29use crate::error::TryRecvError;
30use crate::sync_unsafe_cell::SyncUnsafeCell;
31use core::fmt::Debug;
32use core::future::Future;
33use core::marker::PhantomData;
34use core::mem::ManuallyDrop;
35use core::mem::MaybeUninit;
36use core::task::Context;
37use core::task::Poll;
38use core::task::Waker;
39use mesh_node::local_node::HandleMessageError;
40use mesh_node::local_node::HandlePortEvent;
41use mesh_node::local_node::Port;
42use mesh_node::local_node::PortField;
43use mesh_node::local_node::PortWithHandler;
44use mesh_node::message::MeshField;
45use mesh_node::message::Message;
46use mesh_node::message::OwnedMessage;
47use mesh_protobuf::DefaultEncoding;
48use mesh_protobuf::Protobuf;
49use parking_lot::Mutex;
50use parking_lot::MutexGuard;
51use std::marker::PhantomPinned;
52use std::sync::Arc;
53use std::sync::OnceLock;
54use std::task::ready;
55
56/// Creates a new channel for sending messages of type `T`, returning the sender
57/// and receiver ends.
58pub fn channel<T>() -> (Sender<T>, Receiver<T>) {
59    fn channel_core(vtable: &'static ElementVtable) -> (SenderCore, ReceiverCore) {
60        let mut receiver = ReceiverCore::new(vtable);
61        let sender = receiver.sender();
62        (sender, receiver)
63    }
64    let (sender, receiver) = channel_core(const { &ElementVtable::new::<T>() });
65    (Sender(sender, PhantomData), Receiver(receiver, PhantomData))
66}
67
68/// The sending half of a channel returned by [`channel`].
69///
70/// The sender can be cloned to send messages from multiple threads or
71/// processes.
72//
73// Note that the `PhantomData` here is necessary to ensure `Send/Sync` traits
74// are only implemented when `T` is `Send`, since the `SenderCore` is always
75// `Send+Sync`. This behavior is verified in the unit tests.
76pub struct Sender<T>(SenderCore, PhantomData<Arc<Mutex<[T]>>>);
77
78impl<T> Debug for Sender<T> {
79    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
80        Debug::fmt(&self.0, f)
81    }
82}
83
84impl<T> Clone for Sender<T> {
85    fn clone(&self) -> Self {
86        Self(self.0.clone(), PhantomData)
87    }
88}
89
90impl<T> Sender<T> {
91    /// Sends a message to the associated [`Receiver<T>`].
92    ///
93    /// Does not return a result, so messages can be silently dropped if the
94    /// receiver has closed or failed. To detect such conditions, include
95    /// another sender in the message you send so that the receiving thread can
96    /// use it to send a response.
97    ///
98    /// ```rust
99    /// # use mesh_channel_core::*;
100    /// # futures::executor::block_on(async {
101    /// let (send, mut recv) = channel();
102    /// let (response_send, mut response_recv) = channel::<bool>();
103    /// send.send((3, response_send));
104    /// let (val, response_send) = recv.recv().await.unwrap();
105    /// response_send.send(val == 3);
106    /// assert_eq!(response_recv.recv().await.unwrap(), true);
107    /// # });
108    /// ```
109    pub fn send(&self, message: T) {
110        // SAFETY: the queue is for `T` and `message` is a valid owned `T`.
111        // Additionally, the sender/receiver is only `Send`/`Sync` if `T` is
112        // `Send`/`Sync`.
113        unsafe { self.0.send(message) }
114    }
115
116    /// Returns whether the receiving side of the channel is known to be closed
117    /// (or failed).
118    ///
119    /// This is useful to determine if there is any point in sending more data
120    /// via this port. Note that even if this returns `false` messages may still
121    /// fail to reach the destination, for example if the receiver is closed
122    /// after this method is called but before the message is consumed.
123    pub fn is_closed(&self) -> bool {
124        self.0.is_closed()
125    }
126}
127
128struct MessagePtr(*mut ());
129
130impl MessagePtr {
131    fn new<T>(message: &mut MaybeUninit<T>) -> Self {
132        Self(message.as_mut_ptr().cast())
133    }
134
135    /// # Safety
136    /// The caller must ensure that `self` is a valid owned `T`.
137    unsafe fn read<T>(self) -> T {
138        // SAFETY: The caller guarantees `self` is a valid owned `T`.
139        unsafe { self.0.cast::<T>().read() }
140    }
141}
142
143/// Sends a `ChannelPayload::Message(message)` to a port.
144///
145/// # Safety
146/// The caller must ensure that `message` is a valid owned `T`.
147unsafe fn send_message<T: MeshField>(port: &Port, message: MessagePtr) {
148    // SAFETY: The caller guarantees `message` is a valid owned `T`.
149    let m = unsafe { ChannelPayload::Message(message.read::<T>()) };
150    port.send_protobuf(m);
151}
152
153#[derive(Debug, Clone)]
154struct SenderCore(ManuallyDrop<Arc<Queue>>);
155
156impl SenderCore {
157    /// Sends `message`.
158    ///
159    /// # Safety
160    /// The caller must ensure that the message is a valid owned `T` for the `T`
161    /// the queue was created with. It also must ensure that the queue is not
162    /// sent/shared across threads unless `T` is `Send`/`Sync`.
163    unsafe fn send<T>(&self, message: T) {
164        fn send(queue: &Queue, message: MessagePtr) -> bool {
165            match queue.access() {
166                QueueAccess::Local(mut local) => {
167                    if local.receiver_gone {
168                        return false;
169                    }
170                    // SAFETY: The caller guarantees `message` is a valid owned `T`,
171                    // and that the queue will not be sent/shared across threads
172                    // unless `T` is `Send`/`Sync`.
173                    unsafe { local.messages.push_back(message.0) };
174                    if let Some(waker) = local.waker.take() {
175                        drop(local);
176                        waker.wake();
177                    }
178                }
179                QueueAccess::Remote(remote) => {
180                    // SAFETY: The caller guarantees `message` is a valid owned `T`.
181                    unsafe { (remote.send)(&remote.port, message) };
182                }
183            }
184            true
185        }
186
187        let mut message = MaybeUninit::new(message);
188        let sent = send(&self.0, MessagePtr::new(&mut message));
189        if !sent {
190            // SAFETY: `message` was not dropped.
191            unsafe { message.assume_init_drop() };
192        }
193    }
194
195    fn is_closed(&self) -> bool {
196        match self.0.access() {
197            QueueAccess::Local(local) => local.receiver_gone,
198            QueueAccess::Remote(remote) => remote.port.is_closed().unwrap_or(true),
199        }
200    }
201
202    fn into_queue(self) -> Arc<Queue> {
203        let Self(ref queue) = *ManuallyDrop::new(self);
204        // SAFETY: copying from a field that won't be dropped.
205        unsafe { <*const _>::read(&**queue) }
206    }
207
208    /// Creates a new queue with element type `T` for sending to `port`.
209    fn from_port<T: MeshField>(port: Port) -> Self {
210        fn from_port(port: Port, vtable: &'static ElementVtable, send: SendFn) -> SenderCore {
211            SenderCore(ManuallyDrop::new(Arc::new(Queue {
212                local: Mutex::new(LocalQueue {
213                    remote: true,
214                    ..LocalQueue::new(vtable)
215                }),
216                remote: OnceLock::from(RemoteQueueState { port, send }),
217                receiver: Default::default(),
218            })))
219        }
220
221        from_port(
222            port,
223            const { &ElementVtable::new::<T>() },
224            send_message::<T>,
225        )
226    }
227
228    /// Converts this sender into a port.
229    ///
230    /// # Safety
231    /// The caller must ensure that the queue has element type `T`.
232    unsafe fn into_port<T: MeshField>(self) -> Port {
233        fn into_port(this: SenderCore, new_handler: NewHandlerFn) -> Port {
234            match Arc::try_unwrap(this.into_queue()) {
235                Ok(mut queue) => {
236                    if let Some(remote) = queue.remote.into_inner() {
237                        // This is the unique owner of the port.
238                        remote.port
239                    } else {
240                        assert!(queue.local.get_mut().receiver_gone);
241                        let (send, _recv) = Port::new_pair();
242                        send
243                    }
244                }
245                Err(queue) => {
246                    // There is a receiver or at least one other sender.
247                    let (send, recv) = Port::new_pair();
248                    match queue.access() {
249                        QueueAccess::Local(mut local) => {
250                            if !local.receiver_gone {
251                                local.new_handler = new_handler;
252                                local.ports.push(recv);
253                                if let Some(waker) = local.waker.take() {
254                                    drop(local);
255                                    waker.wake();
256                                }
257                            }
258                        }
259                        QueueAccess::Remote(remote) => {
260                            remote.port.send_protobuf(ChannelPayload::<()>::Port(recv));
261                        }
262                    }
263                    send
264                }
265            }
266        }
267        into_port(self, RemotePortHandler::new::<T>)
268    }
269}
270
271impl Drop for SenderCore {
272    fn drop(&mut self) {
273        // SAFETY: the queue won't be referenced after this.
274        let queue = unsafe { ManuallyDrop::take(&mut self.0) };
275        let waker = if queue.remote.get().is_some() {
276            None
277        } else {
278            let mut local = queue.local.lock();
279            // TODO: keep a sender count to avoid needing to wake.
280            local.waker.take()
281        };
282        // Drop the queue so that the receiver will see the sender is gone.
283        drop(queue);
284        if let Some(waker) = waker {
285            waker.wake();
286        }
287    }
288}
289
290impl<T> DefaultEncoding for Sender<T> {
291    type Encoding = PortField;
292}
293
294impl<T: MeshField> From<Port> for Sender<T> {
295    fn from(port: Port) -> Self {
296        Self(SenderCore::from_port::<T>(port), PhantomData)
297    }
298}
299
300impl<T: MeshField> From<Sender<T>> for Port {
301    fn from(sender: Sender<T>) -> Self {
302        // SAFETY: the queue has element type `T`.
303        unsafe { sender.0.into_port::<T>() }
304    }
305}
306
307impl<T: MeshField> Sender<T> {
308    /// Bridges this and `recv` together, consuming both `self` and `recv`. This
309    /// makes it so that anything sent to `recv` will be directly sent to this
310    /// channel's peer receiver, without a separate relay step. This includes
311    /// any data that was previously sent but not yet consumed.
312    ///
313    /// ```rust
314    /// # use mesh_channel_core::*;
315    /// let (outer_send, inner_recv) = channel::<u32>();
316    /// let (inner_send, mut outer_recv) = channel::<u32>();
317    ///
318    /// outer_send.send(2);
319    /// inner_send.send(1);
320    /// inner_send.bridge(inner_recv);
321    /// assert_eq!(outer_recv.try_recv().unwrap(), 1);
322    /// assert_eq!(outer_recv.try_recv().unwrap(), 2);
323    /// ```
324    pub fn bridge(self, receiver: Receiver<T>) {
325        let sender = Port::from(self);
326        let receiver = Port::from(receiver);
327        sender.bridge(receiver);
328    }
329}
330
331/// The receiving half of a channel returned by [`channel`].
332//
333// Note that the `PhantomData` here is necessary to ensure `Send/Sync` traits
334// are only implemented when `T` is `Send`, since the `ReceiverCore` is always
335// `Send+Sync`. This behavior is verified in the unit tests.
336pub struct Receiver<T>(ReceiverCore, PhantomData<Arc<Mutex<[T]>>>);
337
338impl<T> Debug for Receiver<T> {
339    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
340        Debug::fmt(&self.0, f)
341    }
342}
343
344impl<T> Default for Receiver<T> {
345    fn default() -> Self {
346        Self::new()
347    }
348}
349
350#[derive(Debug)]
351struct ReceiverCore {
352    queue: ReceiverQueue,
353}
354
355impl<T> Receiver<T> {
356    /// Creates a new receiver with no senders.
357    ///
358    /// Receives will fail with [`RecvError::Closed`] until [`Self::sender`] is
359    /// called.
360    pub fn new() -> Self {
361        Self(
362            ReceiverCore::new(const { &ElementVtable::new::<T>() }),
363            PhantomData,
364        )
365    }
366
367    /// Consumes and returns the next message, waiting until one is available.
368    ///
369    /// Returns immediately when the channel is closed or failed.
370    ///
371    /// ```rust
372    /// # use mesh_channel_core::*;
373    /// # futures::executor::block_on(async {
374    /// let (send, mut recv) = channel();
375    /// send.send(5u32);
376    /// drop(send);
377    /// assert_eq!(recv.recv().await.unwrap(), 5);
378    /// assert!(matches!(recv.recv().await.unwrap_err(), RecvError::Closed));
379    /// # });
380    /// ```
381    pub fn recv(&mut self) -> Recv<'_, T> {
382        Recv(self, PhantomPinned)
383    }
384
385    /// Consumes and returns the next message, if there is one.
386    ///
387    /// Otherwise, returns whether the channel is empty, closed, or failed.
388    ///
389    /// ```rust
390    /// # use mesh_channel_core::*;
391    /// let (send, mut recv) = channel();
392    /// send.send(5u32);
393    /// drop(send);
394    /// assert_eq!(recv.try_recv().unwrap(), 5);
395    /// assert!(matches!(recv.try_recv().unwrap_err(), TryRecvError::Closed));
396    /// ```
397    pub fn try_recv(&mut self) -> Result<T, TryRecvError> {
398        // SAFETY: the queue type is `T`.
399        let r = unsafe { self.0.try_poll_recv::<T>(None) };
400        match r {
401            Poll::Ready(Ok(v)) => Ok(v),
402            Poll::Ready(Err(RecvError::Closed)) => Err(TryRecvError::Closed),
403            Poll::Ready(Err(RecvError::Error(e))) => Err(TryRecvError::Error(e)),
404            Poll::Pending => Err(TryRecvError::Empty),
405        }
406    }
407
408    /// Polls for the next message.
409    ///
410    /// If one is available, consumes and returns it. If the
411    /// channel is closed or failed, fails. Otherwise, registers the current task to wake
412    /// when a message is available or the channel is closed or fails.
413    pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Result<T, RecvError>> {
414        // SAFETY: the queue type is `T`.
415        unsafe { self.0.try_poll_recv(Some(cx)) }
416    }
417
418    /// Creates a new sender for sending data to this receiver.
419    ///
420    /// Note that this may transition the channel from the closed to open state.
421    pub fn sender(&mut self) -> Sender<T> {
422        Sender(self.0.sender(), PhantomData)
423    }
424}
425
426/// The future returned by [`Receiver::recv`].
427//
428// Force `!Unpin` to allow for future optimizations.
429pub struct Recv<'a, T>(&'a mut Receiver<T>, PhantomPinned);
430
431impl<T> Future for Recv<'_, T> {
432    type Output = Result<T, RecvError>;
433
434    fn poll(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
435        // SAFETY: there are no actual pinning invariants.
436        let this = unsafe { self.get_unchecked_mut() };
437        this.0.poll_recv(cx)
438    }
439}
440
441#[derive(Default)]
442struct ReceiverState {
443    ports: PortHandlerList,
444    terminated: bool,
445}
446
447#[derive(Debug)]
448struct ReceiverQueue(Arc<Queue>);
449
450impl Drop for ReceiverQueue {
451    fn drop(&mut self) {
452        // Drop the receiver state now to propagate the close signal and to
453        // eliminate circular references.
454        //
455        // SAFETY: `receiver` is exclusively owned by this receiver and will never
456        // be accessed again.
457        unsafe { ManuallyDrop::drop(&mut *self.0.receiver.0.get()) };
458        let mut local = self.0.local.lock();
459        local.receiver_gone = true;
460        let _waker = std::mem::take(&mut local.waker);
461        local.messages.clear_and_shrink();
462        let _ports = std::mem::take(&mut local.ports);
463    }
464}
465
466impl ReceiverQueue {
467    fn state(&self) -> &ReceiverState {
468        // SAFETY: `receiver` is exclusively owned by this receiver.
469        unsafe { &*self.0.receiver.0.get() }
470    }
471
472    fn state_mut(&mut self) -> &mut ReceiverState {
473        self.split_mut().1
474    }
475
476    fn split_mut(&mut self) -> (&Arc<Queue>, &mut ReceiverState) {
477        // SAFETY: `receiver` is exclusively owned by this receiver.
478        let state = unsafe { &mut *self.0.receiver.0.get() };
479        (&self.0, state)
480    }
481}
482
483impl ReceiverCore {
484    fn new(vtable: &'static ElementVtable) -> Self {
485        Self {
486            queue: ReceiverQueue(Arc::new(Queue {
487                local: Mutex::new(LocalQueue::new(vtable)),
488                remote: OnceLock::new(),
489                receiver: SyncUnsafeCell::new(ManuallyDrop::new(ReceiverState {
490                    ports: PortHandlerList::default(),
491                    terminated: true,
492                })),
493            })),
494        }
495    }
496
497    /// Polls for a message.
498    ///
499    /// # Safety
500    ///
501    /// The queue must have element type `T`.
502    unsafe fn try_poll_recv<T>(
503        &mut self,
504        cx: Option<&mut Context<'_>>,
505    ) -> Poll<Result<T, RecvError>> {
506        fn try_poll_recv<'a>(
507            this: &'a mut ReceiverCore,
508            cx: Option<&mut Context<'_>>,
509        ) -> Poll<Result<MutexGuard<'a, LocalQueue>, RecvError>> {
510            let (queue, state) = this.queue.split_mut();
511            loop {
512                debug_assert!(queue.remote.get().is_none());
513                let mut local = queue.local.lock();
514                if local.remove_closed {
515                    local.remove_closed = false;
516                    drop(local);
517                    if let Err(err) = state.ports.remove_closed() {
518                        // Propagate the error to the caller only if there
519                        // are no more senders. Otherwise, the caller might
520                        // stop receiving messages from the remaining
521                        // senders.
522                        let local = queue.local.lock();
523                        if local.messages.is_empty()
524                            && local.ports.is_empty()
525                            && ReceiverCore::is_closed(queue)
526                        {
527                            state.terminated = true;
528                            return Poll::Ready(Err(RecvError::Error(err)));
529                        } else {
530                            trace_channel_error(&err);
531                        }
532                    }
533                } else if !local.ports.is_empty() {
534                    let new_handler = local.new_handler;
535                    let ports = std::mem::take(&mut local.ports);
536                    drop(local);
537                    state.ports.0.extend(ports.into_iter().map(|port| {
538                        // SAFETY: `new_handler` has been set to a function whose
539                        // element type matches the queue's element type.
540                        let handler = unsafe { new_handler(queue.clone()) };
541                        port.set_handler(handler)
542                    }));
543                    continue;
544                } else if local.messages.is_empty() {
545                    if let Some(cx) = cx {
546                        if !local
547                            .waker
548                            .as_ref()
549                            .is_some_and(|waker| waker.will_wake(cx.waker()))
550                            && !ReceiverCore::is_closed(queue)
551                        {
552                            local.waker = Some(cx.waker().clone());
553                        }
554                    }
555                    if ReceiverCore::is_closed(queue) {
556                        state.terminated = true;
557                        return Poll::Ready(Err(RecvError::Closed));
558                    } else {
559                        return Poll::Pending;
560                    }
561                } else {
562                    return Poll::Ready(Ok(local));
563                }
564            }
565        }
566
567        ready!(try_poll_recv(self, cx))
568            .map(|mut local| {
569                let message = local.messages.pop_front_in_place().unwrap();
570                // SAFETY: `message` is a valid owned `T`.
571                unsafe { message.as_ptr().cast::<T>().read() }
572            })
573            .into()
574    }
575
576    fn is_closed(queue: &Arc<Queue>) -> bool {
577        Arc::strong_count(queue) == 1
578    }
579
580    fn sender(&mut self) -> SenderCore {
581        self.queue.state_mut().terminated = false;
582        SenderCore(ManuallyDrop::new(self.queue.0.clone()))
583    }
584
585    /// Converts this receiver into a port.
586    ///
587    /// # Safety
588    /// The caller must ensure that the queue has element type `T`.
589    unsafe fn into_port<T: MeshField>(self) -> Port {
590        fn into_port(mut this: ReceiverCore, send: SendFn) -> Port {
591            let ports = std::mem::take(&mut this.queue.state_mut().ports).into_ports();
592            if ports.len() == 1 {
593                if let Some(queue) = Arc::get_mut(&mut this.queue.0) {
594                    let local = queue.local.get_mut();
595                    if local.messages.is_empty() && local.ports.is_empty() {
596                        return ports.into_iter().next().unwrap();
597                    }
598                }
599            }
600            let (sender, recv) = Port::new_pair();
601            for port in ports {
602                sender.send_protobuf(ChannelPayload::<()>::Port(port));
603            }
604            let mut local = this.queue.0.local.lock();
605            for port in local.ports.drain(..) {
606                sender.send_protobuf(ChannelPayload::<()>::Port(port));
607            }
608            while let Some(message) = local.messages.pop_front_in_place() {
609                // SAFETY: `message` is a valid owned `T`.
610                unsafe { send(&sender, MessagePtr(message.as_ptr())) };
611            }
612            local.remote = true;
613            this.queue
614                .0
615                .remote
616                .set(RemoteQueueState { port: sender, send })
617                .ok()
618                .unwrap();
619
620            recv
621        }
622        into_port(self, send_message::<T>)
623    }
624
625    /// Creates a new queue with element type `T` for receiving from `port`.
626    fn from_port<T: MeshField>(port: Port) -> Self {
627        fn from_port(
628            port: Port,
629            vtable: &'static ElementVtable,
630            new_handler: NewHandlerFn,
631        ) -> ReceiverCore {
632            let queue = Arc::new(Queue {
633                local: Mutex::new(LocalQueue {
634                    ports: vec![port],
635                    new_handler,
636                    ..LocalQueue::new(vtable)
637                }),
638                remote: OnceLock::new(),
639                receiver: Default::default(),
640            });
641            ReceiverCore {
642                queue: ReceiverQueue(queue),
643            }
644        }
645        from_port(
646            port,
647            const { &ElementVtable::new::<T>() },
648            RemotePortHandler::new::<T>,
649        )
650    }
651}
652
653fn trace_channel_error(err: &ChannelError) {
654    tracing::error!(
655        error = err as &dyn std::error::Error,
656        "channel closed due to error"
657    );
658}
659
660impl<T> futures_core::Stream for Receiver<T> {
661    type Item = T;
662
663    fn poll_next(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
664        Poll::Ready(match std::task::ready!(self.get_mut().poll_recv(cx)) {
665            Ok(t) => Some(t),
666            Err(RecvError::Closed) => None,
667            Err(RecvError::Error(err)) => {
668                trace_channel_error(&err);
669                None
670            }
671        })
672    }
673}
674
675impl<T> futures_core::FusedStream for Receiver<T> {
676    fn is_terminated(&self) -> bool {
677        self.0.queue.state().terminated
678    }
679}
680
681#[derive(Debug, Default)]
682struct PortHandlerList(Vec<PortWithHandler<RemotePortHandler>>);
683
684impl PortHandlerList {
685    fn remove_closed(&mut self) -> Result<(), ChannelError> {
686        let mut r = Ok(());
687        self.0.retain(|port| match port.is_closed() {
688            Ok(true) => false,
689            Ok(false) => true,
690            Err(err) => {
691                let err = ChannelError::from(err);
692                if r.is_ok() {
693                    r = Err(err);
694                } else {
695                    trace_channel_error(&err);
696                }
697                false
698            }
699        });
700        r
701    }
702
703    fn into_ports(self) -> Vec<Port> {
704        self.0
705            .into_iter()
706            .map(|port| port.remove_handler().0)
707            .collect()
708    }
709}
710
711impl<T: MeshField> DefaultEncoding for Receiver<T> {
712    type Encoding = PortField;
713}
714
715impl<T: MeshField> From<Port> for Receiver<T> {
716    fn from(port: Port) -> Self {
717        Self(ReceiverCore::from_port::<T>(port), PhantomData)
718    }
719}
720
721impl<T: MeshField> From<Receiver<T>> for Port {
722    fn from(receiver: Receiver<T>) -> Self {
723        // SAFETY: the queue has element type `T`.
724        unsafe { receiver.0.into_port::<T>() }
725    }
726}
727
728impl<T: MeshField> Receiver<T> {
729    /// Bridges this and `sender` together, consuming both `self` and `sender`.
730    ///
731    /// See [`Sender::bridge`] for more details.
732    pub fn bridge(self, sender: Sender<T>) {
733        sender.bridge(self)
734    }
735}
736
737#[derive(Debug)]
738struct Queue {
739    remote: OnceLock<RemoteQueueState>,
740    local: Mutex<LocalQueue>,
741    // Stored in this shared state but owned exclusively by the `ReceiverCore`.
742    // This minimizes the size of the `ReceiverCore`.
743    receiver: SyncUnsafeCell<ManuallyDrop<ReceiverState>>,
744}
745
746enum QueueAccess<'a> {
747    Local(MutexGuard<'a, LocalQueue>),
748    Remote(&'a RemoteQueueState),
749}
750
751impl Queue {
752    fn access(&self) -> QueueAccess<'_> {
753        loop {
754            // Check if the queue is remote first to avoid taking the lock.
755            if let Some(remote) = self.remote.get() {
756                break QueueAccess::Remote(remote);
757            } else {
758                let local = self.local.lock();
759                if local.remote {
760                    // The queue was made remote between our check above and
761                    // taking the lock.
762                    continue;
763                }
764                break QueueAccess::Local(local);
765            }
766        }
767    }
768}
769
770#[derive(Debug)]
771struct LocalQueue {
772    messages: ErasedVecDeque,
773    ports: Vec<Port>,
774    waker: Option<Waker>,
775    remote: bool,
776    receiver_gone: bool,
777    remove_closed: bool,
778    new_handler: NewHandlerFn,
779}
780
781type NewHandlerFn = unsafe fn(Arc<Queue>) -> RemotePortHandler;
782
783impl LocalQueue {
784    fn new(vtable: &'static ElementVtable) -> Self {
785        Self {
786            messages: ErasedVecDeque::new(vtable),
787            ports: Vec::new(),
788            waker: None,
789            remote: false,
790            receiver_gone: false,
791            remove_closed: false,
792            new_handler: missing_handler,
793        }
794    }
795}
796
797fn missing_handler(_: Arc<Queue>) -> RemotePortHandler {
798    unreachable!("handler function not set")
799}
800
801#[derive(Debug)]
802struct RemoteQueueState {
803    port: Port,
804    send: SendFn,
805}
806
807type SendFn = unsafe fn(&Port, MessagePtr);
808
809#[derive(Protobuf)]
810#[mesh(bound = "T: MeshField", resource = "mesh_node::resource::Resource")]
811enum ChannelPayload<T> {
812    #[mesh(transparent)]
813    Message(T),
814    #[mesh(transparent)]
815    Port(Port),
816}
817
818struct RemotePortHandler {
819    queue: Arc<Queue>,
820    parse: unsafe fn(Message<'_>, *mut ()) -> Result<Option<Port>, ChannelError>,
821}
822
823impl RemotePortHandler {
824    /// Creates a new handler for a queue with element type `T`.
825    ///
826    /// # Safety
827    /// The caller must ensure that `queue` has element type `T`.
828    unsafe fn new<T: MeshField>(queue: Arc<Queue>) -> Self {
829        Self {
830            queue,
831            parse: Self::parse::<T>,
832        }
833    }
834
835    /// Parses a message into a `T` or a `Port`.
836    ///
837    /// # Safety
838    /// The caller must ensure that `p` is valid for writing a `T`.
839    unsafe fn parse<T: MeshField>(
840        message: Message<'_>,
841        p: *mut (),
842    ) -> Result<Option<Port>, ChannelError> {
843        match message.parse_non_static::<ChannelPayload<T>>() {
844            Ok(ChannelPayload::Message(message)) => {
845                // SAFETY: The caller guarantees `p` is valid for writing a `T`.
846                unsafe { p.cast::<T>().write(message) };
847                Ok(None)
848            }
849            Ok(ChannelPayload::Port(port)) => Ok(Some(port)),
850            Err(err) => Err(err.into()),
851        }
852    }
853}
854
855impl HandlePortEvent for RemotePortHandler {
856    fn message(
857        &mut self,
858        control: &mut mesh_node::local_node::PortControl<'_, '_>,
859        message: Message<'_>,
860    ) -> Result<(), HandleMessageError> {
861        let mut local = self.queue.local.lock();
862        assert!(!local.receiver_gone);
863        assert!(!local.remote);
864        // Decode directly into the queue.
865        let p = local.messages.reserve_one();
866        // SAFETY: `p` is valid for writing a `T`, the element type of the
867        // queue.
868        let r = unsafe { (self.parse)(message, p.as_ptr()) };
869        let port = r.map_err(HandleMessageError::new)?;
870        match port {
871            None => {
872                // SAFETY: `p` has been written to.
873                unsafe { p.commit() };
874            }
875            Some(port) => {
876                local.ports.push(port);
877            }
878        }
879        let waker = local.waker.take();
880        drop(local);
881        if let Some(waker) = waker {
882            control.wake(waker);
883        }
884        Ok(())
885    }
886
887    fn close(&mut self, control: &mut mesh_node::local_node::PortControl<'_, '_>) {
888        let waker = {
889            let mut local = self.queue.local.lock();
890            local.remove_closed = true;
891            local.waker.take()
892        };
893        if let Some(waker) = waker {
894            control.wake(waker);
895        }
896    }
897
898    fn fail(
899        &mut self,
900        control: &mut mesh_node::local_node::PortControl<'_, '_>,
901        _err: mesh_node::local_node::NodeError,
902    ) {
903        self.close(control);
904    }
905
906    fn drain(&mut self) -> Vec<OwnedMessage> {
907        Vec::new()
908    }
909}
910
911#[cfg(test)]
912mod tests {
913    use super::Receiver;
914    use super::Sender;
915    use super::channel;
916    use crate::RecvError;
917    use futures::StreamExt;
918    use futures::executor::block_on;
919    use futures_core::FusedStream;
920    use mesh_node::local_node::Port;
921    use mesh_protobuf::Protobuf;
922    use std::cell::Cell;
923    use std::marker::PhantomData;
924    use test_with_tracing::test;
925
926    // Ensure `Send` and `Sync` are implemented correctly.
927    static_assertions::assert_impl_all!(Sender<i32>: Send, Sync);
928    static_assertions::assert_impl_all!(Receiver<i32>: Send, Sync);
929    static_assertions::assert_impl_all!(Sender<Cell<i32>>: Send, Sync);
930    static_assertions::assert_impl_all!(Receiver<Cell<i32>>: Send, Sync);
931    static_assertions::assert_not_impl_any!(Sender<*const ()>: Send, Sync);
932    static_assertions::assert_not_impl_any!(Receiver<*const ()>: Send, Sync);
933
934    #[test]
935    fn test_basic() {
936        block_on(async {
937            let (sender, mut receiver) = channel();
938            sender.send(String::from("test"));
939            assert_eq!(receiver.next().await.as_deref(), Some("test"));
940            drop(sender);
941            assert_eq!(receiver.next().await, None);
942        })
943    }
944
945    #[test]
946    fn test_convert_sender_port() {
947        block_on(async {
948            let (sender, mut receiver) = channel::<String>();
949            let sender = Sender::<String>::from(Port::from(sender));
950            sender.send(String::from("test"));
951            assert_eq!(receiver.next().await.as_deref(), Some("test"));
952            drop(sender);
953            assert_eq!(receiver.next().await, None);
954        })
955    }
956
957    #[test]
958    fn test_convert_receiver_port() {
959        block_on(async {
960            let (sender, receiver) = channel();
961            let mut receiver = Receiver::<String>::from(Port::from(receiver));
962            sender.send(String::from("test"));
963            assert_eq!(receiver.next().await.as_deref(), Some("test"));
964            drop(sender);
965            assert_eq!(receiver.next().await, None);
966        })
967    }
968
969    #[test]
970    fn test_non_port_and_port_sender() {
971        block_on(async {
972            let (sender, mut receiver) = channel();
973            let sender2 = Sender::<String>::from(Port::from(sender.clone()));
974            sender.send(String::from("test"));
975            sender2.send(String::from("tset"));
976            assert_eq!(receiver.next().await.as_deref(), Some("test"));
977            assert_eq!(receiver.next().await.as_deref(), Some("tset"));
978            drop(sender);
979            drop(sender2);
980            assert_eq!(receiver.next().await, None);
981        })
982    }
983
984    #[test]
985    fn test_port_receiver_with_senders_and_messages() {
986        block_on(async {
987            let (sender, receiver) = channel();
988            let sender2 = Sender::<String>::from(Port::from(sender.clone()));
989            sender.send(String::from("test"));
990            sender2.send(String::from("tset"));
991            let mut receiver = Receiver::<String>::from(Port::from(receiver));
992            assert_eq!(receiver.next().await.as_deref(), Some("test"));
993            assert_eq!(receiver.next().await.as_deref(), Some("tset"));
994            drop(sender);
995            drop(sender2);
996            assert_eq!(receiver.next().await, None);
997        })
998    }
999
1000    #[test]
1001    fn test_message_corruption() {
1002        block_on(async {
1003            let (sender, receiver) = channel();
1004            let mut receiver = Receiver::<i32>::from(Port::from(receiver));
1005            sender.send("text".to_owned());
1006            let RecvError::Error(err) = receiver.recv().await.unwrap_err() else {
1007                panic!()
1008            };
1009            tracing::info!(error = &err as &dyn std::error::Error, "expected error");
1010            assert!(receiver.is_terminated());
1011        })
1012    }
1013
1014    #[test]
1015    fn test_no_send() {
1016        block_on(async {
1017            #[derive(Protobuf)]
1018            struct NoSend(String, PhantomData<*mut ()>);
1019
1020            let (sender, receiver) = channel::<NoSend>();
1021            let mut receiver = Receiver::<NoSend>::from(Port::from(receiver));
1022            sender.send(NoSend(String::from("test"), PhantomData));
1023            assert_eq!(
1024                receiver.next().await.as_ref().map(|v| v.0.as_str()),
1025                Some("test")
1026            );
1027        })
1028    }
1029}