mesh_channel/
bidir.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! A bidirectional channel implemented on top of [`Port`].
5
6use super::RecvError;
7use super::TryRecvError;
8use super::lazy::DeserializeFn;
9use super::lazy::LazyMessage;
10use super::lazy::SerializeFn;
11use super::lazy::deserializer;
12use super::lazy::ensure_serializable;
13use super::lazy::lazy_parse;
14use super::lazy::serializer;
15use mesh_node::local_node::HandleMessageError;
16use mesh_node::local_node::HandlePortEvent;
17use mesh_node::local_node::NodeError;
18use mesh_node::local_node::Port;
19use mesh_node::local_node::PortControl;
20use mesh_node::local_node::PortField;
21use mesh_node::local_node::PortWithHandler;
22use mesh_node::message::MeshPayload;
23use mesh_node::message::Message;
24use mesh_node::message::OwnedMessage;
25use mesh_node::resource::SerializedMessage;
26use std::any::TypeId;
27use std::collections::VecDeque;
28use std::fmt;
29use std::fmt::Debug;
30use std::future::Future;
31use std::future::poll_fn;
32use std::task::Context;
33use std::task::Poll;
34use std::task::Waker;
35
36/// One half of a bidirectional communication channel.
37///
38/// The port can send data of type `T` and receive data of type `U`.
39///
40/// This is a lower-level construct for sending and receiving binary messages.
41/// Most code should use a higher-level channel returned by [`mesh::channel()`],
42/// which uses this type internally.
43pub struct Channel<T = SerializedMessage, U = SerializedMessage> {
44    generic: GenericChannel,
45    // Cached function for serializing T.
46    serialize: Option<SerializeFn<T>>,
47    // Cached function for deserializing U.
48    deserialize: Option<DeserializeFn<U>>,
49}
50
51impl<T: MeshPayload, U: MeshPayload> mesh_protobuf::DefaultEncoding for Channel<T, U> {
52    type Encoding = PortField;
53}
54
55struct GenericChannel {
56    port: PortWithHandler<MessageQueue>,
57    queue_drained: bool,
58}
59
60impl Debug for GenericChannel {
61    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
62        f.debug_struct("GenericPort")
63            .field("port", &self.port)
64            .field("queue_drained", &self.queue_drained)
65            .finish()
66    }
67}
68
69impl From<GenericChannel> for Port {
70    fn from(port: GenericChannel) -> Self {
71        port.port.remove_handler().0
72    }
73}
74
75impl<T: 'static + MeshPayload, U: 'static + MeshPayload> From<Channel<T, U>> for Port {
76    fn from(channel: Channel<T, U>) -> Self {
77        channel
78            .change_types::<SerializedMessage, SerializedMessage>()
79            .generic
80            .into()
81    }
82}
83
84impl<T: 'static + MeshPayload, U: 'static + MeshPayload> From<Port> for Channel<T, U> {
85    fn from(port: Port) -> Self {
86        <Channel<SerializedMessage, SerializedMessage>>::new(GenericChannel::new(port))
87            .change_types()
88    }
89}
90
91impl<T, U> Debug for Channel<T, U> {
92    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
93        f.debug_struct("Port")
94            .field("generic", &self.generic)
95            .field("serialize", &self.serialize)
96            .field("deserialize", &self.deserialize)
97            .finish()
98    }
99}
100
101impl<T: 'static + Send, U: 'static + Send> Channel<T, U> {
102    /// Creates a new bidirectional channel, returning a pair of ports.
103    ///
104    /// The left port can send `T` and receive `U`, and the right port can send
105    /// `U` and receive `T`.
106    pub fn new_pair() -> (Self, Channel<U, T>) {
107        let (left, right) = GenericChannel::new_pair();
108        (Self::new(left), Channel::new(right))
109    }
110
111    fn new(port: GenericChannel) -> Self {
112        let serialize = (TypeId::of::<T>() == TypeId::of::<SerializedMessage>())
113            .then(|| serializer::<T>().unwrap());
114        let deserialize = (TypeId::of::<U>() == TypeId::of::<SerializedMessage>())
115            .then(|| deserializer::<U>().unwrap());
116        Self {
117            generic: port,
118            serialize,
119            deserialize,
120        }
121    }
122}
123
124impl GenericChannel {
125    fn new_pair() -> (Self, Self) {
126        let (left, right) = Port::new_pair();
127        let left = Self {
128            port: left.set_handler(MessageQueue::default()),
129            queue_drained: false,
130        };
131        let right = Self {
132            port: right.set_handler(MessageQueue::default()),
133            queue_drained: false,
134        };
135        (left, right)
136    }
137
138    fn new(port: Port) -> Self {
139        Self {
140            port: port.set_handler(MessageQueue::default()),
141            queue_drained: false,
142        }
143    }
144
145    /// Consumes and returns the first message from the incoming message queue
146    /// if there are any messages available.
147    fn try_recv(&self) -> Result<OwnedMessage, TryRecvError> {
148        self.port.with_handler(|queue| match &queue.state {
149            QueueState::Open => queue.messages.pop_front().ok_or(TryRecvError::Empty),
150            QueueState::Closed => queue.messages.pop_front().ok_or(TryRecvError::Closed),
151            QueueState::Failed(err) => Err(TryRecvError::Error(err.clone().into())),
152        })
153    }
154
155    /// Polls the message queue.
156    fn poll_recv(&self, cx: &mut Context<'_>) -> Poll<Result<OwnedMessage, RecvError>> {
157        let mut old_waker = None;
158        self.port.with_handler(|queue| match &queue.state {
159            QueueState::Open => {
160                if let Some(message) = queue.messages.pop_front() {
161                    Poll::Ready(Ok(message))
162                } else {
163                    old_waker = queue.waker.replace(cx.waker().clone());
164                    Poll::Pending
165                }
166            }
167            QueueState::Closed => Poll::Ready(queue.messages.pop_front().ok_or(RecvError::Closed)),
168            QueueState::Failed(err) => Poll::Ready(Err(RecvError::Error(err.clone().into()))),
169        })
170    }
171
172    fn bridge(self, other: Self) {
173        self.port
174            .remove_handler()
175            .0
176            .bridge(other.port.remove_handler().0);
177    }
178
179    fn is_peer_closed(&self) -> bool {
180        self.port.with_handler(|queue| match queue.state {
181            QueueState::Open => false,
182            QueueState::Closed => true,
183            QueueState::Failed(_) => true,
184        })
185    }
186}
187
188impl<T: 'static + Send, U: 'static + Send> Channel<T, U> {
189    /// Sends a message to the opposite endpoint.
190    pub fn send(&self, message: T) {
191        self.generic
192            .port
193            .send(Message::new(LazyMessage::new(message, self.serialize)))
194    }
195
196    /// Sends a message to the opposite endpoint and closes the channel in one
197    /// operation.
198    pub fn send_and_close(self, message: T) {
199        // FUTURE: optimize by sending a single event with both message and close.
200        self.generic
201            .port
202            .send(Message::new(LazyMessage::new(message, self.serialize)));
203    }
204
205    /// Consumes and returns the first message from the incoming message queue
206    /// if there are any messages available.
207    pub fn try_recv(&mut self) -> Result<U, TryRecvError> {
208        self.generic
209            .try_recv()?
210            .try_unwrap()
211            .or_else(|m| lazy_parse(m.serialize(), &mut self.deserialize))
212            .map_err(|err| TryRecvError::Error(err.into()))
213    }
214
215    /// Polls the message queue.
216    pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Result<U, RecvError>> {
217        let r = std::task::ready!(self.generic.poll_recv(cx)).and_then(|message| {
218            message
219                .try_unwrap()
220                .or_else(|m| lazy_parse(m.serialize(), &mut self.deserialize))
221                .map_err(|err| RecvError::Error(err.into()))
222        });
223        if r.is_err() {
224            self.generic.queue_drained = true;
225        }
226        Poll::Ready(r)
227    }
228
229    /// Returns a future to asynchronously receive a message.
230    pub fn recv(&mut self) -> impl Future<Output = Result<U, RecvError>> + Unpin + '_ {
231        poll_fn(move |cx| self.poll_recv(cx))
232    }
233
234    /// Bridges two channels together so that the peer of `self` is connected
235    /// directly to the peer of `other`.
236    pub fn bridge(self, other: Channel<U, T>) {
237        self.generic.bridge(other.generic);
238    }
239
240    /// Returns if the peer port is known to be closed (or failed).
241    ///
242    /// N.B. This will return true even if there is more data in the message
243    ///      queue. This function is mostly useful on a sending port to know
244    ///      whether there is any hope of data reaching the receive side.
245    pub fn is_peer_closed(&self) -> bool {
246        self.generic.is_peer_closed()
247    }
248
249    /// Returns true if the message queue is drained and the port is closed or
250    /// failed.
251    ///
252    /// If the port has failed, this will only return true if the failure
253    /// condition has been consumed.
254    pub fn is_queue_drained(&self) -> bool {
255        self.generic.queue_drained
256    }
257}
258
259impl<T: 'static + MeshPayload, U: 'static + MeshPayload> Channel<T, U> {
260    /// Changes the message types for the port.
261    ///
262    /// The old and new types must be serializable since the port's peer is
263    /// still operating on the old types (which is intended to work fine as long
264    /// as the messages have compatible serialization formats). As a result, it
265    /// may be necessary to round trip messages through their serialized form.
266    ///
267    /// The caller must therefore ensure that the new message type is compatible
268    /// with the message encoding.
269    pub fn change_types<NewT: 'static + MeshPayload, NewU: 'static + MeshPayload>(
270        self,
271    ) -> Channel<NewT, NewU> {
272        // Ensure all the types are serializable so that the peer port can
273        // convert between them as necessary.
274        ensure_serializable::<T>();
275        ensure_serializable::<U>();
276        let (serialize, _) = ensure_serializable::<NewT>();
277        let (_, deserialize) = ensure_serializable::<NewU>();
278        Channel {
279            generic: self.generic,
280            serialize: Some(serialize),
281            deserialize: Some(deserialize),
282        }
283    }
284}
285
286#[derive(Debug, Default)]
287enum QueueState {
288    #[default]
289    Open,
290    Closed,
291    Failed(NodeError),
292}
293
294#[derive(Debug, Default)]
295struct MessageQueue {
296    messages: VecDeque<OwnedMessage>,
297    state: QueueState,
298    waker: Option<Waker>,
299}
300
301impl HandlePortEvent for MessageQueue {
302    fn message(
303        &mut self,
304        control: &mut PortControl<'_, '_>,
305        message: Message<'_>,
306    ) -> Result<(), HandleMessageError> {
307        self.messages.push_back(message.into_owned());
308        if let Some(waker) = self.waker.take() {
309            control.wake(waker);
310        }
311        Ok(())
312    }
313
314    fn fail(&mut self, control: &mut PortControl<'_, '_>, err: NodeError) {
315        self.state = QueueState::Failed(err);
316        if let Some(waker) = self.waker.take() {
317            control.wake(waker);
318        }
319    }
320
321    fn close(&mut self, control: &mut PortControl<'_, '_>) {
322        self.state = QueueState::Closed;
323        if let Some(waker) = self.waker.take() {
324            control.wake(waker);
325        }
326    }
327
328    fn drain(&mut self) -> Vec<OwnedMessage> {
329        std::mem::take(&mut self.messages).into()
330    }
331}