1use super::core::Core;
7use super::core::ReadState;
8use super::core::WriteState;
9use crate::async_dgram::AsyncRecv;
10use crate::async_dgram::AsyncSend;
11use crate::core::PollError;
12use futures::AsyncRead;
13use futures::AsyncWrite;
14use guestmem::AccessError;
15use guestmem::MemoryRead;
16use guestmem::MemoryWrite;
17use inspect::InspectMut;
18use std::cmp;
19use std::future::poll_fn;
20use std::io;
21use std::io::IoSlice;
22use std::io::IoSliceMut;
23use std::pin::Pin;
24use std::task::Context;
25use std::task::Poll;
26use std::task::ready;
27use thiserror::Error;
28use vmbus_channel::RawAsyncChannel;
29use vmbus_channel::connected_async_channels;
30use vmbus_ring as ring;
31use vmbus_ring::FlatRingMem;
32use vmbus_ring::RingMem;
33use zerocopy::FromZeros;
34use zerocopy::IntoBytes;
35
36#[derive(Debug, Error)]
37enum Error {
38 #[error("the channel has been closed")]
39 ChannelClosed,
40 #[error("packet is too large for the ring")]
41 PacketTooLarge,
42 #[error("unexpected ring packet type")]
43 UnexpectedRingPacketType,
44 #[error("invalid pipe packet type {0:#x}")]
45 InvalidPipePacketType(u32),
46 #[error("ring buffer error")]
47 Ring(#[from] ring::Error),
48 #[error("memory access error")]
49 Access(#[from] AccessError),
50 #[error("partial packet offset is too large")]
51 PartialPacketOffsetTooLarge,
52 #[error(transparent)]
53 Io(#[from] io::Error),
54}
55
56impl From<PollError> for Error {
57 fn from(value: PollError) -> Self {
58 match value {
59 PollError::Ring(err) => Self::Ring(err),
60 PollError::Closed => Self::ChannelClosed,
61 }
62 }
63}
64
65impl From<Error> for io::Error {
66 fn from(err: Error) -> Self {
67 match err {
68 Error::ChannelClosed => {
69 io::Error::new(io::ErrorKind::ConnectionReset, Error::ChannelClosed)
70 }
71 err => io::Error::other(err),
72 }
73 }
74}
75
76#[derive(Debug)]
77enum TryReadError {
78 Empty,
79 Pipe(Error),
80}
81
82impl From<ring::ReadError> for TryReadError {
83 fn from(e: ring::ReadError) -> Self {
84 match e {
85 ring::ReadError::Empty => Self::Empty,
86 ring::ReadError::Corrupt(e) => Self::Pipe(e.into()),
87 }
88 }
89}
90
91impl<T> From<T> for TryReadError
92where
93 Error: From<T>,
94{
95 fn from(e: T) -> Self {
96 Self::Pipe(e.into())
97 }
98}
99
100#[derive(Debug)]
101enum TryWriteError {
102 Full(usize),
103 Pipe(Error),
104}
105
106impl From<ring::WriteError> for TryWriteError {
107 fn from(e: ring::WriteError) -> Self {
108 match e {
109 ring::WriteError::Full(n) => Self::Full(n),
110 ring::WriteError::Corrupt(e) => Self::Pipe(e.into()),
111 }
112 }
113}
114
115impl<T> From<T> for TryWriteError
116where
117 Error: From<T>,
118{
119 fn from(e: T) -> Self {
120 Self::Pipe(e.into())
121 }
122}
123
124impl From<TryWriteError> for io::Error {
125 fn from(e: TryWriteError) -> Self {
126 match e {
127 TryWriteError::Full(_) => {
128 io::Error::new(io::ErrorKind::WouldBlock, "the ring buffer is full")
129 }
130 TryWriteError::Pipe(e) => e.into(),
131 }
132 }
133}
134
135#[derive(Debug)]
136struct PipeWriteState {
137 state: WriteState,
138 raw: bool,
139 max_payload_len: usize,
140}
141
142impl PipeWriteState {
143 fn new(ptrs: ring::OutgoingOffset, raw: bool, max_payload_len: usize) -> Self {
144 Self {
145 state: WriteState::new(ptrs),
146 raw,
147 max_payload_len,
148 }
149 }
150
151 fn writer<'a, M: RingMem>(&'a mut self, core: &'a Core<M>) -> PipeWriter<'a, M> {
152 PipeWriter { write: self, core }
153 }
154}
155
156struct PipeWriter<'a, M: RingMem> {
157 write: &'a mut PipeWriteState,
158 core: &'a Core<M>,
159}
160
161impl<M: RingMem> PipeWriter<'_, M> {
162 fn try_write_message(&mut self, bufs: &[IoSlice<'_>]) -> Result<usize, TryWriteError> {
165 let len = bufs.iter().map(|x| x.len()).sum();
166 let mut packet_len = len;
167 if len > self.write.max_payload_len {
168 return Err(TryWriteError::Pipe(Error::PacketTooLarge));
169 }
170 if !self.write.raw {
171 packet_len += size_of::<ring::PipeHeader>();
172 }
173 let mut outgoing = self.write.state.ptrs.clone();
174 let range = self.core.out_ring().write(
175 &mut outgoing,
176 &ring::OutgoingPacket {
177 transaction_id: 0,
178 size: packet_len,
179 typ: ring::OutgoingPacketType::InBandNoCompletion,
180 },
181 )?;
182 let mut writer = range.writer(self.core.out_ring());
183 if !self.write.raw {
184 writer.write(
185 ring::PipeHeader {
186 packet_type: ring::PIPE_PACKET_TYPE_DATA,
187 len: len as u32,
188 }
189 .as_bytes(),
190 )?;
191 }
192 for buf in bufs {
193 writer.write(buf)?;
194 }
195 self.write.state.clear_poll(self.core);
196 if self.core.out_ring().commit_write(&mut outgoing) {
197 self.core.signal();
198 self.write.state.signals.increment();
199 }
200 self.write.state.ptrs = outgoing;
201 Ok(len)
202 }
203
204 fn try_write_bytes(&mut self, buf: &[u8]) -> Result<usize, TryWriteError> {
208 if buf.is_empty() {
209 return Ok(0);
210 }
211
212 const CHUNK_SIZE: usize = 2048;
213 let mut written = 0;
216 let mut outgoing = self.write.state.ptrs.clone();
217 for buf in buf.chunks(CHUNK_SIZE) {
218 match self.core.out_ring().write(
219 &mut outgoing,
220 &ring::OutgoingPacket {
221 transaction_id: 0,
222 size: buf.len() + size_of::<ring::PipeHeader>(),
223 typ: ring::OutgoingPacketType::InBandNoCompletion,
224 },
225 ) {
226 Ok(range) => {
227 let mut writer = range.writer(self.core.out_ring());
228 writer.write(
229 ring::PipeHeader {
230 packet_type: ring::PIPE_PACKET_TYPE_DATA,
231 len: buf.len() as u32,
232 }
233 .as_bytes(),
234 )?;
235 writer.write(buf)?;
236 written += buf.len();
237 }
238 Err(ring::WriteError::Full(n)) => {
239 if written > 0 {
240 break;
241 } else {
242 return Err(TryWriteError::Full(n));
243 }
244 }
245 Err(ring::WriteError::Corrupt(err)) => return Err(TryWriteError::Pipe(err.into())),
246 }
247 }
248 assert!(written > 0);
249 if self.core.out_ring().commit_write(&mut outgoing) {
250 self.core.signal();
251 self.write.state.signals.increment();
252 }
253 self.write.state.ptrs = outgoing;
254 Ok(written)
255 }
256
257 fn try_shutdown_writes(&mut self) -> Result<(), TryWriteError> {
261 if !self.write.raw {
262 match self.try_write_message(&[]) {
266 Ok(_) => {}
267 Err(err) => return Err(err),
268 }
269 }
270 Ok(())
271 }
272
273 fn poll_op<F, R>(&mut self, cx: &mut Context<'_>, mut f: F) -> Poll<Result<R, Error>>
274 where
275 F: FnMut(&mut Self) -> Result<R, TryWriteError>,
276 {
277 let mut send_size = 32;
279 loop {
280 std::task::ready!(self.write.state.poll_ready(cx, self.core, send_size))?;
281 match f(self) {
282 Ok(r) => break Poll::Ready(Ok(r)),
283 Err(TryWriteError::Full(len)) => {
284 send_size = len;
285 self.write.state.clear_ready();
286 }
287 Err(TryWriteError::Pipe(e)) => break Poll::Ready(Err(e)),
288 }
289 }
290 }
291
292 fn poll_write_bytes(&mut self, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, Error>> {
293 self.poll_op(cx, |this| this.try_write_bytes(buf))
294 }
295
296 fn poll_write_message(
297 &mut self,
298 cx: &mut Context<'_>,
299 bufs: &[IoSlice<'_>],
300 ) -> Poll<Result<usize, Error>> {
301 self.poll_op(cx, |this| this.try_write_message(bufs))
302 }
303
304 fn poll_shutdown_writes(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
308 match self.poll_op(cx, |this| this.try_shutdown_writes()) {
309 Poll::Ready(Err(Error::ChannelClosed)) => {
310 Poll::Ready(Ok(()))
312 }
313 r => r,
314 }
315 }
316}
317
318#[derive(Debug)]
319struct PipeReadState {
320 read: ReadState,
321 max_payload_len: usize,
322 raw: bool,
323 eof: bool,
324}
325
326impl PipeReadState {
327 fn new(ptrs: ring::IncomingOffset, raw: bool, max_payload_len: usize) -> Self {
328 Self {
329 read: ReadState::new(ptrs),
330 raw,
331 max_payload_len,
332 eof: false,
333 }
334 }
335
336 fn reader<'a, M: RingMem>(&'a mut self, core: &'a Core<M>) -> PipeReader<'a, M> {
337 PipeReader { state: self, core }
338 }
339}
340
341struct PipeReader<'a, M: RingMem> {
342 state: &'a mut PipeReadState,
343 core: &'a Core<M>,
344}
345
346impl<M: RingMem> PipeReader<'_, M> {
347 fn try_read_message(&mut self, bufs: &mut [IoSliceMut<'_>]) -> Result<usize, TryReadError> {
350 let len = bufs.iter().map(|x| x.len()).sum();
351 let mut incoming = self.state.read.ptrs.clone();
352 match self.core.in_ring().read(&mut incoming) {
353 Ok(ring::IncomingPacket {
354 typ: ring::IncomingPacketType::InBand,
355 payload,
356 ..
357 }) => {
358 let mut reader = payload.reader(self.core.in_ring());
359 let bytes_read = if !self.state.raw {
360 let mut header = ring::PipeHeader::new_zeroed();
361 reader.read(header.as_mut_bytes())?;
362 if header.packet_type != ring::PIPE_PACKET_TYPE_DATA {
363 return Err(TryReadError::Pipe(Error::InvalidPipePacketType(
364 header.packet_type,
365 )));
366 }
367 header.len as usize } else {
369 payload.len()
370 };
371 if bytes_read > cmp::min(len, self.state.max_payload_len) {
372 return Err(TryReadError::Pipe(Error::PacketTooLarge));
373 }
374 let mut remaining = bytes_read;
375 for buf in bufs {
376 if remaining == 0 {
377 break;
378 }
379 let this_len = cmp::min(remaining, buf.len());
380 remaining -= this_len;
381 reader.read(&mut buf[..this_len])?;
382 }
383 self.state.read.clear_poll(self.core);
384 if self.core.in_ring().commit_read(&mut incoming) {
385 self.core.signal();
386 self.state.read.signals.increment();
387 }
388 self.state.read.ptrs = incoming;
389 Ok(bytes_read)
390 }
391 Ok(_) => Err(TryReadError::Pipe(Error::UnexpectedRingPacketType)),
392 Err(err) => Err(err.into()),
393 }
394 }
395
396 fn try_read_bytes(&mut self, buf: &mut [u8]) -> Result<usize, TryReadError> {
399 if buf.is_empty() || self.state.eof {
400 return Ok(0);
401 }
402 let mut incoming = self.state.read.ptrs.clone();
403 let mut commit = incoming.clone();
404 let mut total_read = 0;
405 while total_read < buf.len() {
406 match self.core.in_ring().read(&mut incoming) {
407 Ok(ring::IncomingPacket {
408 typ: ring::IncomingPacketType::InBand,
409 payload,
410 ..
411 }) => {
412 let mut reader = payload.reader(self.core.in_ring());
413 let mut header = ring::PipeHeader::new_zeroed();
414 reader.read(header.as_mut_bytes())?;
415 let (off, len) = match header.packet_type {
416 ring::PIPE_PACKET_TYPE_DATA => {
417 if header.len == 0 {
431 if total_read == 0 {
432 self.state.eof = true;
433 commit = incoming.clone();
434 }
435 break;
436 }
437 (0, header.len as usize)
438 }
439 ring::PIPE_PACKET_TYPE_PARTIAL => {
440 let off = header.len >> 16;
444 let len = header.len & 0xffff;
445 if off >= len {
446 return Err(TryReadError::Pipe(Error::PartialPacketOffsetTooLarge));
447 }
448 (off as usize, (len - off) as usize)
449 }
450 n => return Err(TryReadError::Pipe(Error::InvalidPipePacketType(n))),
451 };
452 reader.skip(off)?;
453 let read = cmp::min(len, buf.len() - total_read);
454 reader.read(&mut buf[total_read..total_read + read])?;
455 if read < len {
456 header.packet_type = ring::PIPE_PACKET_TYPE_PARTIAL;
458 header.len += (read as u32) << 16;
459 let mut writer = payload.writer(self.core.in_ring());
460 writer.write(header.as_bytes())?;
461 } else {
462 commit = incoming.clone();
464 }
465 total_read += read;
466 }
467 Ok(_) => return Err(TryReadError::Pipe(Error::UnexpectedRingPacketType)),
468 Err(ring::ReadError::Empty) => break,
469 Err(ring::ReadError::Corrupt(err)) => return Err(err.into()),
470 }
471 }
472 if total_read > 0 || self.state.eof {
473 self.state.read.clear_poll(self.core);
474 if self.core.in_ring().commit_read(&mut commit) {
475 self.core.signal();
476 self.state.read.signals.increment();
477 }
478 self.state.read.ptrs = commit;
479 Ok(total_read)
480 } else {
481 Err(TryReadError::Empty)
483 }
484 }
485
486 fn poll_op<F, R>(&mut self, cx: &mut Context<'_>, mut f: F) -> Poll<Result<R, Error>>
487 where
488 F: FnMut(&mut Self) -> Result<R, TryReadError>,
489 {
490 loop {
491 std::task::ready!(self.state.read.poll_ready(cx, self.core))?;
492 match f(self) {
493 Ok(r) => break Poll::Ready(Ok(r)),
494 Err(TryReadError::Empty) => self.state.read.clear_ready(),
495 Err(TryReadError::Pipe(err)) => break Poll::Ready(Err(err)),
496 }
497 }
498 }
499 fn poll_read_bytes(
500 &mut self,
501 cx: &mut Context<'_>,
502 buf: &mut [u8],
503 ) -> Poll<Result<usize, Error>> {
504 self.poll_op(cx, |this| this.try_read_bytes(buf))
505 }
506
507 fn poll_read_message(
508 &mut self,
509 cx: &mut Context<'_>,
510 bufs: &mut [IoSliceMut<'_>],
511 ) -> Poll<Result<usize, Error>> {
512 self.poll_op(cx, |this| this.try_read_message(bufs))
513 }
514}
515
516pub struct MessagePipe<M: RingMem>(Pipe<M>);
518
519impl<M: RingMem> InspectMut for MessagePipe<M> {
520 fn inspect_mut(&mut self, req: inspect::Request<'_>) {
521 self.0.inspect_mut(req)
522 }
523}
524
525pub struct BytePipe<M: RingMem>(Pipe<M>);
528
529impl<M: RingMem> InspectMut for BytePipe<M> {
530 fn inspect_mut(&mut self, req: inspect::Request<'_>) {
531 self.0.inspect_mut(req)
532 }
533}
534
535struct Pipe<M: RingMem> {
537 core: Core<M>,
538 read: PipeReadState,
539 write: PipeWriteState,
540}
541
542impl<M: RingMem> InspectMut for Pipe<M> {
543 fn inspect_mut(&mut self, req: inspect::Request<'_>) {
544 req.respond()
545 .merge(&self.core)
546 .field("incoming_ring", &self.read.read)
547 .field("outgoing_ring", &self.write.state);
548 }
549}
550
551pub struct MessageReadHalf<'a, M: RingMem> {
553 core: &'a Core<M>,
554 read: &'a mut PipeReadState,
555}
556
557pub struct MessageWriteHalf<'a, M: RingMem> {
559 core: &'a Core<M>,
560 write: &'a mut PipeWriteState,
561}
562
563pub struct ByteReadHalf<'a, M: RingMem> {
565 core: &'a Core<M>,
566 read: &'a mut PipeReadState,
567}
568
569pub struct ByteWriteHalf<'a, M: RingMem> {
571 core: &'a Core<M>,
572 write: &'a mut PipeWriteState,
573}
574
575impl<M: RingMem> MessagePipe<M> {
576 pub fn new(channel: RawAsyncChannel<M>) -> io::Result<Self> {
578 Self::new_inner(channel, false)
579 }
580
581 pub fn new_raw(channel: RawAsyncChannel<M>) -> io::Result<Self> {
587 Self::new_inner(channel, true)
588 }
589
590 fn new_inner(channel: RawAsyncChannel<M>, raw: bool) -> io::Result<Self> {
591 let max_payload_len = if raw {
592 channel.out_ring.maximum_packet_size() - ring::PacketSize::in_band(0)
595 } else {
596 cmp::min(
598 ring::MAXIMUM_PIPE_PACKET_SIZE,
599 channel.out_ring.maximum_packet_size()
600 - ring::PacketSize::in_band(size_of::<ring::PipeHeader>()),
601 )
602 };
603
604 let incoming = channel.in_ring.incoming().map_err(Error::Ring)?;
605 let outgoing = channel.out_ring.outgoing().map_err(Error::Ring)?;
606
607 Ok(Self(Pipe {
608 core: Core::new(channel),
609 read: PipeReadState::new(incoming, raw, max_payload_len),
610 write: PipeWriteState::new(outgoing, raw, max_payload_len),
611 }))
612 }
613
614 pub fn split(&mut self) -> (MessageReadHalf<'_, M>, MessageWriteHalf<'_, M>) {
617 (
618 MessageReadHalf {
619 core: &self.0.core,
620 read: &mut self.0.read,
621 },
622 MessageWriteHalf {
623 core: &self.0.core,
624 write: &mut self.0.write,
625 },
626 )
627 }
628
629 pub async fn wait_write_ready(&mut self, send_size: usize) -> io::Result<()> {
632 self.split().1.wait_ready(send_size).await
633 }
634
635 pub fn try_send(&mut self, buf: &[u8]) -> io::Result<()> {
638 self.split().1.try_send(buf)
639 }
640
641 pub fn try_send_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result<()> {
644 self.split().1.try_send_vectored(bufs)
645 }
646}
647
648impl<M: RingMem> BytePipe<M> {
649 pub fn new(channel: RawAsyncChannel<M>) -> io::Result<Self> {
651 let incoming = channel.in_ring.incoming().map_err(Error::Ring)?;
652 let outgoing = channel.out_ring.outgoing().map_err(Error::Ring)?;
653
654 Ok(Self(Pipe {
655 core: Core::new(channel),
656 read: PipeReadState::new(incoming, false, 0),
657 write: PipeWriteState::new(outgoing, false, 0),
658 }))
659 }
660
661 pub fn split(&mut self) -> (ByteReadHalf<'_, M>, ByteWriteHalf<'_, M>) {
664 (
665 ByteReadHalf {
666 core: &self.0.core,
667 read: &mut self.0.read,
668 },
669 ByteWriteHalf {
670 core: &self.0.core,
671 write: &mut self.0.write,
672 },
673 )
674 }
675}
676
677impl<M: RingMem + Unpin> AsyncRead for BytePipe<M> {
678 fn poll_read(
679 self: Pin<&mut Self>,
680 cx: &mut Context<'_>,
681 buf: &mut [u8],
682 ) -> Poll<io::Result<usize>> {
683 let this = self.get_mut();
684 this.0
685 .read
686 .reader(&this.0.core)
687 .poll_read_bytes(cx, buf)
688 .map_err(Into::into)
689 }
690}
691
692impl<M: RingMem + Unpin> AsyncRead for ByteReadHalf<'_, M> {
693 fn poll_read(
694 self: Pin<&mut Self>,
695 cx: &mut Context<'_>,
696 buf: &mut [u8],
697 ) -> Poll<io::Result<usize>> {
698 let this = self.get_mut();
699 this.read
700 .reader(this.core)
701 .poll_read_bytes(cx, buf)
702 .map_err(Into::into)
703 }
704}
705
706impl<M: RingMem + Unpin> AsyncWrite for BytePipe<M> {
707 fn poll_write(
708 self: Pin<&mut Self>,
709 cx: &mut Context<'_>,
710 buf: &[u8],
711 ) -> Poll<io::Result<usize>> {
712 let this = self.get_mut();
713 this.0
714 .write
715 .writer(&this.0.core)
716 .poll_write_bytes(cx, buf)
717 .map_err(Into::into)
718 }
719
720 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
721 Poll::Ready(Ok(()))
722 }
723
724 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
725 let this = self.get_mut();
726 this.0
727 .write
728 .writer(&this.0.core)
729 .poll_shutdown_writes(cx)
730 .map_err(Into::into)
731 }
732}
733
734impl<M: RingMem + Unpin> AsyncWrite for ByteWriteHalf<'_, M> {
735 fn poll_write(
736 self: Pin<&mut Self>,
737 cx: &mut Context<'_>,
738 buf: &[u8],
739 ) -> Poll<io::Result<usize>> {
740 let this = self.get_mut();
741 this.write
742 .writer(this.core)
743 .poll_write_bytes(cx, buf)
744 .map_err(Into::into)
745 }
746
747 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
748 Poll::Ready(Ok(()))
749 }
750
751 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
752 let this = self.get_mut();
753 this.write
754 .writer(this.core)
755 .poll_shutdown_writes(cx)
756 .map_err(Into::into)
757 }
758}
759
760impl<M: RingMem> AsyncRecv for MessagePipe<M> {
761 fn poll_recv(
762 &mut self,
763 cx: &mut Context<'_>,
764 bufs: &mut [IoSliceMut<'_>],
765 ) -> Poll<io::Result<usize>> {
766 self.0
767 .read
768 .reader(&self.0.core)
769 .poll_read_message(cx, bufs)
770 .map_err(Into::into)
771 }
772}
773
774impl<M: RingMem> AsyncSend for MessagePipe<M> {
775 fn poll_send(&mut self, cx: &mut Context<'_>, bufs: &[IoSlice<'_>]) -> Poll<io::Result<()>> {
776 ready!(
777 self.0
778 .write
779 .writer(&self.0.core)
780 .poll_write_message(cx, bufs)
781 )?;
782
783 Poll::Ready(Ok(()))
784 }
785}
786
787impl<M: RingMem> AsyncRecv for MessageReadHalf<'_, M> {
788 fn poll_recv(
789 &mut self,
790 cx: &mut Context<'_>,
791 bufs: &mut [IoSliceMut<'_>],
792 ) -> Poll<io::Result<usize>> {
793 self.read
794 .reader(self.core)
795 .poll_read_message(cx, bufs)
796 .map_err(Into::into)
797 }
798}
799
800impl<M: RingMem> MessageWriteHalf<'_, M> {
801 pub fn poll_ready(&mut self, cx: &mut Context<'_>, send_size: usize) -> Poll<io::Result<()>> {
807 let send_size = if self.write.raw {
808 send_size
809 } else {
810 send_size + size_of::<ring::PipeHeader>()
811 };
812 self.poll_for_ring_space(cx, ring::PacketSize::in_band(send_size))
813 }
814
815 pub fn poll_empty(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
818 self.poll_for_ring_space(cx, self.core.out_ring().maximum_packet_size())
819 }
820
821 fn poll_for_ring_space(&mut self, cx: &mut Context<'_>, size: usize) -> Poll<io::Result<()>> {
822 loop {
823 std::task::ready!(self.write.state.poll_ready(cx, self.core, size))
824 .map_err(Error::from)?;
825 if self
826 .core
827 .out_ring()
828 .can_write(&mut self.write.state.ptrs, size)
829 .map_err(Error::from)?
830 {
831 break;
832 }
833 self.write.state.clear_ready();
834 }
835 Poll::Ready(Ok(()))
836 }
837
838 pub async fn wait_ready(&mut self, send_size: usize) -> io::Result<()> {
844 poll_fn(|cx| self.poll_ready(cx, send_size)).await
845 }
846
847 pub async fn wait_empty(&mut self) -> io::Result<()> {
850 poll_fn(|cx| self.poll_empty(cx)).await
851 }
852
853 pub fn try_send(&mut self, buf: &[u8]) -> io::Result<()> {
856 self.write
857 .writer(self.core)
858 .try_write_message(&[IoSlice::new(buf)])?;
859 Ok(())
860 }
861
862 pub fn try_send_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result<()> {
865 self.write.writer(self.core).try_write_message(bufs)?;
866 Ok(())
867 }
868}
869
870impl<M: RingMem> AsyncSend for MessageWriteHalf<'_, M> {
871 fn poll_send(&mut self, cx: &mut Context<'_>, bufs: &[IoSlice<'_>]) -> Poll<io::Result<()>> {
872 ready!(self.write.writer(self.core).poll_write_message(cx, bufs))?;
873
874 Poll::Ready(Ok(()))
875 }
876}
877
878pub fn connected_message_pipes(
880 ring_size: usize,
881) -> (MessagePipe<FlatRingMem>, MessagePipe<FlatRingMem>) {
882 let (host, guest) = connected_async_channels(ring_size);
883 (
884 MessagePipe::new(host).unwrap(),
885 MessagePipe::new(guest).unwrap(),
886 )
887}
888
889pub fn connected_raw_message_pipes(
892 ring_size: usize,
893) -> (MessagePipe<FlatRingMem>, MessagePipe<FlatRingMem>) {
894 let (host, guest) = connected_async_channels(ring_size);
895 (
896 MessagePipe::new_raw(host).unwrap(),
897 MessagePipe::new_raw(guest).unwrap(),
898 )
899}
900
901pub fn connected_byte_pipes(ring_size: usize) -> (BytePipe<FlatRingMem>, BytePipe<FlatRingMem>) {
903 let (host, guest) = connected_async_channels(ring_size);
904 (BytePipe::new(host).unwrap(), BytePipe::new(guest).unwrap())
905}
906
907#[cfg(test)]
908mod tests {
909 use crate::async_dgram::AsyncRecvExt;
910 use crate::async_dgram::AsyncSendExt;
911 use crate::pipe::connected_byte_pipes;
912 use crate::pipe::connected_message_pipes;
913 use futures::AsyncReadExt;
914 use futures::AsyncWriteExt;
915 use pal_async::DefaultDriver;
916 use pal_async::async_test;
917 use pal_async::timer::PolledTimer;
918 use std::io::ErrorKind;
919 use std::time::Duration;
920 use zerocopy::IntoBytes;
921
922 #[async_test]
923 async fn test_async_channel_close() {
924 let (mut host, guest) = connected_message_pipes(4096);
925 let mut b = [0];
926 assert!(futures::poll!(host.recv(&mut b)).is_pending());
927 drop(guest);
928 assert_eq!(
929 host.recv(&mut b).await.unwrap_err().kind(),
930 ErrorKind::ConnectionReset
931 );
932 }
933
934 #[async_test]
935 async fn test_async_read(driver: DefaultDriver) {
936 let (mut host, mut guest) = connected_message_pipes(4096);
937 let guest_read = async {
938 let mut b = [0; 3];
939 let mut read = guest.recv(&mut b);
940 assert!(futures::poll!(&mut read).is_pending());
941 assert_eq!(read.await.unwrap(), 3);
942 assert_eq!(&b, b"abc");
943 };
944 let host_write = async {
945 let mut timer = PolledTimer::new(&driver);
946 timer.sleep(Duration::from_millis(200)).await;
947 host.send(b"abc").await.unwrap();
948 };
949 futures::future::join(guest_read, host_write).await;
950 }
951
952 #[async_test]
953 async fn test_async_write(driver: DefaultDriver) {
954 let (mut host, mut guest) = connected_message_pipes(4096);
955 let v: Vec<_> = (0..2000_u16).collect();
956 guest.send(v.as_bytes()).await.unwrap();
957 let guest_write = async {
958 let v: Vec<_> = (2000..4000_u16).collect();
959 let mut write = guest.send(v.as_bytes());
960 assert!(futures::poll!(&mut write).is_pending());
961 write.await.unwrap();
962 };
963 let host_read = async {
964 let mut timer = PolledTimer::new(&driver);
965 timer.sleep(Duration::from_millis(200)).await;
966 let mut v = [0_u16; 2000];
967 let n = host.recv(v.as_mut_bytes()).await.unwrap();
968 assert_eq!(n, v.as_bytes().len());
969 assert!(v.iter().copied().eq(0..2000_u16));
970 let n = host.recv(v.as_mut_bytes()).await.unwrap();
971 assert_eq!(n, v.as_bytes().len());
972 assert!(v.iter().copied().eq(2000..4000_u16));
973 };
974 futures::future::join(guest_write, host_read).await;
975 }
976
977 #[async_test]
978 async fn test_byte_pipe(driver: DefaultDriver) {
979 let (mut host, mut guest) = connected_byte_pipes(4096);
980 let guest_write = async {
981 let v: Vec<_> = (0..10000_u16).collect();
982 let mut write = guest.write_all(v.as_bytes());
983 assert!(futures::poll!(&mut write).is_pending());
984 write.await.unwrap();
985 };
986 let host_read = async {
987 let mut timer = PolledTimer::new(&driver);
988 timer.sleep(Duration::from_millis(200)).await;
989 let mut v = [0_u16; 10000];
990 host.read_exact(v.as_mut_bytes()).await.unwrap();
991 assert!(v.iter().copied().eq(0..10000_u16));
992 };
993 futures::future::join(guest_write, host_read).await;
994 }
995}