mesh_channel/
pipe.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! Implementation a unidirectional byte stream pipe over mesh.
5
6use crate::ChannelError;
7use futures_io::AsyncRead;
8use futures_io::AsyncWrite;
9use mesh_node::local_node::HandleMessageError;
10use mesh_node::local_node::HandlePortEvent;
11use mesh_node::local_node::NodeError;
12use mesh_node::local_node::Port;
13use mesh_node::local_node::PortControl;
14use mesh_node::local_node::PortField;
15use mesh_node::local_node::PortWithHandler;
16use mesh_node::message::Message;
17use mesh_node::message::OwnedMessage;
18use mesh_node::resource::Resource;
19use mesh_protobuf::Protobuf;
20use mesh_protobuf::encoding::OptionField;
21use std::collections::VecDeque;
22use std::io;
23use std::pin::Pin;
24use std::sync::Arc;
25use std::task::Context;
26use std::task::Poll;
27use std::task::Waker;
28use thiserror::Error;
29
30/// Creates a new unidirectional pipe, returning a reader and writer.
31///
32/// The resulting pipe has backpressure, so that if the writer tries to write
33/// too many bytes before the reader reads them, then calls to
34/// `futures::AsyncWriteExt::write` will block, and calls to
35/// [`AsyncWrite::poll_write`] will return [`Poll::Pending`].
36pub fn pipe() -> (ReadPipe, WritePipe) {
37    let (read, write) = Port::new_pair();
38    let quota_bytes = 65536;
39    let quota_messages = 64;
40    let read = ReadPipe {
41        port: read.set_handler(ReadPipeState {
42            data: VecDeque::new(),
43            consumed_messages: 0,
44            consumed_bytes: 0,
45            quota_bytes,
46            closed: false,
47            failed: None,
48            waker: None,
49        }),
50        quota_messages,
51        quota_bytes,
52    };
53    let write = WritePipe {
54        port: Some(write.set_handler(WritePipeState {
55            remaining_messages: quota_messages,
56            remaining_bytes: quota_bytes,
57            closed: false,
58            failed: None,
59            waker: None,
60        })),
61    };
62    (read, write)
63}
64
65/// The read side of a pipe.
66///
67/// This is primarily used via [`AsyncRead`] and `futures::AsyncReadExt`.
68pub struct ReadPipe {
69    port: PortWithHandler<ReadPipeState>,
70    quota_bytes: u32,
71    quota_messages: u32,
72}
73
74struct ReadPipeState {
75    data: VecDeque<u8>,
76    consumed_messages: u32,
77    consumed_bytes: u32,
78    quota_bytes: u32,
79    closed: bool,
80    failed: Option<ReadError>,
81    waker: Option<Waker>,
82}
83
84#[derive(Debug, Error, Clone)]
85enum ReadError {
86    #[error("received message beyond quota")]
87    OverQuota,
88    #[error("node failure")]
89    NodeFailure(#[source] NodeError),
90}
91
92impl From<ReadError> for io::Error {
93    fn from(err: ReadError) -> Self {
94        let kind = match err {
95            ReadError::OverQuota => io::ErrorKind::InvalidData,
96            ReadError::NodeFailure(_) => io::ErrorKind::ConnectionReset,
97        };
98        io::Error::new(kind, err)
99    }
100}
101
102impl AsyncRead for ReadPipe {
103    fn poll_read(
104        self: Pin<&mut Self>,
105        cx: &mut Context<'_>,
106        buf: &mut [u8],
107    ) -> Poll<io::Result<usize>> {
108        let mut old_waker = None;
109        self.port.with_port_and_handler(|port, state| {
110            if state.data.is_empty() {
111                if let Some(err) = &state.failed {
112                    return Err(err.clone().into()).into();
113                } else if state.closed {
114                    return Ok(0).into();
115                }
116                old_waker = state.waker.replace(cx.waker().clone());
117                return Poll::Pending;
118            }
119            let n = state.data.len().min(buf.len());
120            let (left, right) = state.data.as_slices();
121            if n > left.len() {
122                buf[..left.len()].copy_from_slice(left);
123                buf[left.len()..n].copy_from_slice(&right[..n - left.len()]);
124            } else {
125                buf[..n].copy_from_slice(&left[..n]);
126            }
127            state.data.drain(..n);
128            state.consumed_bytes += n as u32;
129            if state.consumed_bytes >= self.quota_bytes / 2
130                || state.consumed_messages >= self.quota_messages / 2
131            {
132                port.respond(Message::new(QuotaMessage {
133                    bytes: state.consumed_bytes,
134                    messages: state.consumed_messages,
135                }));
136                state.consumed_bytes = 0;
137                state.consumed_messages = 0;
138            }
139            Ok(n).into()
140        })
141    }
142}
143
144impl HandlePortEvent for ReadPipeState {
145    fn message(
146        &mut self,
147        control: &mut PortControl<'_, '_>,
148        message: Message<'_>,
149    ) -> Result<(), HandleMessageError> {
150        if let Some(err) = &self.failed {
151            return Err(HandleMessageError::new(err.clone()));
152        }
153        let (data, _) = message.serialize();
154        if data.len() + self.data.len() + self.consumed_bytes as usize > self.quota_bytes as usize {
155            self.failed = Some(ReadError::OverQuota);
156            return Err(HandleMessageError::new(ReadError::OverQuota));
157        }
158        self.data.extend(data.as_ref());
159        self.consumed_messages += 1;
160        if let Some(waker) = self.waker.take() {
161            control.wake(waker);
162        }
163        Ok(())
164    }
165
166    fn close(&mut self, control: &mut PortControl<'_, '_>) {
167        self.closed = true;
168        if let Some(waker) = self.waker.take() {
169            control.wake(waker);
170        }
171    }
172
173    fn fail(&mut self, control: &mut PortControl<'_, '_>, err: NodeError) {
174        self.failed = Some(ReadError::NodeFailure(err));
175        if let Some(waker) = self.waker.take() {
176            control.wake(waker);
177        }
178    }
179
180    fn drain(&mut self) -> Vec<OwnedMessage> {
181        let data = std::mem::take(&mut self.data).into();
182        vec![OwnedMessage::serialized(mesh_protobuf::SerializedMessage {
183            data,
184            resources: Vec::new(),
185        })]
186    }
187}
188
189/// The write side of a pipe.
190///
191/// This is primarily used via [`AsyncWrite`] and `futures::AsyncWriteExt`.
192#[derive(Protobuf)]
193#[mesh(resource = "Resource")]
194pub struct WritePipe {
195    #[mesh(encoding = "OptionField<PortField>")]
196    port: Option<PortWithHandler<WritePipeState>>,
197}
198
199#[derive(Default)]
200struct WritePipeState {
201    remaining_messages: u32,
202    remaining_bytes: u32,
203    closed: bool,
204    failed: Option<Arc<ChannelError>>,
205    waker: Option<Waker>,
206}
207
208impl WritePipe {
209    /// Attempts to write `buf` to the pipe without blocking. Returns the number
210    /// of bytes written, an error, or [`io::ErrorKind::WouldBlock`] if the pipe
211    /// is full.
212    pub fn write_nonblocking(&self, buf: &[u8]) -> io::Result<usize> {
213        match self.write_to_port(None, buf) {
214            Poll::Ready(r) => r,
215            Poll::Pending => Err(io::ErrorKind::WouldBlock.into()),
216        }
217    }
218
219    fn write_to_port(&self, cx: Option<&mut Context<'_>>, buf: &[u8]) -> Poll<io::Result<usize>> {
220        let port = self.port.as_ref().ok_or(io::ErrorKind::BrokenPipe)?;
221        let mut old_waker = None;
222        port.with_port_and_handler(|port, state| {
223            if let Some(err) = &state.failed {
224                Err(io::Error::new(io::ErrorKind::ConnectionReset, err.clone())).into()
225            } else if state.closed {
226                Err(io::ErrorKind::BrokenPipe.into()).into()
227            } else if buf.is_empty() {
228                Ok(0).into()
229            } else if state.remaining_messages > 0 && state.remaining_bytes > 0 {
230                let n = buf.len().min(state.remaining_bytes as usize);
231                state.remaining_bytes -= n as u32;
232                state.remaining_messages -= 1;
233                port.respond(Message::serialized(&buf[..n], Vec::new()));
234                Ok(n).into()
235            } else {
236                if let Some(cx) = cx {
237                    old_waker = state.waker.replace(cx.waker().clone());
238                }
239                Poll::Pending
240            }
241        })
242    }
243}
244
245impl AsyncWrite for WritePipe {
246    fn poll_write(
247        self: Pin<&mut Self>,
248        cx: &mut Context<'_>,
249        buf: &[u8],
250    ) -> Poll<io::Result<usize>> {
251        self.write_to_port(Some(cx), buf)
252    }
253
254    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
255        Ok(()).into()
256    }
257
258    fn poll_close(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
259        self.port = None;
260        Ok(()).into()
261    }
262}
263
264impl HandlePortEvent for WritePipeState {
265    fn message(
266        &mut self,
267        control: &mut PortControl<'_, '_>,
268        message: Message<'_>,
269    ) -> Result<(), HandleMessageError> {
270        if let Some(err) = &self.failed {
271            return Err(HandleMessageError::new(err.clone()));
272        }
273        let message = message.parse::<QuotaMessage>().map_err(|err| {
274            let err = Arc::new(ChannelError::from(err));
275            if self.failed.is_none() {
276                self.failed = Some(err.clone());
277            }
278            HandleMessageError::new(err)
279        })?;
280        if self.remaining_bytes == 0 || self.remaining_messages == 0 {
281            if let Some(waker) = self.waker.take() {
282                control.wake(waker);
283            }
284        }
285        self.remaining_bytes += message.bytes;
286        self.remaining_messages += message.messages;
287        Ok(())
288    }
289
290    fn close(&mut self, control: &mut PortControl<'_, '_>) {
291        self.closed = true;
292        if let Some(waker) = self.waker.take() {
293            control.wake(waker);
294        }
295    }
296
297    fn fail(&mut self, control: &mut PortControl<'_, '_>, err: NodeError) {
298        self.failed = Some(Arc::new(err.into()));
299        if let Some(waker) = self.waker.take() {
300            control.wake(waker);
301        }
302    }
303
304    fn drain(&mut self) -> Vec<OwnedMessage> {
305        // Send remaining quota as a message to avoid having to synchronize
306        // during encoding.
307        vec![OwnedMessage::new(QuotaMessage {
308            bytes: self.remaining_bytes,
309            messages: self.remaining_messages,
310        })]
311    }
312}
313
314#[derive(Protobuf)]
315struct QuotaMessage {
316    bytes: u32,
317    messages: u32,
318}
319
320mod encoding {
321    use super::ReadPipe;
322    use super::ReadPipeState;
323    use mesh_node::local_node::Port;
324    use mesh_node::resource::Resource;
325    use mesh_protobuf::DefaultEncoding;
326    use mesh_protobuf::MessageDecode;
327    use mesh_protobuf::MessageEncode;
328    use mesh_protobuf::Protobuf;
329    use mesh_protobuf::encoding::MessageEncoding;
330    use mesh_protobuf::inplace_none;
331    use std::collections::VecDeque;
332
333    pub struct ReadPipeEncoder;
334
335    impl DefaultEncoding for ReadPipe {
336        type Encoding = MessageEncoding<ReadPipeEncoder>;
337    }
338
339    #[derive(Protobuf)]
340    #[mesh(resource = "Resource")]
341    struct SerializedReadPipe {
342        port: Port,
343        quota_bytes: u32,
344        quota_messages: u32,
345    }
346
347    impl From<SerializedReadPipe> for ReadPipe {
348        fn from(value: SerializedReadPipe) -> Self {
349            let SerializedReadPipe {
350                port,
351                quota_bytes,
352                quota_messages,
353            } = value;
354            Self {
355                port: port.set_handler(ReadPipeState {
356                    data: VecDeque::new(),
357                    consumed_messages: 0,
358                    consumed_bytes: 0,
359                    quota_bytes,
360                    closed: false,
361                    failed: None,
362                    waker: None,
363                }),
364                quota_bytes,
365                quota_messages,
366            }
367        }
368    }
369
370    impl From<ReadPipe> for SerializedReadPipe {
371        fn from(value: ReadPipe) -> Self {
372            Self {
373                port: value.port.remove_handler().0,
374                quota_bytes: value.quota_bytes,
375                quota_messages: value.quota_messages,
376            }
377        }
378    }
379
380    impl MessageEncode<ReadPipe, Resource> for ReadPipeEncoder {
381        fn write_message(
382            item: ReadPipe,
383            writer: mesh_protobuf::protobuf::MessageWriter<'_, '_, Resource>,
384        ) {
385            <SerializedReadPipe as DefaultEncoding>::Encoding::write_message(
386                SerializedReadPipe::from(item),
387                writer,
388            )
389        }
390
391        fn compute_message_size(
392            item: &mut ReadPipe,
393            mut sizer: mesh_protobuf::protobuf::MessageSizer<'_>,
394        ) {
395            sizer.field(1).resource();
396            sizer.field(2).varint(item.quota_bytes.into());
397            sizer.field(3).varint(item.quota_messages.into());
398        }
399    }
400
401    impl MessageDecode<'_, ReadPipe, Resource> for ReadPipeEncoder {
402        fn read_message(
403            item: &mut mesh_protobuf::inplace::InplaceOption<'_, ReadPipe>,
404            reader: mesh_protobuf::protobuf::MessageReader<'_, '_, Resource>,
405        ) -> mesh_protobuf::Result<()> {
406            inplace_none!(inner: SerializedReadPipe);
407            <SerializedReadPipe as DefaultEncoding>::Encoding::read_message(&mut inner, reader)?;
408            item.set(inner.take().unwrap().into());
409            Ok(())
410        }
411    }
412}
413
414#[cfg(test)]
415mod tests {
416    use super::pipe;
417    use crate::pipe::ReadPipe;
418    use crate::pipe::WritePipe;
419    use futures::AsyncReadExt;
420    use futures::AsyncWriteExt;
421    use futures::FutureExt;
422    use futures_concurrency::future::TryJoin;
423    use mesh_node::resource::SerializedMessage;
424    use pal_async::async_test;
425
426    #[async_test]
427    async fn test_pipe() {
428        let (mut read, mut write) = pipe();
429        let v: Vec<_> = (0..1000000).map(|x| x as u8).collect();
430        let w = async {
431            write.write_all(&v).await?;
432            drop(write);
433            Ok(())
434        };
435        let mut buf = Vec::new();
436        let r = read.read_to_end(&mut buf);
437        (r, w).try_join().await.unwrap();
438        assert_eq!(buf, v);
439    }
440
441    #[async_test]
442    async fn test_message_backpressure() {
443        let (mut read, mut write) = pipe();
444        let mut n = 0;
445        while write.write(&[0]).now_or_never().is_some() {
446            n += 1;
447        }
448        assert_eq!(n, 64);
449        let mut b = [0];
450        read.read(&mut b).now_or_never().unwrap().unwrap();
451        write.write(&[0]).now_or_never().unwrap().unwrap();
452    }
453
454    #[async_test]
455    async fn test_encoding() {
456        let (read, mut write) = pipe();
457        write.write_all(b"hello world").await.unwrap();
458        let mut read: ReadPipe = SerializedMessage::from_message(read)
459            .into_message()
460            .unwrap();
461        let mut write: WritePipe = SerializedMessage::from_message(write)
462            .into_message()
463            .unwrap();
464        write.write_all(b"!").await.unwrap();
465        write.close().await.unwrap();
466        let mut b = Vec::new();
467        read.read_to_end(&mut b).await.unwrap();
468        assert_eq!(b.as_slice(), b"hello world!");
469    }
470}