1use crate::ChannelError;
7use futures_io::AsyncRead;
8use futures_io::AsyncWrite;
9use mesh_node::local_node::HandleMessageError;
10use mesh_node::local_node::HandlePortEvent;
11use mesh_node::local_node::NodeError;
12use mesh_node::local_node::Port;
13use mesh_node::local_node::PortControl;
14use mesh_node::local_node::PortField;
15use mesh_node::local_node::PortWithHandler;
16use mesh_node::message::Message;
17use mesh_node::message::OwnedMessage;
18use mesh_node::resource::Resource;
19use mesh_protobuf::Protobuf;
20use mesh_protobuf::encoding::OptionField;
21use std::collections::VecDeque;
22use std::io;
23use std::pin::Pin;
24use std::sync::Arc;
25use std::task::Context;
26use std::task::Poll;
27use std::task::Waker;
28use thiserror::Error;
29
30pub fn pipe() -> (ReadPipe, WritePipe) {
37 let (read, write) = Port::new_pair();
38 let quota_bytes = 65536;
39 let quota_messages = 64;
40 let read = ReadPipe {
41 port: read.set_handler(ReadPipeState {
42 data: VecDeque::new(),
43 consumed_messages: 0,
44 consumed_bytes: 0,
45 quota_bytes,
46 closed: false,
47 failed: None,
48 waker: None,
49 }),
50 quota_messages,
51 quota_bytes,
52 };
53 let write = WritePipe {
54 port: Some(write.set_handler(WritePipeState {
55 remaining_messages: quota_messages,
56 remaining_bytes: quota_bytes,
57 closed: false,
58 failed: None,
59 waker: None,
60 })),
61 };
62 (read, write)
63}
64
65pub struct ReadPipe {
69 port: PortWithHandler<ReadPipeState>,
70 quota_bytes: u32,
71 quota_messages: u32,
72}
73
74struct ReadPipeState {
75 data: VecDeque<u8>,
76 consumed_messages: u32,
77 consumed_bytes: u32,
78 quota_bytes: u32,
79 closed: bool,
80 failed: Option<ReadError>,
81 waker: Option<Waker>,
82}
83
84#[derive(Debug, Error, Clone)]
85enum ReadError {
86 #[error("received message beyond quota")]
87 OverQuota,
88 #[error("node failure")]
89 NodeFailure(#[source] NodeError),
90}
91
92impl From<ReadError> for io::Error {
93 fn from(err: ReadError) -> Self {
94 let kind = match err {
95 ReadError::OverQuota => io::ErrorKind::InvalidData,
96 ReadError::NodeFailure(_) => io::ErrorKind::ConnectionReset,
97 };
98 io::Error::new(kind, err)
99 }
100}
101
102impl AsyncRead for ReadPipe {
103 fn poll_read(
104 self: Pin<&mut Self>,
105 cx: &mut Context<'_>,
106 buf: &mut [u8],
107 ) -> Poll<io::Result<usize>> {
108 let mut old_waker = None;
109 self.port.with_port_and_handler(|port, state| {
110 if state.data.is_empty() {
111 if let Some(err) = &state.failed {
112 return Err(err.clone().into()).into();
113 } else if state.closed {
114 return Ok(0).into();
115 }
116 old_waker = state.waker.replace(cx.waker().clone());
117 return Poll::Pending;
118 }
119 let n = state.data.len().min(buf.len());
120 let (left, right) = state.data.as_slices();
121 if n > left.len() {
122 buf[..left.len()].copy_from_slice(left);
123 buf[left.len()..n].copy_from_slice(&right[..n - left.len()]);
124 } else {
125 buf[..n].copy_from_slice(&left[..n]);
126 }
127 state.data.drain(..n);
128 state.consumed_bytes += n as u32;
129 if state.consumed_bytes >= self.quota_bytes / 2
130 || state.consumed_messages >= self.quota_messages / 2
131 {
132 port.respond(Message::new(QuotaMessage {
133 bytes: state.consumed_bytes,
134 messages: state.consumed_messages,
135 }));
136 state.consumed_bytes = 0;
137 state.consumed_messages = 0;
138 }
139 Ok(n).into()
140 })
141 }
142}
143
144impl HandlePortEvent for ReadPipeState {
145 fn message(
146 &mut self,
147 control: &mut PortControl<'_, '_>,
148 message: Message<'_>,
149 ) -> Result<(), HandleMessageError> {
150 if let Some(err) = &self.failed {
151 return Err(HandleMessageError::new(err.clone()));
152 }
153 let (data, _) = message.serialize();
154 if data.len() + self.data.len() + self.consumed_bytes as usize > self.quota_bytes as usize {
155 self.failed = Some(ReadError::OverQuota);
156 return Err(HandleMessageError::new(ReadError::OverQuota));
157 }
158 self.data.extend(data.as_ref());
159 self.consumed_messages += 1;
160 if let Some(waker) = self.waker.take() {
161 control.wake(waker);
162 }
163 Ok(())
164 }
165
166 fn close(&mut self, control: &mut PortControl<'_, '_>) {
167 self.closed = true;
168 if let Some(waker) = self.waker.take() {
169 control.wake(waker);
170 }
171 }
172
173 fn fail(&mut self, control: &mut PortControl<'_, '_>, err: NodeError) {
174 self.failed = Some(ReadError::NodeFailure(err));
175 if let Some(waker) = self.waker.take() {
176 control.wake(waker);
177 }
178 }
179
180 fn drain(&mut self) -> Vec<OwnedMessage> {
181 let data = std::mem::take(&mut self.data).into();
182 vec![OwnedMessage::serialized(mesh_protobuf::SerializedMessage {
183 data,
184 resources: Vec::new(),
185 })]
186 }
187}
188
189#[derive(Protobuf)]
193#[mesh(resource = "Resource")]
194pub struct WritePipe {
195 #[mesh(encoding = "OptionField<PortField>")]
196 port: Option<PortWithHandler<WritePipeState>>,
197}
198
199#[derive(Default)]
200struct WritePipeState {
201 remaining_messages: u32,
202 remaining_bytes: u32,
203 closed: bool,
204 failed: Option<Arc<ChannelError>>,
205 waker: Option<Waker>,
206}
207
208impl WritePipe {
209 pub fn write_nonblocking(&self, buf: &[u8]) -> io::Result<usize> {
213 match self.write_to_port(None, buf) {
214 Poll::Ready(r) => r,
215 Poll::Pending => Err(io::ErrorKind::WouldBlock.into()),
216 }
217 }
218
219 fn write_to_port(&self, cx: Option<&mut Context<'_>>, buf: &[u8]) -> Poll<io::Result<usize>> {
220 let port = self.port.as_ref().ok_or(io::ErrorKind::BrokenPipe)?;
221 let mut old_waker = None;
222 port.with_port_and_handler(|port, state| {
223 if let Some(err) = &state.failed {
224 Err(io::Error::new(io::ErrorKind::ConnectionReset, err.clone())).into()
225 } else if state.closed {
226 Err(io::ErrorKind::BrokenPipe.into()).into()
227 } else if buf.is_empty() {
228 Ok(0).into()
229 } else if state.remaining_messages > 0 && state.remaining_bytes > 0 {
230 let n = buf.len().min(state.remaining_bytes as usize);
231 state.remaining_bytes -= n as u32;
232 state.remaining_messages -= 1;
233 port.respond(Message::serialized(&buf[..n], Vec::new()));
234 Ok(n).into()
235 } else {
236 if let Some(cx) = cx {
237 old_waker = state.waker.replace(cx.waker().clone());
238 }
239 Poll::Pending
240 }
241 })
242 }
243}
244
245impl AsyncWrite for WritePipe {
246 fn poll_write(
247 self: Pin<&mut Self>,
248 cx: &mut Context<'_>,
249 buf: &[u8],
250 ) -> Poll<io::Result<usize>> {
251 self.write_to_port(Some(cx), buf)
252 }
253
254 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
255 Ok(()).into()
256 }
257
258 fn poll_close(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
259 self.port = None;
260 Ok(()).into()
261 }
262}
263
264impl HandlePortEvent for WritePipeState {
265 fn message(
266 &mut self,
267 control: &mut PortControl<'_, '_>,
268 message: Message<'_>,
269 ) -> Result<(), HandleMessageError> {
270 if let Some(err) = &self.failed {
271 return Err(HandleMessageError::new(err.clone()));
272 }
273 let message = message.parse::<QuotaMessage>().map_err(|err| {
274 let err = Arc::new(ChannelError::from(err));
275 if self.failed.is_none() {
276 self.failed = Some(err.clone());
277 }
278 HandleMessageError::new(err)
279 })?;
280 if self.remaining_bytes == 0 || self.remaining_messages == 0 {
281 if let Some(waker) = self.waker.take() {
282 control.wake(waker);
283 }
284 }
285 self.remaining_bytes += message.bytes;
286 self.remaining_messages += message.messages;
287 Ok(())
288 }
289
290 fn close(&mut self, control: &mut PortControl<'_, '_>) {
291 self.closed = true;
292 if let Some(waker) = self.waker.take() {
293 control.wake(waker);
294 }
295 }
296
297 fn fail(&mut self, control: &mut PortControl<'_, '_>, err: NodeError) {
298 self.failed = Some(Arc::new(err.into()));
299 if let Some(waker) = self.waker.take() {
300 control.wake(waker);
301 }
302 }
303
304 fn drain(&mut self) -> Vec<OwnedMessage> {
305 vec![OwnedMessage::new(QuotaMessage {
308 bytes: self.remaining_bytes,
309 messages: self.remaining_messages,
310 })]
311 }
312}
313
314#[derive(Protobuf)]
315struct QuotaMessage {
316 bytes: u32,
317 messages: u32,
318}
319
320mod encoding {
321 use super::ReadPipe;
322 use super::ReadPipeState;
323 use mesh_node::local_node::Port;
324 use mesh_node::resource::Resource;
325 use mesh_protobuf::DefaultEncoding;
326 use mesh_protobuf::MessageDecode;
327 use mesh_protobuf::MessageEncode;
328 use mesh_protobuf::Protobuf;
329 use mesh_protobuf::encoding::MessageEncoding;
330 use mesh_protobuf::inplace_none;
331 use std::collections::VecDeque;
332
333 pub struct ReadPipeEncoder;
334
335 impl DefaultEncoding for ReadPipe {
336 type Encoding = MessageEncoding<ReadPipeEncoder>;
337 }
338
339 #[derive(Protobuf)]
340 #[mesh(resource = "Resource")]
341 struct SerializedReadPipe {
342 port: Port,
343 quota_bytes: u32,
344 quota_messages: u32,
345 }
346
347 impl From<SerializedReadPipe> for ReadPipe {
348 fn from(value: SerializedReadPipe) -> Self {
349 let SerializedReadPipe {
350 port,
351 quota_bytes,
352 quota_messages,
353 } = value;
354 Self {
355 port: port.set_handler(ReadPipeState {
356 data: VecDeque::new(),
357 consumed_messages: 0,
358 consumed_bytes: 0,
359 quota_bytes,
360 closed: false,
361 failed: None,
362 waker: None,
363 }),
364 quota_bytes,
365 quota_messages,
366 }
367 }
368 }
369
370 impl From<ReadPipe> for SerializedReadPipe {
371 fn from(value: ReadPipe) -> Self {
372 Self {
373 port: value.port.remove_handler().0,
374 quota_bytes: value.quota_bytes,
375 quota_messages: value.quota_messages,
376 }
377 }
378 }
379
380 impl MessageEncode<ReadPipe, Resource> for ReadPipeEncoder {
381 fn write_message(
382 item: ReadPipe,
383 writer: mesh_protobuf::protobuf::MessageWriter<'_, '_, Resource>,
384 ) {
385 <SerializedReadPipe as DefaultEncoding>::Encoding::write_message(
386 SerializedReadPipe::from(item),
387 writer,
388 )
389 }
390
391 fn compute_message_size(
392 item: &mut ReadPipe,
393 mut sizer: mesh_protobuf::protobuf::MessageSizer<'_>,
394 ) {
395 sizer.field(1).resource();
396 sizer.field(2).varint(item.quota_bytes.into());
397 sizer.field(3).varint(item.quota_messages.into());
398 }
399 }
400
401 impl MessageDecode<'_, ReadPipe, Resource> for ReadPipeEncoder {
402 fn read_message(
403 item: &mut mesh_protobuf::inplace::InplaceOption<'_, ReadPipe>,
404 reader: mesh_protobuf::protobuf::MessageReader<'_, '_, Resource>,
405 ) -> mesh_protobuf::Result<()> {
406 inplace_none!(inner: SerializedReadPipe);
407 <SerializedReadPipe as DefaultEncoding>::Encoding::read_message(&mut inner, reader)?;
408 item.set(inner.take().unwrap().into());
409 Ok(())
410 }
411 }
412}
413
414#[cfg(test)]
415mod tests {
416 use super::pipe;
417 use crate::pipe::ReadPipe;
418 use crate::pipe::WritePipe;
419 use futures::AsyncReadExt;
420 use futures::AsyncWriteExt;
421 use futures::FutureExt;
422 use futures_concurrency::future::TryJoin;
423 use mesh_node::resource::SerializedMessage;
424 use pal_async::async_test;
425
426 #[async_test]
427 async fn test_pipe() {
428 let (mut read, mut write) = pipe();
429 let v: Vec<_> = (0..1000000).map(|x| x as u8).collect();
430 let w = async {
431 write.write_all(&v).await?;
432 drop(write);
433 Ok(())
434 };
435 let mut buf = Vec::new();
436 let r = read.read_to_end(&mut buf);
437 (r, w).try_join().await.unwrap();
438 assert_eq!(buf, v);
439 }
440
441 #[async_test]
442 async fn test_message_backpressure() {
443 let (mut read, mut write) = pipe();
444 let mut n = 0;
445 while write.write(&[0]).now_or_never().is_some() {
446 n += 1;
447 }
448 assert_eq!(n, 64);
449 let mut b = [0];
450 read.read(&mut b).now_or_never().unwrap().unwrap();
451 write.write(&[0]).now_or_never().unwrap().unwrap();
452 }
453
454 #[async_test]
455 async fn test_encoding() {
456 let (read, mut write) = pipe();
457 write.write_all(b"hello world").await.unwrap();
458 let mut read: ReadPipe = SerializedMessage::from_message(read)
459 .into_message()
460 .unwrap();
461 let mut write: WritePipe = SerializedMessage::from_message(write)
462 .into_message()
463 .unwrap();
464 write.write_all(b"!").await.unwrap();
465 write.close().await.unwrap();
466 let mut b = Vec::new();
467 read.read_to_end(&mut b).await.unwrap();
468 assert_eq!(b.as_slice(), b"hello world!");
469 }
470}