vmbus_async/
pipe.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! Asynchronous vmbus pipe channels.
5
6use super::core::Core;
7use super::core::ReadState;
8use super::core::WriteState;
9use crate::async_dgram::AsyncRecv;
10use crate::async_dgram::AsyncSend;
11use crate::core::PollError;
12use futures::AsyncRead;
13use futures::AsyncWrite;
14use guestmem::AccessError;
15use guestmem::MemoryRead;
16use guestmem::MemoryWrite;
17use inspect::InspectMut;
18use std::cmp;
19use std::future::poll_fn;
20use std::io;
21use std::io::IoSlice;
22use std::io::IoSliceMut;
23use std::pin::Pin;
24use std::task::Context;
25use std::task::Poll;
26use std::task::ready;
27use thiserror::Error;
28use vmbus_channel::RawAsyncChannel;
29use vmbus_channel::connected_async_channels;
30use vmbus_ring as ring;
31use vmbus_ring::FlatRingMem;
32use vmbus_ring::RingMem;
33use zerocopy::FromZeros;
34use zerocopy::IntoBytes;
35
36#[derive(Debug, Error)]
37enum Error {
38    #[error("the channel has been closed")]
39    ChannelClosed,
40    #[error("packet is too large for the ring")]
41    PacketTooLarge,
42    #[error("unexpected ring packet type")]
43    UnexpectedRingPacketType,
44    #[error("invalid pipe packet type {0:#x}")]
45    InvalidPipePacketType(u32),
46    #[error("ring buffer error")]
47    Ring(#[from] ring::Error),
48    #[error("memory access error")]
49    Access(#[from] AccessError),
50    #[error("partial packet offset is too large")]
51    PartialPacketOffsetTooLarge,
52    #[error(transparent)]
53    Io(#[from] io::Error),
54}
55
56impl From<PollError> for Error {
57    fn from(value: PollError) -> Self {
58        match value {
59            PollError::Ring(err) => Self::Ring(err),
60            PollError::Closed => Self::ChannelClosed,
61        }
62    }
63}
64
65impl From<Error> for io::Error {
66    fn from(err: Error) -> Self {
67        match err {
68            Error::ChannelClosed => {
69                io::Error::new(io::ErrorKind::ConnectionReset, Error::ChannelClosed)
70            }
71            err => io::Error::other(err),
72        }
73    }
74}
75
76#[derive(Debug)]
77enum TryReadError {
78    Empty,
79    Pipe(Error),
80}
81
82impl From<ring::ReadError> for TryReadError {
83    fn from(e: ring::ReadError) -> Self {
84        match e {
85            ring::ReadError::Empty => Self::Empty,
86            ring::ReadError::Corrupt(e) => Self::Pipe(e.into()),
87        }
88    }
89}
90
91impl<T> From<T> for TryReadError
92where
93    Error: From<T>,
94{
95    fn from(e: T) -> Self {
96        Self::Pipe(e.into())
97    }
98}
99
100#[derive(Debug)]
101enum TryWriteError {
102    Full(usize),
103    Pipe(Error),
104}
105
106impl From<ring::WriteError> for TryWriteError {
107    fn from(e: ring::WriteError) -> Self {
108        match e {
109            ring::WriteError::Full(n) => Self::Full(n),
110            ring::WriteError::Corrupt(e) => Self::Pipe(e.into()),
111        }
112    }
113}
114
115impl<T> From<T> for TryWriteError
116where
117    Error: From<T>,
118{
119    fn from(e: T) -> Self {
120        Self::Pipe(e.into())
121    }
122}
123
124impl From<TryWriteError> for io::Error {
125    fn from(e: TryWriteError) -> Self {
126        match e {
127            TryWriteError::Full(_) => {
128                io::Error::new(io::ErrorKind::WouldBlock, "the ring buffer is full")
129            }
130            TryWriteError::Pipe(e) => e.into(),
131        }
132    }
133}
134
135#[derive(Debug)]
136struct PipeWriteState {
137    state: WriteState,
138    raw: bool,
139    max_payload_len: usize,
140}
141
142impl PipeWriteState {
143    fn new(ptrs: ring::OutgoingOffset, raw: bool, max_payload_len: usize) -> Self {
144        Self {
145            state: WriteState::new(ptrs),
146            raw,
147            max_payload_len,
148        }
149    }
150
151    fn writer<'a, M: RingMem>(&'a mut self, core: &'a Core<M>) -> PipeWriter<'a, M> {
152        PipeWriter { write: self, core }
153    }
154}
155
156struct PipeWriter<'a, M: RingMem> {
157    write: &'a mut PipeWriteState,
158    core: &'a Core<M>,
159}
160
161impl<M: RingMem> PipeWriter<'_, M> {
162    /// Tries to write a full message as a ring packet, returning
163    /// Err(TryWriteError::Full(_)) if the ring is full.
164    fn try_write_message(&mut self, bufs: &[IoSlice<'_>]) -> Result<usize, TryWriteError> {
165        let len = bufs.iter().map(|x| x.len()).sum();
166        let mut packet_len = len;
167        if len > self.write.max_payload_len {
168            return Err(TryWriteError::Pipe(Error::PacketTooLarge));
169        }
170        if !self.write.raw {
171            packet_len += size_of::<ring::PipeHeader>();
172        }
173        let mut outgoing = self.write.state.ptrs.clone();
174        let range = self.core.out_ring().write(
175            &mut outgoing,
176            &ring::OutgoingPacket {
177                transaction_id: 0,
178                size: packet_len,
179                typ: ring::OutgoingPacketType::InBandNoCompletion,
180            },
181        )?;
182        let mut writer = range.writer(self.core.out_ring());
183        if !self.write.raw {
184            writer.write(
185                ring::PipeHeader {
186                    packet_type: ring::PIPE_PACKET_TYPE_DATA,
187                    len: len as u32,
188                }
189                .as_bytes(),
190            )?;
191        }
192        for buf in bufs {
193            writer.write(buf)?;
194        }
195        self.write.state.clear_poll(self.core);
196        if self.core.out_ring().commit_write(&mut outgoing) {
197            self.core.signal();
198            self.write.state.signals.increment();
199        }
200        self.write.state.ptrs = outgoing;
201        Ok(len)
202    }
203
204    /// Tries to write `buf` into the ring as a series of packets, possibly
205    /// sending only a portion of the bytes. Returns `Ok(None)` if the ring is
206    /// full.
207    fn try_write_bytes(&mut self, buf: &[u8]) -> Result<usize, TryWriteError> {
208        if buf.is_empty() {
209            return Ok(0);
210        }
211
212        const CHUNK_SIZE: usize = 2048;
213        // Write in packets of CHUNK_SIZE bytes so that the opposite endpoint can remove
214        // packets as it reads data, freeing up more space for writes.
215        let mut written = 0;
216        let mut outgoing = self.write.state.ptrs.clone();
217        for buf in buf.chunks(CHUNK_SIZE) {
218            match self.core.out_ring().write(
219                &mut outgoing,
220                &ring::OutgoingPacket {
221                    transaction_id: 0,
222                    size: buf.len() + size_of::<ring::PipeHeader>(),
223                    typ: ring::OutgoingPacketType::InBandNoCompletion,
224                },
225            ) {
226                Ok(range) => {
227                    let mut writer = range.writer(self.core.out_ring());
228                    writer.write(
229                        ring::PipeHeader {
230                            packet_type: ring::PIPE_PACKET_TYPE_DATA,
231                            len: buf.len() as u32,
232                        }
233                        .as_bytes(),
234                    )?;
235                    writer.write(buf)?;
236                    written += buf.len();
237                }
238                Err(ring::WriteError::Full(n)) => {
239                    if written > 0 {
240                        break;
241                    } else {
242                        return Err(TryWriteError::Full(n));
243                    }
244                }
245                Err(ring::WriteError::Corrupt(err)) => return Err(TryWriteError::Pipe(err.into())),
246            }
247        }
248        assert!(written > 0);
249        if self.core.out_ring().commit_write(&mut outgoing) {
250            self.core.signal();
251            self.write.state.signals.increment();
252        }
253        self.write.state.ptrs = outgoing;
254        Ok(written)
255    }
256
257    /// Notifies the opposite endpoint that no more data will be written
258    /// (similar to TCP's FIN). Requires ring buffer space, so may fail if this
259    /// would block.
260    fn try_shutdown_writes(&mut self) -> Result<(), TryWriteError> {
261        if !self.write.raw {
262            // Write a zero-byte message. Ignore ChannelClosed since the operation
263            // has already succeeded in some sense--the opposite endpoint has
264            // stopped reading data.
265            match self.try_write_message(&[]) {
266                Ok(_) => {}
267                Err(err) => return Err(err),
268            }
269        }
270        Ok(())
271    }
272
273    fn poll_op<F, R>(&mut self, cx: &mut Context<'_>, mut f: F) -> Poll<Result<R, Error>>
274    where
275        F: FnMut(&mut Self) -> Result<R, TryWriteError>,
276    {
277        // Estimate the required send size. Update it later if the send actually fails.
278        let mut send_size = 32;
279        loop {
280            std::task::ready!(self.write.state.poll_ready(cx, self.core, send_size))?;
281            match f(self) {
282                Ok(r) => break Poll::Ready(Ok(r)),
283                Err(TryWriteError::Full(len)) => {
284                    send_size = len;
285                    self.write.state.clear_ready();
286                }
287                Err(TryWriteError::Pipe(e)) => break Poll::Ready(Err(e)),
288            }
289        }
290    }
291
292    fn poll_write_bytes(&mut self, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, Error>> {
293        self.poll_op(cx, |this| this.try_write_bytes(buf))
294    }
295
296    fn poll_write_message(
297        &mut self,
298        cx: &mut Context<'_>,
299        bufs: &[IoSlice<'_>],
300    ) -> Poll<Result<usize, Error>> {
301        self.poll_op(cx, |this| this.try_write_message(bufs))
302    }
303
304    /// Notifies the opposite endpoint that no more data will be written
305    /// (similar to TCP's FIN). Requires ring buffer space, so may fail if this
306    /// would block.
307    fn poll_shutdown_writes(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
308        match self.poll_op(cx, |this| this.try_shutdown_writes()) {
309            Poll::Ready(Err(Error::ChannelClosed)) => {
310                // Treat a closed pipe as a successful shutdown.
311                Poll::Ready(Ok(()))
312            }
313            r => r,
314        }
315    }
316}
317
318#[derive(Debug)]
319struct PipeReadState {
320    read: ReadState,
321    max_payload_len: usize,
322    raw: bool,
323    eof: bool,
324}
325
326impl PipeReadState {
327    fn new(ptrs: ring::IncomingOffset, raw: bool, max_payload_len: usize) -> Self {
328        Self {
329            read: ReadState::new(ptrs),
330            raw,
331            max_payload_len,
332            eof: false,
333        }
334    }
335
336    fn reader<'a, M: RingMem>(&'a mut self, core: &'a Core<M>) -> PipeReader<'a, M> {
337        PipeReader { state: self, core }
338    }
339}
340
341struct PipeReader<'a, M: RingMem> {
342    state: &'a mut PipeReadState,
343    core: &'a Core<M>,
344}
345
346impl<M: RingMem> PipeReader<'_, M> {
347    /// Tries to read the full contents of a single message packet into `bufs`,
348    /// returning `Err(TryReadError::Empty)` if the ring is empty.
349    fn try_read_message(&mut self, bufs: &mut [IoSliceMut<'_>]) -> Result<usize, TryReadError> {
350        let len = bufs.iter().map(|x| x.len()).sum();
351        let mut incoming = self.state.read.ptrs.clone();
352        match self.core.in_ring().read(&mut incoming) {
353            Ok(ring::IncomingPacket {
354                typ: ring::IncomingPacketType::InBand,
355                payload,
356                ..
357            }) => {
358                let mut reader = payload.reader(self.core.in_ring());
359                let bytes_read = if !self.state.raw {
360                    let mut header = ring::PipeHeader::new_zeroed();
361                    reader.read(header.as_mut_bytes())?;
362                    if header.packet_type != ring::PIPE_PACKET_TYPE_DATA {
363                        return Err(TryReadError::Pipe(Error::InvalidPipePacketType(
364                            header.packet_type,
365                        )));
366                    }
367                    header.len as usize // validated by call to payload.reader.read below
368                } else {
369                    payload.len()
370                };
371                if bytes_read > cmp::min(len, self.state.max_payload_len) {
372                    return Err(TryReadError::Pipe(Error::PacketTooLarge));
373                }
374                let mut remaining = bytes_read;
375                for buf in bufs {
376                    if remaining == 0 {
377                        break;
378                    }
379                    let this_len = cmp::min(remaining, buf.len());
380                    remaining -= this_len;
381                    reader.read(&mut buf[..this_len])?;
382                }
383                self.state.read.clear_poll(self.core);
384                if self.core.in_ring().commit_read(&mut incoming) {
385                    self.core.signal();
386                    self.state.read.signals.increment();
387                }
388                self.state.read.ptrs = incoming;
389                Ok(bytes_read)
390            }
391            Ok(_) => Err(TryReadError::Pipe(Error::UnexpectedRingPacketType)),
392            Err(err) => Err(err.into()),
393        }
394    }
395
396    /// Tries to fill `buf` with bytes from the channel, consuming partial or
397    /// full packets. Returns `Err(TryReadError::Empty)` if the ring is empty.
398    fn try_read_bytes(&mut self, buf: &mut [u8]) -> Result<usize, TryReadError> {
399        if buf.is_empty() || self.state.eof {
400            return Ok(0);
401        }
402        let mut incoming = self.state.read.ptrs.clone();
403        let mut commit = incoming.clone();
404        let mut total_read = 0;
405        while total_read < buf.len() {
406            match self.core.in_ring().read(&mut incoming) {
407                Ok(ring::IncomingPacket {
408                    typ: ring::IncomingPacketType::InBand,
409                    payload,
410                    ..
411                }) => {
412                    let mut reader = payload.reader(self.core.in_ring());
413                    let mut header = ring::PipeHeader::new_zeroed();
414                    reader.read(header.as_mut_bytes())?;
415                    let (off, len) = match header.packet_type {
416                        ring::PIPE_PACKET_TYPE_DATA => {
417                            // A zero-byte packet indicates EOF--the opposite
418                            // endpoint will not write any more data. Consume
419                            // the packet only if no other data is being
420                            // returned so that if the channel is saved after
421                            // seeing the zero-byte packet but before a
422                            // zero-byte read is returned, the EOF signal is not
423                            // lost.
424                            //
425                            // Another solution to this would be to leave the
426                            // EOF packet in the ring, but jstarks told the
427                            // Linux kernel devs that they could wait for this
428                            // packet to be consumed to know whether it is safe
429                            // to tear down the ring buffer on the guest side.
430                            if header.len == 0 {
431                                if total_read == 0 {
432                                    self.state.eof = true;
433                                    commit = incoming.clone();
434                                }
435                                break;
436                            }
437                            (0, header.len as usize)
438                        }
439                        ring::PIPE_PACKET_TYPE_PARTIAL => {
440                            // The read offset is stored in the high 16 bits.
441                            // There should be at least one byte remaining;
442                            // otherwise, the packet would have been removed.
443                            let off = header.len >> 16;
444                            let len = header.len & 0xffff;
445                            if off >= len {
446                                return Err(TryReadError::Pipe(Error::PartialPacketOffsetTooLarge));
447                            }
448                            (off as usize, (len - off) as usize)
449                        }
450                        n => return Err(TryReadError::Pipe(Error::InvalidPipePacketType(n))),
451                    };
452                    reader.skip(off)?;
453                    let read = cmp::min(len, buf.len() - total_read);
454                    reader.read(&mut buf[total_read..total_read + read])?;
455                    if read < len {
456                        // Update the ring with the partial packet information.
457                        header.packet_type = ring::PIPE_PACKET_TYPE_PARTIAL;
458                        header.len += (read as u32) << 16;
459                        let mut writer = payload.writer(self.core.in_ring());
460                        writer.write(header.as_bytes())?;
461                    } else {
462                        // The whole packet has been consumed.
463                        commit = incoming.clone();
464                    }
465                    total_read += read;
466                }
467                Ok(_) => return Err(TryReadError::Pipe(Error::UnexpectedRingPacketType)),
468                Err(ring::ReadError::Empty) => break,
469                Err(ring::ReadError::Corrupt(err)) => return Err(err.into()),
470            }
471        }
472        if total_read > 0 || self.state.eof {
473            self.state.read.clear_poll(self.core);
474            if self.core.in_ring().commit_read(&mut commit) {
475                self.core.signal();
476                self.state.read.signals.increment();
477            }
478            self.state.read.ptrs = commit;
479            Ok(total_read)
480        } else {
481            // Need to block to get more data.
482            Err(TryReadError::Empty)
483        }
484    }
485
486    fn poll_op<F, R>(&mut self, cx: &mut Context<'_>, mut f: F) -> Poll<Result<R, Error>>
487    where
488        F: FnMut(&mut Self) -> Result<R, TryReadError>,
489    {
490        loop {
491            std::task::ready!(self.state.read.poll_ready(cx, self.core))?;
492            match f(self) {
493                Ok(r) => break Poll::Ready(Ok(r)),
494                Err(TryReadError::Empty) => self.state.read.clear_ready(),
495                Err(TryReadError::Pipe(err)) => break Poll::Ready(Err(err)),
496            }
497        }
498    }
499    fn poll_read_bytes(
500        &mut self,
501        cx: &mut Context<'_>,
502        buf: &mut [u8],
503    ) -> Poll<Result<usize, Error>> {
504        self.poll_op(cx, |this| this.try_read_bytes(buf))
505    }
506
507    fn poll_read_message(
508        &mut self,
509        cx: &mut Context<'_>,
510        bufs: &mut [IoSliceMut<'_>],
511    ) -> Poll<Result<usize, Error>> {
512        self.poll_op(cx, |this| this.try_read_message(bufs))
513    }
514}
515
516/// An open vmbus pipe in message mode, which can send and receive datagrams.
517pub struct MessagePipe<M: RingMem>(Pipe<M>);
518
519impl<M: RingMem> InspectMut for MessagePipe<M> {
520    fn inspect_mut(&mut self, req: inspect::Request<'_>) {
521        self.0.inspect_mut(req)
522    }
523}
524
525/// An open vmbus pipe in byte mode, which can be read from and written to as a
526/// byte stream.
527pub struct BytePipe<M: RingMem>(Pipe<M>);
528
529impl<M: RingMem> InspectMut for BytePipe<M> {
530    fn inspect_mut(&mut self, req: inspect::Request<'_>) {
531        self.0.inspect_mut(req)
532    }
533}
534
535/// An open pipe.
536struct Pipe<M: RingMem> {
537    core: Core<M>,
538    read: PipeReadState,
539    write: PipeWriteState,
540}
541
542impl<M: RingMem> InspectMut for Pipe<M> {
543    fn inspect_mut(&mut self, req: inspect::Request<'_>) {
544        req.respond()
545            .merge(&self.core)
546            .field("incoming_ring", &self.read.read)
547            .field("outgoing_ring", &self.write.state);
548    }
549}
550
551/// The read half of a pipe.
552pub struct MessageReadHalf<'a, M: RingMem> {
553    core: &'a Core<M>,
554    read: &'a mut PipeReadState,
555}
556
557/// The write half of a pipe.
558pub struct MessageWriteHalf<'a, M: RingMem> {
559    core: &'a Core<M>,
560    write: &'a mut PipeWriteState,
561}
562
563/// The read half of a pipe.
564pub struct ByteReadHalf<'a, M: RingMem> {
565    core: &'a Core<M>,
566    read: &'a mut PipeReadState,
567}
568
569/// The write half of a pipe.
570pub struct ByteWriteHalf<'a, M: RingMem> {
571    core: &'a Core<M>,
572    write: &'a mut PipeWriteState,
573}
574
575impl<M: RingMem> MessagePipe<M> {
576    /// Creates a new pipe from an open channel.
577    pub fn new(channel: RawAsyncChannel<M>) -> io::Result<Self> {
578        Self::new_inner(channel, false)
579    }
580
581    /// Creates a new raw pipe from an open channel.
582    ///
583    /// A raw pipe has no additional framing and sends and receives vmbus
584    /// packets directly. As a result, packet sizes will be rounded up to an
585    /// 8-byte multiple.
586    pub fn new_raw(channel: RawAsyncChannel<M>) -> io::Result<Self> {
587        Self::new_inner(channel, true)
588    }
589
590    fn new_inner(channel: RawAsyncChannel<M>, raw: bool) -> io::Result<Self> {
591        let max_payload_len = if raw {
592            // There is no inherent maximum packet size for non-pipe rings.
593            // Fall back to the ring size.
594            channel.out_ring.maximum_packet_size() - ring::PacketSize::in_band(0)
595        } else {
596            // There is a protocol-specified maximum size.
597            cmp::min(
598                ring::MAXIMUM_PIPE_PACKET_SIZE,
599                channel.out_ring.maximum_packet_size()
600                    - ring::PacketSize::in_band(size_of::<ring::PipeHeader>()),
601            )
602        };
603
604        let incoming = channel.in_ring.incoming().map_err(Error::Ring)?;
605        let outgoing = channel.out_ring.outgoing().map_err(Error::Ring)?;
606
607        Ok(Self(Pipe {
608            core: Core::new(channel),
609            read: PipeReadState::new(incoming, raw, max_payload_len),
610            write: PipeWriteState::new(outgoing, raw, max_payload_len),
611        }))
612    }
613
614    /// Splits the pipe into read and write halves so that reads and writes may
615    /// be concurrently issued.
616    pub fn split(&mut self) -> (MessageReadHalf<'_, M>, MessageWriteHalf<'_, M>) {
617        (
618            MessageReadHalf {
619                core: &self.0.core,
620                read: &mut self.0.read,
621            },
622            MessageWriteHalf {
623                core: &self.0.core,
624                write: &mut self.0.write,
625            },
626        )
627    }
628
629    /// Waits for the outgoing ring buffer to have enough space to write a
630    /// message of size `send_size`.
631    pub async fn wait_write_ready(&mut self, send_size: usize) -> io::Result<()> {
632        self.split().1.wait_ready(send_size).await
633    }
634
635    /// Tries to send a datagram, failing with [`io::ErrorKind::WouldBlock`] if
636    /// there is not enough space in the ring.
637    pub fn try_send(&mut self, buf: &[u8]) -> io::Result<()> {
638        self.split().1.try_send(buf)
639    }
640
641    /// Tries to send a datagram, failing with [`io::ErrorKind::WouldBlock`] if
642    /// there is not enough space in the ring.
643    pub fn try_send_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result<()> {
644        self.split().1.try_send_vectored(bufs)
645    }
646}
647
648impl<M: RingMem> BytePipe<M> {
649    /// Creates a new pipe from an open channel.
650    pub fn new(channel: RawAsyncChannel<M>) -> io::Result<Self> {
651        let incoming = channel.in_ring.incoming().map_err(Error::Ring)?;
652        let outgoing = channel.out_ring.outgoing().map_err(Error::Ring)?;
653
654        Ok(Self(Pipe {
655            core: Core::new(channel),
656            read: PipeReadState::new(incoming, false, 0),
657            write: PipeWriteState::new(outgoing, false, 0),
658        }))
659    }
660
661    /// Splits the pipe into read and write halves so that reads and writes may
662    /// be concurrently issued.
663    pub fn split(&mut self) -> (ByteReadHalf<'_, M>, ByteWriteHalf<'_, M>) {
664        (
665            ByteReadHalf {
666                core: &self.0.core,
667                read: &mut self.0.read,
668            },
669            ByteWriteHalf {
670                core: &self.0.core,
671                write: &mut self.0.write,
672            },
673        )
674    }
675}
676
677impl<M: RingMem + Unpin> AsyncRead for BytePipe<M> {
678    fn poll_read(
679        self: Pin<&mut Self>,
680        cx: &mut Context<'_>,
681        buf: &mut [u8],
682    ) -> Poll<io::Result<usize>> {
683        let this = self.get_mut();
684        this.0
685            .read
686            .reader(&this.0.core)
687            .poll_read_bytes(cx, buf)
688            .map_err(Into::into)
689    }
690}
691
692impl<M: RingMem + Unpin> AsyncRead for ByteReadHalf<'_, M> {
693    fn poll_read(
694        self: Pin<&mut Self>,
695        cx: &mut Context<'_>,
696        buf: &mut [u8],
697    ) -> Poll<io::Result<usize>> {
698        let this = self.get_mut();
699        this.read
700            .reader(this.core)
701            .poll_read_bytes(cx, buf)
702            .map_err(Into::into)
703    }
704}
705
706impl<M: RingMem + Unpin> AsyncWrite for BytePipe<M> {
707    fn poll_write(
708        self: Pin<&mut Self>,
709        cx: &mut Context<'_>,
710        buf: &[u8],
711    ) -> Poll<io::Result<usize>> {
712        let this = self.get_mut();
713        this.0
714            .write
715            .writer(&this.0.core)
716            .poll_write_bytes(cx, buf)
717            .map_err(Into::into)
718    }
719
720    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
721        Poll::Ready(Ok(()))
722    }
723
724    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
725        let this = self.get_mut();
726        this.0
727            .write
728            .writer(&this.0.core)
729            .poll_shutdown_writes(cx)
730            .map_err(Into::into)
731    }
732}
733
734impl<M: RingMem + Unpin> AsyncWrite for ByteWriteHalf<'_, M> {
735    fn poll_write(
736        self: Pin<&mut Self>,
737        cx: &mut Context<'_>,
738        buf: &[u8],
739    ) -> Poll<io::Result<usize>> {
740        let this = self.get_mut();
741        this.write
742            .writer(this.core)
743            .poll_write_bytes(cx, buf)
744            .map_err(Into::into)
745    }
746
747    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
748        Poll::Ready(Ok(()))
749    }
750
751    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
752        let this = self.get_mut();
753        this.write
754            .writer(this.core)
755            .poll_shutdown_writes(cx)
756            .map_err(Into::into)
757    }
758}
759
760impl<M: RingMem> AsyncRecv for MessagePipe<M> {
761    fn poll_recv(
762        &mut self,
763        cx: &mut Context<'_>,
764        bufs: &mut [IoSliceMut<'_>],
765    ) -> Poll<io::Result<usize>> {
766        self.0
767            .read
768            .reader(&self.0.core)
769            .poll_read_message(cx, bufs)
770            .map_err(Into::into)
771    }
772}
773
774impl<M: RingMem> AsyncSend for MessagePipe<M> {
775    fn poll_send(&mut self, cx: &mut Context<'_>, bufs: &[IoSlice<'_>]) -> Poll<io::Result<()>> {
776        ready!(
777            self.0
778                .write
779                .writer(&self.0.core)
780                .poll_write_message(cx, bufs)
781        )?;
782
783        Poll::Ready(Ok(()))
784    }
785}
786
787impl<M: RingMem> AsyncRecv for MessageReadHalf<'_, M> {
788    fn poll_recv(
789        &mut self,
790        cx: &mut Context<'_>,
791        bufs: &mut [IoSliceMut<'_>],
792    ) -> Poll<io::Result<usize>> {
793        self.read
794            .reader(self.core)
795            .poll_read_message(cx, bufs)
796            .map_err(Into::into)
797    }
798}
799
800impl<M: RingMem> MessageWriteHalf<'_, M> {
801    /// Polls the outgoing ring for the ability to send a packet of size
802    /// `send_size`.
803    ///
804    /// `send_size` can be computed by calling `try_write` and extracting the
805    /// size from `TryReadError::Full(send_size)`.
806    pub fn poll_ready(&mut self, cx: &mut Context<'_>, send_size: usize) -> Poll<io::Result<()>> {
807        let send_size = if self.write.raw {
808            send_size
809        } else {
810            send_size + size_of::<ring::PipeHeader>()
811        };
812        self.poll_for_ring_space(cx, ring::PacketSize::in_band(send_size))
813    }
814
815    /// Polls the outgoing ring for being completely empty, indicating that the
816    /// other endpoint has read everything.
817    pub fn poll_empty(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
818        self.poll_for_ring_space(cx, self.core.out_ring().maximum_packet_size())
819    }
820
821    fn poll_for_ring_space(&mut self, cx: &mut Context<'_>, size: usize) -> Poll<io::Result<()>> {
822        loop {
823            std::task::ready!(self.write.state.poll_ready(cx, self.core, size))
824                .map_err(Error::from)?;
825            if self
826                .core
827                .out_ring()
828                .can_write(&mut self.write.state.ptrs, size)
829                .map_err(Error::from)?
830            {
831                break;
832            }
833            self.write.state.clear_ready();
834        }
835        Poll::Ready(Ok(()))
836    }
837
838    /// Waits until there is enough space in the ring to send a packet of size
839    /// `send_size`.
840    ///
841    /// `send_size` can be computed by calling `try_write` and extracting the
842    /// size from `TryReadError::Full(send_size)`.
843    pub async fn wait_ready(&mut self, send_size: usize) -> io::Result<()> {
844        poll_fn(|cx| self.poll_ready(cx, send_size)).await
845    }
846
847    /// Waits until the ring is completely empty, indicating that the other
848    /// endpoint has read everything.
849    pub async fn wait_empty(&mut self) -> io::Result<()> {
850        poll_fn(|cx| self.poll_empty(cx)).await
851    }
852
853    /// Tries to send a datagram, failing with [`io::ErrorKind::WouldBlock`] if
854    /// there is not enough space in the ring.
855    pub fn try_send(&mut self, buf: &[u8]) -> io::Result<()> {
856        self.write
857            .writer(self.core)
858            .try_write_message(&[IoSlice::new(buf)])?;
859        Ok(())
860    }
861
862    /// Tries to send a datagram, failing with [`io::ErrorKind::WouldBlock`] if
863    /// there is not enough space in the ring.
864    pub fn try_send_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result<()> {
865        self.write.writer(self.core).try_write_message(bufs)?;
866        Ok(())
867    }
868}
869
870impl<M: RingMem> AsyncSend for MessageWriteHalf<'_, M> {
871    fn poll_send(&mut self, cx: &mut Context<'_>, bufs: &[IoSlice<'_>]) -> Poll<io::Result<()>> {
872        ready!(self.write.writer(self.core).poll_write_message(cx, bufs))?;
873
874        Poll::Ready(Ok(()))
875    }
876}
877
878/// Creates a pair of async connected message pipes. Useful for testing.
879pub fn connected_message_pipes(
880    ring_size: usize,
881) -> (MessagePipe<FlatRingMem>, MessagePipe<FlatRingMem>) {
882    let (host, guest) = connected_async_channels(ring_size);
883    (
884        MessagePipe::new(host).unwrap(),
885        MessagePipe::new(guest).unwrap(),
886    )
887}
888
889/// Creates a pair of async connected pipes in raw mode (with no vmbus pipe
890/// framing on the packets). Useful for testing.
891pub fn connected_raw_message_pipes(
892    ring_size: usize,
893) -> (MessagePipe<FlatRingMem>, MessagePipe<FlatRingMem>) {
894    let (host, guest) = connected_async_channels(ring_size);
895    (
896        MessagePipe::new_raw(host).unwrap(),
897        MessagePipe::new_raw(guest).unwrap(),
898    )
899}
900
901/// Creates a pair of async connected byte pipes. Useful for testing.
902pub fn connected_byte_pipes(ring_size: usize) -> (BytePipe<FlatRingMem>, BytePipe<FlatRingMem>) {
903    let (host, guest) = connected_async_channels(ring_size);
904    (BytePipe::new(host).unwrap(), BytePipe::new(guest).unwrap())
905}
906
907#[cfg(test)]
908mod tests {
909    use crate::async_dgram::AsyncRecvExt;
910    use crate::async_dgram::AsyncSendExt;
911    use crate::pipe::connected_byte_pipes;
912    use crate::pipe::connected_message_pipes;
913    use futures::AsyncReadExt;
914    use futures::AsyncWriteExt;
915    use pal_async::DefaultDriver;
916    use pal_async::async_test;
917    use pal_async::timer::PolledTimer;
918    use std::io::ErrorKind;
919    use std::time::Duration;
920    use zerocopy::IntoBytes;
921
922    #[async_test]
923    async fn test_async_channel_close() {
924        let (mut host, guest) = connected_message_pipes(4096);
925        let mut b = [0];
926        assert!(futures::poll!(host.recv(&mut b)).is_pending());
927        drop(guest);
928        assert_eq!(
929            host.recv(&mut b).await.unwrap_err().kind(),
930            ErrorKind::ConnectionReset
931        );
932    }
933
934    #[async_test]
935    async fn test_async_read(driver: DefaultDriver) {
936        let (mut host, mut guest) = connected_message_pipes(4096);
937        let guest_read = async {
938            let mut b = [0; 3];
939            let mut read = guest.recv(&mut b);
940            assert!(futures::poll!(&mut read).is_pending());
941            assert_eq!(read.await.unwrap(), 3);
942            assert_eq!(&b, b"abc");
943        };
944        let host_write = async {
945            let mut timer = PolledTimer::new(&driver);
946            timer.sleep(Duration::from_millis(200)).await;
947            host.send(b"abc").await.unwrap();
948        };
949        futures::future::join(guest_read, host_write).await;
950    }
951
952    #[async_test]
953    async fn test_async_write(driver: DefaultDriver) {
954        let (mut host, mut guest) = connected_message_pipes(4096);
955        let v: Vec<_> = (0..2000_u16).collect();
956        guest.send(v.as_bytes()).await.unwrap();
957        let guest_write = async {
958            let v: Vec<_> = (2000..4000_u16).collect();
959            let mut write = guest.send(v.as_bytes());
960            assert!(futures::poll!(&mut write).is_pending());
961            write.await.unwrap();
962        };
963        let host_read = async {
964            let mut timer = PolledTimer::new(&driver);
965            timer.sleep(Duration::from_millis(200)).await;
966            let mut v = [0_u16; 2000];
967            let n = host.recv(v.as_mut_bytes()).await.unwrap();
968            assert_eq!(n, v.as_bytes().len());
969            assert!(v.iter().copied().eq(0..2000_u16));
970            let n = host.recv(v.as_mut_bytes()).await.unwrap();
971            assert_eq!(n, v.as_bytes().len());
972            assert!(v.iter().copied().eq(2000..4000_u16));
973        };
974        futures::future::join(guest_write, host_read).await;
975    }
976
977    #[async_test]
978    async fn test_byte_pipe(driver: DefaultDriver) {
979        let (mut host, mut guest) = connected_byte_pipes(4096);
980        let guest_write = async {
981            let v: Vec<_> = (0..10000_u16).collect();
982            let mut write = guest.write_all(v.as_bytes());
983            assert!(futures::poll!(&mut write).is_pending());
984            write.await.unwrap();
985        };
986        let host_read = async {
987            let mut timer = PolledTimer::new(&driver);
988            timer.sleep(Duration::from_millis(200)).await;
989            let mut v = [0_u16; 10000];
990            host.read_exact(v.as_mut_bytes()).await.unwrap();
991            assert!(v.iter().copied().eq(0..10000_u16));
992        };
993        futures::future::join(guest_write, host_read).await;
994    }
995}