vmbus_async/
queue.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! This module implements the `Queue` type, which provides an abstraction over
5//! a VmBus channel.
6
7use super::core::Core;
8use super::core::ReadState;
9use super::core::WriteState;
10use crate::core::PollError;
11use futures::FutureExt;
12use guestmem::AccessError;
13use guestmem::MemoryRead;
14use guestmem::MemoryWrite;
15use guestmem::ranges::PagedRange;
16use inspect::Inspect;
17use ring::OutgoingPacketType;
18use ring::TransferPageRange;
19use smallvec::smallvec;
20use std::future::Future;
21use std::future::poll_fn;
22use std::ops::Deref;
23use std::task::Context;
24use std::task::Poll;
25use std::task::ready;
26use thiserror::Error;
27use vmbus_channel::RawAsyncChannel;
28use vmbus_channel::connected_async_channels;
29use vmbus_ring as ring;
30use vmbus_ring::FlatRingMem;
31use vmbus_ring::IncomingPacketType;
32use vmbus_ring::IncomingRing;
33use vmbus_ring::RingMem;
34use vmbus_ring::gparange::GpnList;
35use vmbus_ring::gparange::MultiPagedRangeBuf;
36use vmbus_ring::gparange::zeroed_gpn_list;
37use zerocopy::FromBytes;
38use zerocopy::FromZeros;
39use zerocopy::IntoBytes;
40
41/// A queue error.
42#[derive(Debug, Error)]
43#[error(transparent)]
44pub struct Error(Box<ErrorInner>);
45
46impl From<ErrorInner> for Error {
47    fn from(value: ErrorInner) -> Self {
48        Self(Box::new(value))
49    }
50}
51
52impl Error {
53    /// Returns true if the error is due to the channel being closed by the
54    /// remote endpoint.
55    pub fn is_closed_error(&self) -> bool {
56        matches!(self.0.as_ref(), ErrorInner::ChannelClosed)
57    }
58}
59
60#[derive(Debug, Error)]
61enum ErrorInner {
62    /// Failed to access guest memory.
63    #[error("guest memory access error")]
64    Access(#[source] AccessError),
65    /// The ring buffer is corrupted.
66    #[error("ring buffer error")]
67    Ring(#[source] ring::Error),
68    /// The channel has been closed.
69    #[error("the channel has been closed")]
70    ChannelClosed,
71}
72
73impl From<PollError> for ErrorInner {
74    fn from(value: PollError) -> Self {
75        match value {
76            PollError::Ring(ring) => Self::Ring(ring),
77            PollError::Closed => Self::ChannelClosed,
78        }
79    }
80}
81
82/// An error returned by `try_read*` methods.
83#[derive(Debug, Error)]
84pub enum TryReadError {
85    /// The ring is empty.
86    #[error("ring is empty")]
87    Empty,
88    /// Underlying queue error.
89    #[error("queue error")]
90    Queue(#[source] Error),
91}
92
93/// An error returned by `try_write*` methods.
94#[derive(Debug, Error)]
95pub enum TryWriteError {
96    /// The ring is empty.
97    #[error("ring is empty")]
98    Full(usize),
99    /// Underlying queue error.
100    #[error("queue error")]
101    Queue(#[source] Error),
102}
103
104/// An error returned by `read_external_ranges`
105#[derive(Debug, Error)]
106pub enum ExternalDataError {
107    /// The packet is corrupted in some way (e.g. it does not specify a reasonable set of GPA ranges).
108    #[error("invalid gpa ranges")]
109    GpaRange(#[source] vmbus_ring::gparange::Error),
110
111    /// The packet specifies memory that this vmbus cannot read, for some reason.
112    #[error("access error")]
113    Access(#[source] AccessError),
114
115    /// Caller used `read_external_ranges` when the packet contains a buffer id,
116    /// and the caller should have called `read_transfer_ranges`
117    #[error("external data should have been read by calling read_transfer_ranges")]
118    WrongExternalDataType,
119}
120
121/// An incoming packet batch reader.
122pub struct ReadBatch<'a, M: RingMem> {
123    core: &'a Core<M>,
124    read: &'a mut ReadState,
125}
126
127/// The packet iterator for [`ReadBatch`].
128pub struct ReadBatchIter<'a, 'b, M: RingMem>(&'a mut ReadBatch<'b, M>);
129
130impl<'a, M: RingMem> ReadBatch<'a, M> {
131    fn next_priv(&mut self) -> Result<Option<IncomingPacket<'a, M>>, Error> {
132        let mut ptrs = self.read.ptrs.clone();
133        match self.core.in_ring().read(&mut ptrs) {
134            Ok(packet) => {
135                let packet = IncomingPacket::parse(self.core.in_ring(), packet)?;
136                self.read.ptrs = ptrs;
137                Ok(Some(packet))
138            }
139            Err(ring::ReadError::Empty) => Ok(None),
140            Err(ring::ReadError::Corrupt(err)) => Err(ErrorInner::Ring(err).into()),
141        }
142    }
143
144    fn single_packet(mut self) -> Result<Option<PacketRef<'a, M>>, Error> {
145        if let Some(packet) = self.next_priv()? {
146            Ok(Some(PacketRef {
147                batch: self,
148                packet,
149            }))
150        } else {
151            Ok(None)
152        }
153    }
154
155    /// Returns an iterator of the packets.
156    pub fn packets(&mut self) -> ReadBatchIter<'_, 'a, M> {
157        ReadBatchIter(self)
158    }
159}
160
161impl<'a, M: RingMem> Iterator for ReadBatchIter<'a, '_, M> {
162    type Item = Result<IncomingPacket<'a, M>, Error>;
163
164    fn next(&mut self) -> Option<Self::Item> {
165        self.0.next_priv().transpose()
166    }
167}
168
169impl<M: RingMem> Drop for ReadBatch<'_, M> {
170    fn drop(&mut self) {
171        self.read.clear_poll(self.core);
172        if self.core.in_ring().commit_read(&mut self.read.ptrs) {
173            self.core.signal();
174            self.read.signals.increment();
175        }
176    }
177}
178
179/// A reference to a single packet that has not been read out of the ring yet.
180pub struct PacketRef<'a, M: RingMem> {
181    batch: ReadBatch<'a, M>,
182    packet: IncomingPacket<'a, M>,
183}
184
185impl<'a, M: RingMem> Deref for PacketRef<'a, M> {
186    type Target = IncomingPacket<'a, M>;
187
188    fn deref(&self) -> &Self::Target {
189        &self.packet
190    }
191}
192
193impl<'a, M: RingMem> AsRef<IncomingPacket<'a, M>> for PacketRef<'a, M> {
194    fn as_ref(&self) -> &IncomingPacket<'a, M> {
195        self
196    }
197}
198
199impl<M: RingMem> PacketRef<'_, M> {
200    /// Revert the read pointers, allowing a peek at the next packet.
201    ///
202    /// Use this with care: a malicious guest could change the packet's
203    /// contents next time they are read. Any validation on the packet
204    /// needs to be performed again next time the packet is read.
205    pub fn revert(&mut self) {
206        self.batch.read.ptrs.revert();
207    }
208}
209
210/// An incoming packet.
211pub enum IncomingPacket<'a, T: RingMem> {
212    /// A data packet.
213    Data(DataPacket<'a, T>),
214    /// A completion packet.
215    Completion(CompletionPacket<'a, T>),
216}
217
218/// An incoming data packet.
219pub struct DataPacket<'a, T: RingMem> {
220    ring: &'a IncomingRing<T>,
221    payload: ring::RingRange,
222    transaction_id: Option<u64>,
223    buffer_id: Option<u16>,
224    external_data: (u32, ring::RingRange),
225}
226
227impl<T: RingMem> DataPacket<'_, T> {
228    /// A reader for the data payload.
229    ///
230    /// N.B. This reads the payload in place, so multiple instantiations of the
231    /// reader may see multiple different results if the (malicious) opposite
232    /// endpoint is mutating the ring buffer.
233    pub fn reader(&self) -> impl MemoryRead + '_ {
234        self.payload.reader(self.ring)
235    }
236
237    /// The packet's transaction ID. Set if and only if a completion packet was
238    /// requested.
239    pub fn transaction_id(&self) -> Option<u64> {
240        self.transaction_id
241    }
242
243    /// The number of GPA direct ranges.
244    pub fn external_range_count(&self) -> usize {
245        self.external_data.0 as usize
246    }
247
248    fn read_transfer_page_ranges(
249        &self,
250        transfer_buf: &MultiPagedRangeBuf<GpnList>,
251    ) -> Result<MultiPagedRangeBuf<GpnList>, AccessError> {
252        let len = self.external_data.0 as usize;
253        let mut reader = self.external_data.1.reader(self.ring);
254        let available_count = reader.len() / size_of::<TransferPageRange>();
255        if available_count < len {
256            return Err(AccessError::OutOfRange(0, 0));
257        }
258
259        let mut buf: GpnList = smallvec![FromZeros::new_zeroed(); len];
260        reader.read(buf.as_mut_bytes())?;
261
262        // Construct an array of the form [#1 offset/length][page1][page2][...][#2 offset/length][page1][page2]...
263        // See MultiPagedRangeIter for more details.
264        let transfer_buf: GpnList = buf
265            .iter()
266            .map(|range| {
267                let range_data = TransferPageRange::read_from_prefix(range.as_bytes())
268                    .unwrap()
269                    .0; // TODO: zerocopy: use-rest-of-range (https://github.com/microsoft/openvmm/issues/759)
270                let sub_range = transfer_buf
271                    .subrange(
272                        range_data.byte_offset as usize,
273                        range_data.byte_count as usize,
274                    )
275                    .map_err(|_| {
276                        AccessError::OutOfRange(
277                            range_data.byte_offset as usize,
278                            range_data.byte_count as usize,
279                        )
280                    })?;
281                Ok(sub_range.into_buffer())
282            })
283            .collect::<Result<Vec<GpnList>, AccessError>>()?
284            .into_iter()
285            .flatten()
286            .collect();
287        Ok(MultiPagedRangeBuf::new(len, transfer_buf).unwrap())
288    }
289
290    /// Reads the GPA direct range descriptors from the packet.
291    pub fn read_external_ranges(&self) -> Result<MultiPagedRangeBuf<GpnList>, ExternalDataError> {
292        if self.buffer_id.is_some() {
293            return Err(ExternalDataError::WrongExternalDataType);
294        } else if self.external_data.0 == 0 {
295            return Ok(MultiPagedRangeBuf::empty());
296        }
297
298        let mut reader = self.external_data.1.reader(self.ring);
299        let len = reader.len() / 8;
300        let mut buf = zeroed_gpn_list(len);
301        reader
302            .read(buf.as_mut_bytes())
303            .map_err(ExternalDataError::Access)?;
304        MultiPagedRangeBuf::new(self.external_data.0 as usize, buf)
305            .map_err(ExternalDataError::GpaRange)
306    }
307
308    /// Reads the transfer buffer ID from the packet, or None if this is not a transfer packet.
309    pub fn transfer_buffer_id(&self) -> Option<u16> {
310        self.buffer_id
311    }
312
313    /// Reads the transfer descriptors from the packet using the provided buffer. This buffer should be the one
314    /// associated with the value returned from transfer_buffer_id().
315    pub fn read_transfer_ranges<'a, I>(
316        &self,
317        transfer_buf: I,
318    ) -> Result<MultiPagedRangeBuf<GpnList>, AccessError>
319    where
320        I: Iterator<Item = PagedRange<'a>>,
321    {
322        if self.external_data.0 == 0 {
323            return Ok(MultiPagedRangeBuf::empty());
324        }
325
326        let buf: MultiPagedRangeBuf<GpnList> = transfer_buf.collect();
327        self.read_transfer_page_ranges(&buf)
328    }
329}
330
331/// A completion packet.
332pub struct CompletionPacket<'a, T: RingMem> {
333    ring: &'a IncomingRing<T>,
334    payload: ring::RingRange,
335    transaction_id: u64,
336}
337
338impl<T: RingMem> CompletionPacket<'_, T> {
339    /// A reader for the completion payload.
340    pub fn reader(&self) -> impl MemoryRead + '_ {
341        self.payload.reader(self.ring)
342    }
343
344    /// The packet's transaction ID.
345    pub fn transaction_id(&self) -> u64 {
346        self.transaction_id
347    }
348}
349
350impl<'a, T: RingMem> IncomingPacket<'a, T> {
351    fn parse(ring: &'a IncomingRing<T>, packet: ring::IncomingPacket) -> Result<Self, Error> {
352        Ok(match packet.typ {
353            IncomingPacketType::InBand => IncomingPacket::Data(DataPacket {
354                ring,
355                payload: packet.payload,
356                transaction_id: packet.transaction_id,
357                buffer_id: None,
358                external_data: (0, ring::RingRange::empty()),
359            }),
360            IncomingPacketType::GpaDirect(count, ranges) => IncomingPacket::Data(DataPacket {
361                ring,
362                payload: packet.payload,
363                transaction_id: packet.transaction_id,
364                buffer_id: None,
365                external_data: (count, ranges),
366            }),
367            IncomingPacketType::Completion => IncomingPacket::Completion(CompletionPacket {
368                ring,
369                payload: packet.payload,
370                transaction_id: packet.transaction_id.unwrap(),
371            }),
372            IncomingPacketType::TransferPages(id, count, ranges) => {
373                IncomingPacket::Data(DataPacket {
374                    ring,
375                    payload: packet.payload,
376                    transaction_id: packet.transaction_id,
377                    buffer_id: Some(id),
378                    external_data: (count, ranges),
379                })
380            }
381        })
382    }
383}
384
385/// The reader for the incoming ring buffer of a [`Queue`].
386pub struct ReadHalf<'a, M: RingMem> {
387    core: &'a Core<M>,
388    read: &'a mut ReadState,
389}
390
391impl<'a, M: RingMem> ReadHalf<'a, M> {
392    /// Polls the incoming ring for more packets.
393    ///
394    /// This will automatically manage interrupt masking. The queue will keep
395    /// interrupts masked until this is called. Once this is called, interrupts
396    /// will remain unmasked until this or another poll or async read function
397    /// is called again.
398    pub fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
399        ready!(self.read.poll_ready(cx, self.core)).map_err(ErrorInner::from)?;
400        Poll::Ready(Ok(()))
401    }
402
403    /// Polls the incoming ring for more packets and returns a batch reader for
404    /// them.
405    ///
406    /// This will manage interrupt masking as described in [`Self::poll_ready`].
407    pub fn poll_read_batch<'b>(
408        &'b mut self,
409        cx: &mut Context<'_>,
410    ) -> Poll<Result<ReadBatch<'b, M>, Error>> {
411        let batch = loop {
412            std::task::ready!(self.poll_ready(cx))?;
413            if self
414                .core
415                .in_ring()
416                .can_read(&mut self.read.ptrs)
417                .map_err(ErrorInner::Ring)?
418            {
419                break ReadBatch {
420                    core: self.core,
421                    read: self.read,
422                };
423            } else {
424                self.read.clear_ready();
425            }
426        };
427        Poll::Ready(Ok(batch))
428    }
429
430    /// Tries to get a reader for the next batch of packets.
431    pub fn try_read_batch(&mut self) -> Result<ReadBatch<'_, M>, TryReadError> {
432        if self
433            .core
434            .in_ring()
435            .can_read(&mut self.read.ptrs)
436            .map_err(|err| TryReadError::Queue(Error::from(ErrorInner::Ring(err))))?
437        {
438            Ok(ReadBatch {
439                core: self.core,
440                read: self.read,
441            })
442        } else {
443            self.read.clear_ready();
444            Err(TryReadError::Empty)
445        }
446    }
447
448    /// Waits for the next batch of packets to be ready and returns a reader for
449    /// them.
450    ///
451    /// This will manage interrupt masking as described in [`Self::poll_ready`].
452    pub fn read_batch<'b>(&'b mut self) -> BatchRead<'a, 'b, M> {
453        BatchRead(Some(self))
454    }
455
456    /// Tries to read the next packet.
457    ///
458    /// Returns `Err(TryReadError::Empty)` if the ring is empty.
459    pub fn try_read(&mut self) -> Result<PacketRef<'_, M>, TryReadError> {
460        let batch = self.try_read_batch()?;
461        batch
462            .single_packet()
463            .map_err(TryReadError::Queue)?
464            .ok_or(TryReadError::Empty)
465    }
466
467    /// Waits for the next packet to be ready and returns it.
468    pub fn read<'b>(&'b mut self) -> Read<'a, 'b, M> {
469        Read(self.read_batch())
470    }
471
472    /// Indicates whether pending send size notification is supported on
473    /// the vmbus ring.
474    pub fn supports_pending_send_size(&self) -> bool {
475        self.core.in_ring().supports_pending_send_size()
476    }
477}
478
479/// An asynchronous batch read operation.
480pub struct BatchRead<'a, 'b, M: RingMem>(Option<&'a mut ReadHalf<'b, M>>);
481
482impl<'a, M: RingMem> Future for BatchRead<'a, '_, M> {
483    type Output = Result<ReadBatch<'a, M>, Error>;
484
485    fn poll(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
486        let this = self.get_mut();
487        // Rebuild the batch below to get the lifetimes right.
488        let _ = std::task::ready!(this.0.as_mut().unwrap().poll_read_batch(cx))?;
489        let this = this.0.take().unwrap();
490        Poll::Ready(Ok(ReadBatch {
491            core: this.core,
492            read: this.read,
493        }))
494    }
495}
496
497/// An asynchronous read operation.
498pub struct Read<'a, 'b, M: RingMem>(BatchRead<'a, 'b, M>);
499
500impl<'a, M: RingMem> Future for Read<'a, '_, M> {
501    type Output = Result<PacketRef<'a, M>, Error>;
502
503    fn poll(mut self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
504        let batch = std::task::ready!(self.0.poll_unpin(cx))?;
505        Poll::Ready(
506            batch
507                .single_packet()
508                .transpose()
509                .expect("batch was non-empty"),
510        )
511    }
512}
513
514/// An outgoing packet.
515pub struct OutgoingPacket<'a, 'b> {
516    /// The transaction ID. Ignored for `packet_type` of [`OutgoingPacketType::InBandNoCompletion`].
517    pub transaction_id: u64,
518    /// The outgoing packet type.
519    pub packet_type: OutgoingPacketType<'a>,
520    /// The payload, as a list of byte slices.
521    pub payload: &'b [&'b [u8]],
522}
523
524/// The writer for the outgoing ring buffer of a [`Queue`].
525pub struct WriteHalf<'a, M: RingMem> {
526    core: &'a Core<M>,
527    write: &'a mut WriteState,
528}
529
530impl<'a, M: RingMem> WriteHalf<'a, M> {
531    /// Polls the outgoing ring for the ability to send a packet of size
532    /// `send_size`.
533    ///
534    /// `send_size` can be computed by calling `try_write` and extracting the
535    /// size from `TryReadError::Full(send_size)`.
536    pub fn poll_ready(
537        &mut self,
538        cx: &mut Context<'_>,
539        send_size: usize,
540    ) -> Poll<Result<(), Error>> {
541        loop {
542            std::task::ready!(self.write.poll_ready(cx, self.core, send_size))
543                .map_err(ErrorInner::from)?;
544            if self.can_write(send_size)? {
545                break Poll::Ready(Ok(()));
546            }
547        }
548    }
549
550    /// Waits until there is enough space in the ring to send a packet of size
551    /// `send_size`.
552    ///
553    /// `send_size` can be computed by calling `try_write` and extracting the
554    /// size from `TryReadError::Full(send_size)`.
555    pub async fn wait_ready(&mut self, send_size: usize) -> Result<(), Error> {
556        poll_fn(|cx| self.poll_ready(cx, send_size)).await
557    }
558
559    /// Returns an object for writing multiple packets at once.
560    ///
561    /// The batch will be committed when the returned object is dropped.
562    ///
563    /// This reduces the overhead of writing multiple packets by updating the
564    /// ring pointers and signaling an interrupt only once, when the batch is
565    /// committed.
566    pub fn batched(&mut self) -> WriteBatch<'_, M> {
567        WriteBatch {
568            core: self.core,
569            write: self.write,
570        }
571    }
572
573    /// Checks the outgiong ring for the capacity to send a packet of size
574    /// `send_size`.
575    pub fn can_write(&mut self, send_size: usize) -> Result<bool, Error> {
576        self.batched().can_write(send_size)
577    }
578
579    /// The ring's capacity in bytes.
580    pub fn capacity(&self) -> usize {
581        self.core.out_ring().maximum_packet_size()
582    }
583
584    /// Tries to write a packet into the outgoing ring.
585    ///
586    /// Fails with `TryReadError::Full(send_size)` if the ring is full.
587    pub fn try_write(&mut self, packet: &OutgoingPacket<'_, '_>) -> Result<(), TryWriteError> {
588        self.batched().try_write(packet)
589    }
590
591    /// Polls the ring for successful write of `packet`.
592    pub fn poll_write(
593        &mut self,
594        cx: &mut Context<'_>,
595        packet: &OutgoingPacket<'_, '_>,
596    ) -> Poll<Result<(), Error>> {
597        let mut send_size = 32;
598        let r = loop {
599            std::task::ready!(self.write.poll_ready(cx, self.core, send_size))
600                .map_err(ErrorInner::from)?;
601            match self.try_write(packet) {
602                Ok(()) => break Ok(()),
603                Err(TryWriteError::Full(len)) => send_size = len,
604                Err(TryWriteError::Queue(err)) => break Err(err),
605            }
606        };
607        Poll::Ready(r)
608    }
609
610    /// Writes a packet.
611    pub fn write<'b, 'c>(&'b mut self, packet: OutgoingPacket<'c, 'b>) -> Write<'a, 'b, 'c, M> {
612        Write {
613            write: self,
614            packet,
615        }
616    }
617}
618
619/// A batch writer, returned by [`WriteHalf::batched`].
620pub struct WriteBatch<'a, M: RingMem> {
621    core: &'a Core<M>,
622    write: &'a mut WriteState,
623}
624
625impl<M: RingMem> WriteBatch<'_, M> {
626    /// Checks the outgiong ring for the capacity to send a packet of size
627    /// `send_size`.
628    pub fn can_write(&mut self, send_size: usize) -> Result<bool, Error> {
629        let can_write = self
630            .core
631            .out_ring()
632            .can_write(&mut self.write.ptrs, send_size)
633            .map_err(ErrorInner::Ring)?;
634
635        // Ensure that poll_write will check again.
636        if !can_write {
637            self.write.clear_ready();
638        }
639        Ok(can_write)
640    }
641
642    /// Tries to write a packet into the outgoing ring.
643    ///
644    /// Fails with `TryReadError::Full(send_size)` if the ring is full.
645    pub fn try_write(&mut self, packet: &OutgoingPacket<'_, '_>) -> Result<(), TryWriteError> {
646        let size = packet.payload.iter().fold(0, |a, p| a + p.len());
647        let ring_packet = ring::OutgoingPacket {
648            transaction_id: packet.transaction_id,
649            size,
650            typ: packet.packet_type,
651        };
652        let mut ptrs = self.write.ptrs.clone();
653        match self.core.out_ring().write(&mut ptrs, &ring_packet) {
654            Ok(range) => {
655                let mut writer = range.writer(self.core.out_ring());
656                for p in packet.payload.iter().copied() {
657                    writer.write(p).map_err(|err| {
658                        TryWriteError::Queue(Error::from(ErrorInner::Access(err)))
659                    })?;
660                }
661                self.write.clear_poll(self.core);
662                self.write.ptrs = ptrs;
663                Ok(())
664            }
665            Err(ring::WriteError::Full(n)) => {
666                self.write.clear_ready();
667                Err(TryWriteError::Full(n))
668            }
669            Err(ring::WriteError::Corrupt(err)) => {
670                Err(TryWriteError::Queue(ErrorInner::Ring(err).into()))
671            }
672        }
673    }
674}
675
676impl<M: RingMem> Drop for WriteBatch<'_, M> {
677    fn drop(&mut self) {
678        if self.core.out_ring().commit_write(&mut self.write.ptrs) {
679            self.core.signal();
680            self.write.signals.increment();
681        }
682    }
683}
684
685/// An asynchronous packet write operation.
686#[must_use]
687pub struct Write<'a, 'b, 'c, M: RingMem> {
688    write: &'b mut WriteHalf<'a, M>,
689    packet: OutgoingPacket<'c, 'b>,
690}
691
692impl<M: RingMem> Future for Write<'_, '_, '_, M> {
693    type Output = Result<(), Error>;
694
695    fn poll(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
696        let this = self.get_mut();
697        this.write.poll_write(cx, &this.packet)
698    }
699}
700
701/// An abstraction over an open VmBus channel that provides methods to read and
702/// write packets from the ring, as well as poll the ring for readiness.
703///
704/// This is useful when you need to operate on external data packets or send or
705/// receive packets in batch. Otherwise, consider the `Channel` type.
706pub struct Queue<M: RingMem> {
707    core: Core<M>,
708    read: ReadState,
709    write: WriteState,
710}
711
712impl<M: RingMem> Inspect for Queue<M> {
713    fn inspect(&self, req: inspect::Request<'_>) {
714        req.respond()
715            .merge(&self.core)
716            .field("incoming_ring", &self.read)
717            .field("outgoing_ring", &self.write);
718    }
719}
720
721impl<M: RingMem> Queue<M> {
722    /// Constructs a `Queue` object with the given raw channel and given
723    /// configuration.
724    pub fn new(raw: RawAsyncChannel<M>) -> Result<Self, Error> {
725        let incoming = raw.in_ring.incoming().map_err(ErrorInner::Ring)?;
726        let outgoing = raw.out_ring.outgoing().map_err(ErrorInner::Ring)?;
727        let core = Core::new(raw);
728        let read = ReadState::new(incoming);
729        let write = WriteState::new(outgoing);
730
731        Ok(Self { core, read, write })
732    }
733
734    /// Splits the queue into a read half and write half that can be operated on
735    /// independently.
736    pub fn split(&mut self) -> (ReadHalf<'_, M>, WriteHalf<'_, M>) {
737        (
738            ReadHalf {
739                core: &self.core,
740                read: &mut self.read,
741            },
742            WriteHalf {
743                core: &self.core,
744                write: &mut self.write,
745            },
746        )
747    }
748}
749
750/// Returns a pair of connected queues. Useful for testing.
751pub fn connected_queues(ring_size: usize) -> (Queue<FlatRingMem>, Queue<FlatRingMem>) {
752    let (host, guest) = connected_async_channels(ring_size);
753    (Queue::new(host).unwrap(), Queue::new(guest).unwrap())
754}
755
756#[cfg(test)]
757mod tests {
758    use super::*;
759    use pal_async::DefaultDriver;
760    use pal_async::async_test;
761    use pal_async::task::Spawn;
762    use pal_async::timer::PolledTimer;
763    use ring::OutgoingPacketType;
764    use std::future::poll_fn;
765    use std::time::Duration;
766    use vmbus_channel::gpadl::GpadlId;
767    use vmbus_channel::gpadl::GpadlMap;
768
769    #[async_test]
770    async fn test_gpa_direct() {
771        use guestmem::ranges::PagedRange;
772
773        let (mut host_queue, mut guest_queue) = connected_queues(16384);
774
775        let gpa1: Vec<u64> = vec![4096, 8192];
776        let gpa2: Vec<u64> = vec![8192];
777        let gpas = vec![
778            PagedRange::new(20, 4096, &gpa1).unwrap(),
779            PagedRange::new(0, 200, &gpa2).unwrap(),
780        ];
781
782        let payload: &[u8] = &[0xf; 24];
783        guest_queue
784            .split()
785            .1
786            .write(OutgoingPacket {
787                transaction_id: 0,
788                packet_type: OutgoingPacketType::GpaDirect(&gpas),
789                payload: &[payload],
790            })
791            .await
792            .unwrap();
793        host_queue
794            .split()
795            .0
796            .read_batch()
797            .await
798            .unwrap()
799            .packets()
800            .next()
801            .map(|p| match p.unwrap() {
802                IncomingPacket::Data(data) => {
803                    // Check the payload
804                    let mut in_payload = [0_u8; 24];
805                    assert_eq!(payload.len(), data.reader().len());
806                    data.reader().read(&mut in_payload).unwrap();
807                    assert_eq!(in_payload, payload);
808
809                    // Check the external ranges
810                    assert_eq!(data.external_range_count(), 2);
811                    let external_data = data.read_external_ranges().unwrap();
812                    let in_gpas: Vec<PagedRange<'_>> = external_data.iter().collect();
813                    assert_eq!(in_gpas.len(), gpas.len());
814
815                    for (p, q) in in_gpas.iter().zip(gpas) {
816                        assert_eq!(p.offset(), q.offset());
817                        assert_eq!(p.len(), q.len());
818                        assert_eq!(p.gpns(), q.gpns());
819                    }
820                    Ok(())
821                }
822                _ => Err("should be data"),
823            })
824            .unwrap()
825            .unwrap();
826    }
827
828    #[async_test]
829    async fn test_gpa_direct_empty_external_data() {
830        use guestmem::ranges::PagedRange;
831
832        let (mut host_queue, mut guest_queue) = connected_queues(16384);
833
834        let gpa1: Vec<u64> = vec![];
835        let gpas = vec![PagedRange::new(0, 0, &gpa1).unwrap()];
836
837        let payload: &[u8] = &[0xf; 24];
838        guest_queue
839            .split()
840            .1
841            .write(OutgoingPacket {
842                transaction_id: 0,
843                packet_type: OutgoingPacketType::GpaDirect(&gpas),
844                payload: &[payload],
845            })
846            .await
847            .unwrap();
848        host_queue
849            .split()
850            .0
851            .read_batch()
852            .await
853            .unwrap()
854            .packets()
855            .next()
856            .map(|p| match p.unwrap() {
857                IncomingPacket::Data(data) => {
858                    // Check the payload
859                    let mut in_payload = [0_u8; 24];
860                    assert_eq!(payload.len(), data.reader().len());
861                    data.reader().read(&mut in_payload).unwrap();
862                    assert_eq!(in_payload, payload);
863
864                    // Check the external ranges
865                    assert_eq!(data.external_range_count(), 1);
866                    let external_data_result = data.read_external_ranges();
867                    assert_eq!(data.read_external_ranges().is_err(), true);
868                    match external_data_result {
869                        Err(ExternalDataError::GpaRange(_)) => Ok(()),
870                        _ => Err("should be out of range"),
871                    }
872                }
873                _ => Err("should be data"),
874            })
875            .unwrap()
876            .unwrap();
877    }
878
879    #[async_test]
880    async fn test_transfer_pages() {
881        use guestmem::ranges::PagedRange;
882
883        let (mut host_queue, mut guest_queue) = connected_queues(16384);
884
885        let gpadl_map = GpadlMap::new();
886        let buf = vec![0x3000_u64, 1, 2, 3];
887        gpadl_map.add(GpadlId(13), MultiPagedRangeBuf::new(1, buf).unwrap());
888
889        let ranges = vec![
890            TransferPageRange {
891                byte_count: 0x10,
892                byte_offset: 0x10,
893            },
894            TransferPageRange {
895                byte_count: 0x10,
896                byte_offset: 0xfff,
897            },
898            TransferPageRange {
899                byte_count: 0x10,
900                byte_offset: 0x1000,
901            },
902        ];
903
904        let payload: &[u8] = &[0xf; 24];
905        guest_queue
906            .split()
907            .1
908            .write(OutgoingPacket {
909                transaction_id: 0,
910                packet_type: OutgoingPacketType::TransferPages(13, &ranges),
911                payload: &[payload],
912            })
913            .await
914            .unwrap();
915        host_queue
916            .split()
917            .0
918            .read_batch()
919            .await
920            .unwrap()
921            .packets()
922            .next()
923            .map(|p| match p.unwrap() {
924                IncomingPacket::Data(data) => {
925                    // Check the payload
926                    let mut in_payload = [0_u8; 24];
927                    assert_eq!(payload.len(), data.reader().len());
928                    data.reader().read(&mut in_payload).unwrap();
929                    assert_eq!(in_payload, payload);
930
931                    // Check the external ranges
932                    assert_eq!(data.external_range_count(), 3);
933                    let gpadl_map_view = gpadl_map.view();
934                    assert_eq!(data.transfer_buffer_id().unwrap(), 13);
935                    let buffer_range = gpadl_map_view.map(GpadlId(13)).unwrap();
936                    let external_data = data.read_transfer_ranges(buffer_range.iter()).unwrap();
937                    let in_ranges: Vec<PagedRange<'_>> = external_data.iter().collect();
938                    assert_eq!(in_ranges.len(), ranges.len());
939                    assert_eq!(in_ranges[0].offset(), 0x10);
940                    assert_eq!(in_ranges[0].len(), 0x10);
941                    assert_eq!(in_ranges[0].gpns().len(), 1);
942                    assert_eq!(in_ranges[0].gpns()[0], 1);
943
944                    assert_eq!(in_ranges[1].offset(), 0xfff);
945                    assert_eq!(in_ranges[1].len(), 0x10);
946                    assert_eq!(in_ranges[1].gpns().len(), 2);
947                    assert_eq!(in_ranges[1].gpns()[0], 1);
948                    assert_eq!(in_ranges[1].gpns()[1], 2);
949
950                    assert_eq!(in_ranges[2].offset(), 0);
951                    assert_eq!(in_ranges[2].len(), 0x10);
952                    assert_eq!(in_ranges[2].gpns().len(), 1);
953                    assert_eq!(in_ranges[2].gpns()[0], 2);
954
955                    Ok(())
956                }
957                _ => Err("should be data"),
958            })
959            .unwrap()
960            .unwrap();
961    }
962
963    #[async_test]
964    async fn test_ring_full(driver: DefaultDriver) {
965        let (mut host_queue, mut guest_queue) = connected_queues(4096);
966
967        assert!(
968            poll_fn(|cx| host_queue.split().1.poll_ready(cx, 4000))
969                .now_or_never()
970                .is_some()
971        );
972
973        host_queue
974            .split()
975            .1
976            .try_write(&OutgoingPacket {
977                transaction_id: 0,
978                packet_type: OutgoingPacketType::InBandNoCompletion,
979                payload: &[&[0u8; 4000]],
980            })
981            .unwrap();
982
983        let n = match host_queue
984            .split()
985            .1
986            .try_write(&OutgoingPacket {
987                transaction_id: 0,
988                packet_type: OutgoingPacketType::InBandNoCompletion,
989                payload: &[&[0u8; 4000]],
990            })
991            .unwrap_err()
992        {
993            TryWriteError::Full(n) => n,
994            _ => unreachable!(),
995        };
996
997        let mut poll = async move {
998            let mut host_queue = host_queue;
999            poll_fn(|cx| host_queue.split().1.poll_ready(cx, n))
1000                .await
1001                .unwrap();
1002            host_queue
1003        }
1004        .boxed();
1005
1006        assert!(futures::poll!(&mut poll).is_pending());
1007        let poll = driver.spawn("test", poll);
1008
1009        PolledTimer::new(&driver)
1010            .sleep(Duration::from_millis(50))
1011            .await;
1012
1013        guest_queue.split().0.read().await.unwrap();
1014        assert!(guest_queue.split().0.try_read().is_err());
1015
1016        let mut host_queue = poll.await;
1017
1018        host_queue
1019            .split()
1020            .1
1021            .try_write(&OutgoingPacket {
1022                transaction_id: 0,
1023                packet_type: OutgoingPacketType::InBandNoCompletion,
1024                payload: &[&[0u8; 4000]],
1025            })
1026            .unwrap();
1027    }
1028}