vmbus_async/
queue.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! This module implements the `Queue` type, which provides an abstraction over
5//! a VmBus channel.
6
7use super::core::Core;
8use super::core::ReadState;
9use super::core::WriteState;
10use crate::core::PollError;
11use futures::FutureExt;
12use guestmem::AccessError;
13use guestmem::MemoryRead;
14use guestmem::MemoryWrite;
15use guestmem::ranges::PagedRange;
16use inspect::Inspect;
17use ring::OutgoingPacketType;
18use ring::TransferPageRange;
19use std::future::Future;
20use std::future::poll_fn;
21use std::ops::Deref;
22use std::task::Context;
23use std::task::Poll;
24use std::task::ready;
25use thiserror::Error;
26use vmbus_channel::RawAsyncChannel;
27use vmbus_channel::connected_async_channels;
28use vmbus_ring as ring;
29use vmbus_ring::FlatRingMem;
30use vmbus_ring::IncomingPacketType;
31use vmbus_ring::IncomingRing;
32use vmbus_ring::RingMem;
33use vmbus_ring::gparange::MultiPagedRangeBuf;
34use zerocopy::IntoBytes;
35
36/// A queue error.
37#[derive(Debug, Error)]
38#[error(transparent)]
39pub struct Error(Box<ErrorInner>);
40
41impl From<ErrorInner> for Error {
42    fn from(value: ErrorInner) -> Self {
43        Self(Box::new(value))
44    }
45}
46
47impl Error {
48    /// Returns true if the error is due to the channel being closed by the
49    /// remote endpoint.
50    pub fn is_closed_error(&self) -> bool {
51        matches!(self.0.as_ref(), ErrorInner::ChannelClosed)
52    }
53}
54
55#[derive(Debug, Error)]
56enum ErrorInner {
57    /// Failed to access guest memory.
58    #[error("guest memory access error")]
59    Access(#[source] AccessError),
60    /// The ring buffer is corrupted.
61    #[error("ring buffer error")]
62    Ring(#[source] ring::Error),
63    /// The channel has been closed.
64    #[error("the channel has been closed")]
65    ChannelClosed,
66}
67
68impl From<PollError> for ErrorInner {
69    fn from(value: PollError) -> Self {
70        match value {
71            PollError::Ring(ring) => Self::Ring(ring),
72            PollError::Closed => Self::ChannelClosed,
73        }
74    }
75}
76
77/// An error returned by `try_read*` methods.
78#[derive(Debug, Error)]
79pub enum TryReadError {
80    /// The ring is empty.
81    #[error("ring is empty")]
82    Empty,
83    /// Underlying queue error.
84    #[error("queue error")]
85    Queue(#[source] Error),
86}
87
88/// An error returned by `try_write*` methods.
89#[derive(Debug, Error)]
90pub enum TryWriteError {
91    /// The ring is empty.
92    #[error("ring is empty")]
93    Full(usize),
94    /// Underlying queue error.
95    #[error("queue error")]
96    Queue(#[source] Error),
97}
98
99/// An error returned by `read_external_ranges`
100#[derive(Debug, Error)]
101pub enum ExternalDataError {
102    /// The packet is corrupted in some way (e.g. it does not specify a reasonable set of GPA ranges).
103    #[error("invalid gpa ranges")]
104    GpaRange(#[source] vmbus_ring::gparange::Error),
105
106    /// The packet specifies memory that this vmbus cannot read, for some reason.
107    #[error("access error")]
108    Access(#[source] AccessError),
109
110    /// Caller used `read_external_ranges` when the packet contains a buffer id,
111    /// and the caller should have called `read_transfer_ranges`
112    #[error("external data should have been read by calling read_transfer_ranges")]
113    WrongExternalDataType,
114}
115
116/// An incoming packet batch reader.
117pub struct ReadBatch<'a, M: RingMem> {
118    core: &'a Core<M>,
119    read: &'a mut ReadState,
120}
121
122/// The packet iterator for [`ReadBatch`].
123pub struct ReadBatchIter<'a, 'b, M: RingMem>(&'a mut ReadBatch<'b, M>);
124
125impl<'a, M: RingMem> ReadBatch<'a, M> {
126    fn next_priv(&mut self) -> Result<Option<IncomingPacket<'a, M>>, Error> {
127        let mut ptrs = self.read.ptrs.clone();
128        match self.core.in_ring().read(&mut ptrs) {
129            Ok(packet) => {
130                let packet = IncomingPacket::parse(self.core.in_ring(), packet)?;
131                self.read.ptrs = ptrs;
132                Ok(Some(packet))
133            }
134            Err(ring::ReadError::Empty) => Ok(None),
135            Err(ring::ReadError::Corrupt(err)) => Err(ErrorInner::Ring(err).into()),
136        }
137    }
138
139    fn single_packet(mut self) -> Result<Option<PacketRef<'a, M>>, Error> {
140        if let Some(packet) = self.next_priv()? {
141            Ok(Some(PacketRef {
142                batch: self,
143                packet,
144            }))
145        } else {
146            Ok(None)
147        }
148    }
149
150    /// Returns an iterator of the packets.
151    pub fn packets(&mut self) -> ReadBatchIter<'_, 'a, M> {
152        ReadBatchIter(self)
153    }
154}
155
156impl<'a, M: RingMem> Iterator for ReadBatchIter<'a, '_, M> {
157    type Item = Result<IncomingPacket<'a, M>, Error>;
158
159    fn next(&mut self) -> Option<Self::Item> {
160        self.0.next_priv().transpose()
161    }
162}
163
164impl<M: RingMem> Drop for ReadBatch<'_, M> {
165    fn drop(&mut self) {
166        self.read.clear_poll(self.core);
167        if self.core.in_ring().commit_read(&mut self.read.ptrs) {
168            self.core.signal();
169            self.read.signals.increment();
170        }
171    }
172}
173
174/// A reference to a single packet that has not been read out of the ring yet.
175pub struct PacketRef<'a, M: RingMem> {
176    batch: ReadBatch<'a, M>,
177    packet: IncomingPacket<'a, M>,
178}
179
180impl<'a, M: RingMem> Deref for PacketRef<'a, M> {
181    type Target = IncomingPacket<'a, M>;
182
183    fn deref(&self) -> &Self::Target {
184        &self.packet
185    }
186}
187
188impl<'a, M: RingMem> AsRef<IncomingPacket<'a, M>> for PacketRef<'a, M> {
189    fn as_ref(&self) -> &IncomingPacket<'a, M> {
190        self
191    }
192}
193
194impl<M: RingMem> PacketRef<'_, M> {
195    /// Revert the read pointers, allowing a peek at the next packet.
196    ///
197    /// Use this with care: a malicious guest could change the packet's
198    /// contents next time they are read. Any validation on the packet
199    /// needs to be performed again next time the packet is read.
200    pub fn revert(&mut self) {
201        self.batch.read.ptrs.revert();
202    }
203}
204
205/// An incoming packet.
206pub enum IncomingPacket<'a, T: RingMem> {
207    /// A data packet.
208    Data(DataPacket<'a, T>),
209    /// A completion packet.
210    Completion(CompletionPacket<'a, T>),
211}
212
213/// An incoming data packet.
214pub struct DataPacket<'a, T: RingMem> {
215    ring: &'a IncomingRing<T>,
216    payload: ring::RingRange,
217    transaction_id: Option<u64>,
218    buffer_id: Option<u16>,
219    external_data: (u32, ring::RingRange),
220}
221
222impl<T: RingMem> DataPacket<'_, T> {
223    /// A reader for the data payload.
224    ///
225    /// N.B. This reads the payload in place, so multiple instantiations of the
226    /// reader may see multiple different results if the (malicious) opposite
227    /// endpoint is mutating the ring buffer.
228    pub fn reader(&self) -> vmbus_ring::RingRangeReader<'_, T> {
229        self.payload.reader(self.ring)
230    }
231
232    /// The packet's transaction ID. Set if and only if a completion packet was
233    /// requested.
234    pub fn transaction_id(&self) -> Option<u64> {
235        self.transaction_id
236    }
237
238    /// The number of GPA direct ranges.
239    pub fn external_range_count(&self) -> usize {
240        self.external_data.0 as usize
241    }
242
243    fn read_transfer_page_ranges(
244        &self,
245        transfer_buf: PagedRange<'_>,
246        result: &mut MultiPagedRangeBuf,
247    ) -> Result<(), AccessError> {
248        let len = self.external_data.0 as usize;
249        let mut reader = self.external_data.1.reader(self.ring);
250        let available_count = reader.len() / size_of::<TransferPageRange>();
251        if available_count < len {
252            return Err(AccessError::OutOfRange(0, 0));
253        }
254
255        for _ in 0..len {
256            let range = reader.read_plain::<TransferPageRange>()?;
257            result.push_range(
258                transfer_buf
259                    .try_subrange(range.byte_offset as usize, range.byte_count as usize)
260                    .ok_or(AccessError::OutOfRange(
261                        range.byte_offset as usize,
262                        range.byte_count as usize,
263                    ))?,
264            );
265        }
266        Ok(())
267    }
268
269    /// Reads the GPA direct range descriptors from the packet, extending `buf`.
270    pub fn read_external_ranges(
271        &self,
272        buf: &mut MultiPagedRangeBuf,
273    ) -> Result<(), ExternalDataError> {
274        if self.buffer_id.is_some() {
275            return Err(ExternalDataError::WrongExternalDataType);
276        } else if self.external_data.0 == 0 {
277            return Ok(());
278        }
279
280        let mut reader = self.external_data.1.reader(self.ring);
281        buf.try_extend_with(reader.len() / 8, self.external_data.0 as usize, |b| {
282            reader
283                .read(b.as_mut_bytes())
284                .map_err(ExternalDataError::Access)?;
285            Ok(())
286        })?
287        .map_err(ExternalDataError::GpaRange)?;
288        Ok(())
289    }
290
291    /// Returns the transfer buffer ID of the packet, or None if this is not a transfer packet.
292    pub fn transfer_buffer_id(&self) -> Option<u16> {
293        self.buffer_id
294    }
295
296    /// Reads the transfer descriptors from the packet using the provided
297    /// buffer, extending `result`.
298    ///
299    /// This buffer should be the one associated with the value returned from
300    /// transfer_buffer_id().
301    pub fn read_transfer_ranges(
302        &self,
303        transfer_buf: PagedRange<'_>,
304        result: &mut MultiPagedRangeBuf,
305    ) -> Result<(), AccessError> {
306        if self.external_data.0 != 0 {
307            self.read_transfer_page_ranges(transfer_buf, result)?;
308        }
309        Ok(())
310    }
311}
312
313/// A completion packet.
314pub struct CompletionPacket<'a, T: RingMem> {
315    ring: &'a IncomingRing<T>,
316    payload: ring::RingRange,
317    transaction_id: u64,
318}
319
320impl<T: RingMem> CompletionPacket<'_, T> {
321    /// A reader for the completion payload.
322    pub fn reader(&self) -> impl MemoryRead + '_ {
323        self.payload.reader(self.ring)
324    }
325
326    /// The packet's transaction ID.
327    pub fn transaction_id(&self) -> u64 {
328        self.transaction_id
329    }
330}
331
332impl<'a, T: RingMem> IncomingPacket<'a, T> {
333    fn parse(ring: &'a IncomingRing<T>, packet: ring::IncomingPacket) -> Result<Self, Error> {
334        Ok(match packet.typ {
335            IncomingPacketType::InBand => IncomingPacket::Data(DataPacket {
336                ring,
337                payload: packet.payload,
338                transaction_id: packet.transaction_id,
339                buffer_id: None,
340                external_data: (0, ring::RingRange::empty()),
341            }),
342            IncomingPacketType::GpaDirect(count, ranges) => IncomingPacket::Data(DataPacket {
343                ring,
344                payload: packet.payload,
345                transaction_id: packet.transaction_id,
346                buffer_id: None,
347                external_data: (count, ranges),
348            }),
349            IncomingPacketType::Completion => IncomingPacket::Completion(CompletionPacket {
350                ring,
351                payload: packet.payload,
352                transaction_id: packet.transaction_id.unwrap(),
353            }),
354            IncomingPacketType::TransferPages(id, count, ranges) => {
355                IncomingPacket::Data(DataPacket {
356                    ring,
357                    payload: packet.payload,
358                    transaction_id: packet.transaction_id,
359                    buffer_id: Some(id),
360                    external_data: (count, ranges),
361                })
362            }
363        })
364    }
365}
366
367/// The reader for the incoming ring buffer of a [`Queue`].
368pub struct ReadHalf<'a, M: RingMem> {
369    core: &'a Core<M>,
370    read: &'a mut ReadState,
371}
372
373impl<'a, M: RingMem> ReadHalf<'a, M> {
374    /// Polls the incoming ring for more packets.
375    ///
376    /// This will automatically manage interrupt masking. The queue will keep
377    /// interrupts masked until this is called. Once this is called, interrupts
378    /// will remain unmasked until this or another poll or async read function
379    /// is called again.
380    pub fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
381        ready!(self.read.poll_ready(cx, self.core)).map_err(ErrorInner::from)?;
382        Poll::Ready(Ok(()))
383    }
384
385    /// Polls the incoming ring for more packets and returns a batch reader for
386    /// them.
387    ///
388    /// This will manage interrupt masking as described in [`Self::poll_ready`].
389    pub fn poll_read_batch<'b>(
390        &'b mut self,
391        cx: &mut Context<'_>,
392    ) -> Poll<Result<ReadBatch<'b, M>, Error>> {
393        let batch = loop {
394            std::task::ready!(self.poll_ready(cx))?;
395            if self
396                .core
397                .in_ring()
398                .can_read(&mut self.read.ptrs)
399                .map_err(ErrorInner::Ring)?
400            {
401                break ReadBatch {
402                    core: self.core,
403                    read: self.read,
404                };
405            } else {
406                self.read.clear_ready();
407            }
408        };
409        Poll::Ready(Ok(batch))
410    }
411
412    /// Tries to get a reader for the next batch of packets.
413    pub fn try_read_batch(&mut self) -> Result<ReadBatch<'_, M>, TryReadError> {
414        if self
415            .core
416            .in_ring()
417            .can_read(&mut self.read.ptrs)
418            .map_err(|err| TryReadError::Queue(Error::from(ErrorInner::Ring(err))))?
419        {
420            Ok(ReadBatch {
421                core: self.core,
422                read: self.read,
423            })
424        } else {
425            self.read.clear_ready();
426            Err(TryReadError::Empty)
427        }
428    }
429
430    /// Waits for the next batch of packets to be ready and returns a reader for
431    /// them.
432    ///
433    /// This will manage interrupt masking as described in [`Self::poll_ready`].
434    pub fn read_batch<'b>(&'b mut self) -> BatchRead<'a, 'b, M> {
435        BatchRead(Some(self))
436    }
437
438    /// Tries to read the next packet.
439    ///
440    /// Returns `Err(TryReadError::Empty)` if the ring is empty.
441    pub fn try_read(&mut self) -> Result<PacketRef<'_, M>, TryReadError> {
442        let batch = self.try_read_batch()?;
443        batch
444            .single_packet()
445            .map_err(TryReadError::Queue)?
446            .ok_or(TryReadError::Empty)
447    }
448
449    /// Waits for the next packet to be ready and returns it.
450    pub fn read<'b>(&'b mut self) -> Read<'a, 'b, M> {
451        Read(self.read_batch())
452    }
453
454    /// Indicates whether pending send size notification is supported on
455    /// the vmbus ring.
456    pub fn supports_pending_send_size(&self) -> bool {
457        self.core.in_ring().supports_pending_send_size()
458    }
459}
460
461/// An asynchronous batch read operation.
462pub struct BatchRead<'a, 'b, M: RingMem>(Option<&'a mut ReadHalf<'b, M>>);
463
464impl<'a, M: RingMem> Future for BatchRead<'a, '_, M> {
465    type Output = Result<ReadBatch<'a, M>, Error>;
466
467    fn poll(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
468        let this = self.get_mut();
469        // Rebuild the batch below to get the lifetimes right.
470        let _ = std::task::ready!(this.0.as_mut().unwrap().poll_read_batch(cx))?;
471        let this = this.0.take().unwrap();
472        Poll::Ready(Ok(ReadBatch {
473            core: this.core,
474            read: this.read,
475        }))
476    }
477}
478
479/// An asynchronous read operation.
480pub struct Read<'a, 'b, M: RingMem>(BatchRead<'a, 'b, M>);
481
482impl<'a, M: RingMem> Future for Read<'a, '_, M> {
483    type Output = Result<PacketRef<'a, M>, Error>;
484
485    fn poll(mut self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
486        let batch = std::task::ready!(self.0.poll_unpin(cx))?;
487        Poll::Ready(
488            batch
489                .single_packet()
490                .transpose()
491                .expect("batch was non-empty"),
492        )
493    }
494}
495
496/// An outgoing packet.
497pub struct OutgoingPacket<'a, 'b> {
498    /// The transaction ID. Ignored for `packet_type` of [`OutgoingPacketType::InBandNoCompletion`].
499    pub transaction_id: u64,
500    /// The outgoing packet type.
501    pub packet_type: OutgoingPacketType<'a>,
502    /// The payload, as a list of byte slices.
503    pub payload: &'b [&'b [u8]],
504}
505
506/// The writer for the outgoing ring buffer of a [`Queue`].
507pub struct WriteHalf<'a, M: RingMem> {
508    core: &'a Core<M>,
509    write: &'a mut WriteState,
510}
511
512impl<'a, M: RingMem> WriteHalf<'a, M> {
513    /// Polls the outgoing ring for the ability to send a packet of size
514    /// `send_size`.
515    ///
516    /// `send_size` can be computed by calling `try_write` and extracting the
517    /// size from `TryReadError::Full(send_size)`.
518    pub fn poll_ready(
519        &mut self,
520        cx: &mut Context<'_>,
521        send_size: usize,
522    ) -> Poll<Result<(), Error>> {
523        loop {
524            std::task::ready!(self.write.poll_ready(cx, self.core, send_size))
525                .map_err(ErrorInner::from)?;
526            if self.can_write(send_size)? {
527                break Poll::Ready(Ok(()));
528            }
529        }
530    }
531
532    /// Waits until there is enough space in the ring to send a packet of size
533    /// `send_size`.
534    ///
535    /// `send_size` can be computed by calling `try_write` and extracting the
536    /// size from `TryReadError::Full(send_size)`.
537    pub async fn wait_ready(&mut self, send_size: usize) -> Result<(), Error> {
538        poll_fn(|cx| self.poll_ready(cx, send_size)).await
539    }
540
541    /// Returns an object for writing multiple packets at once.
542    ///
543    /// The batch will be committed when the returned object is dropped.
544    ///
545    /// This reduces the overhead of writing multiple packets by updating the
546    /// ring pointers and signaling an interrupt only once, when the batch is
547    /// committed.
548    pub fn batched(&mut self) -> WriteBatch<'_, M> {
549        WriteBatch {
550            core: self.core,
551            write: self.write,
552        }
553    }
554
555    /// Checks the outgiong ring for the capacity to send a packet of size
556    /// `send_size`.
557    pub fn can_write(&mut self, send_size: usize) -> Result<bool, Error> {
558        self.batched().can_write(send_size)
559    }
560
561    /// The ring's capacity in bytes.
562    pub fn capacity(&self) -> usize {
563        self.core.out_ring().maximum_packet_size()
564    }
565
566    /// Tries to write a packet into the outgoing ring.
567    ///
568    /// Fails with `TryReadError::Full(send_size)` if the ring is full.
569    pub fn try_write(&mut self, packet: &OutgoingPacket<'_, '_>) -> Result<(), TryWriteError> {
570        self.batched().try_write(packet)
571    }
572
573    /// Polls the ring for successful write of `packet`.
574    pub fn poll_write(
575        &mut self,
576        cx: &mut Context<'_>,
577        packet: &OutgoingPacket<'_, '_>,
578    ) -> Poll<Result<(), Error>> {
579        let mut send_size = 32;
580        let r = loop {
581            std::task::ready!(self.write.poll_ready(cx, self.core, send_size))
582                .map_err(ErrorInner::from)?;
583            match self.try_write(packet) {
584                Ok(()) => break Ok(()),
585                Err(TryWriteError::Full(len)) => send_size = len,
586                Err(TryWriteError::Queue(err)) => break Err(err),
587            }
588        };
589        Poll::Ready(r)
590    }
591
592    /// Writes a packet.
593    pub fn write<'b, 'c>(&'b mut self, packet: OutgoingPacket<'c, 'b>) -> Write<'a, 'b, 'c, M> {
594        Write {
595            write: self,
596            packet,
597        }
598    }
599}
600
601/// A batch writer, returned by [`WriteHalf::batched`].
602pub struct WriteBatch<'a, M: RingMem> {
603    core: &'a Core<M>,
604    write: &'a mut WriteState,
605}
606
607impl<'a, M: RingMem> WriteBatch<'a, M> {
608    /// Checks the outgiong ring for the capacity to send a packet of size
609    /// `send_size`.
610    pub fn can_write(&mut self, send_size: usize) -> Result<bool, Error> {
611        let can_write = self
612            .core
613            .out_ring()
614            .can_write(&mut self.write.ptrs, send_size)
615            .map_err(ErrorInner::Ring)?;
616
617        // Ensure that poll_write will check again.
618        if !can_write {
619            self.write.clear_ready();
620        }
621        Ok(can_write)
622    }
623
624    /// Tries to write a packet into the outgoing ring.
625    ///
626    /// Fails with `TryReadError::Full(send_size)` if the ring is full.
627    pub fn try_write(&mut self, packet: &OutgoingPacket<'_, '_>) -> Result<(), TryWriteError> {
628        let size = packet.payload.iter().fold(0, |a, p| a + p.len());
629        let ring_packet = ring::OutgoingPacket {
630            transaction_id: packet.transaction_id,
631            size,
632            typ: packet.packet_type,
633        };
634        let mut builder = self.try_start_write(&ring_packet)?;
635        let mut writer = builder.writer();
636        for &p in packet.payload {
637            writer
638                .write(p)
639                .map_err(|err| TryWriteError::Queue(ErrorInner::Access(err).into()))?;
640        }
641        builder.finish();
642        Ok(())
643    }
644
645    /// Tries to write a packet with an aligned u64 payload into the outgoing
646    /// ring.
647    ///
648    /// This is more efficient than `try_write` when the payload is a contiguous
649    /// multiple of 8 bytes.
650    pub fn try_write_aligned(
651        &mut self,
652        transaction_id: u64,
653        packet_type: OutgoingPacketType<'_>,
654        data: &[u64],
655    ) -> Result<(), TryWriteError> {
656        let size = data.len() * 8;
657        let ring_packet = ring::OutgoingPacket {
658            transaction_id,
659            size,
660            typ: packet_type,
661        };
662        let mut builder = self.try_start_write(&ring_packet)?;
663        builder.write_aligned_full(data);
664        builder.finish();
665        Ok(())
666    }
667
668    fn try_start_write(
669        &mut self,
670        packet: &ring::OutgoingPacket<'_>,
671    ) -> Result<WritePacketBuilder<'_, 'a, M>, TryWriteError> {
672        let mut ptrs = self.write.ptrs.clone();
673        match self.core.out_ring().write(&mut ptrs, packet) {
674            Ok(range) => Ok(WritePacketBuilder {
675                batch: self,
676                range,
677                ptrs,
678            }),
679            Err(ring::WriteError::Full(n)) => {
680                self.write.clear_ready();
681                Err(TryWriteError::Full(n))
682            }
683            Err(ring::WriteError::Corrupt(err)) => {
684                Err(TryWriteError::Queue(ErrorInner::Ring(err).into()))
685            }
686        }
687    }
688}
689
690struct WritePacketBuilder<'a, 'b, M: RingMem> {
691    batch: &'a mut WriteBatch<'b, M>,
692    range: ring::RingRange,
693    ptrs: ring::OutgoingOffset,
694}
695
696impl<M: RingMem> WritePacketBuilder<'_, '_, M> {
697    fn writer(&mut self) -> vmbus_ring::RingRangeWriter<'_, M> {
698        self.range.writer(self.batch.core.out_ring())
699    }
700
701    fn write_aligned_full(&mut self, data: &[u64]) {
702        self.range
703            .write_aligned_full(self.batch.core.out_ring(), data)
704    }
705
706    fn finish(self) {
707        self.batch.write.clear_poll(self.batch.core);
708        self.batch.write.ptrs = self.ptrs;
709    }
710}
711
712impl<M: RingMem> Drop for WriteBatch<'_, M> {
713    fn drop(&mut self) {
714        if self.core.out_ring().commit_write(&mut self.write.ptrs) {
715            self.core.signal();
716            self.write.signals.increment();
717        }
718    }
719}
720
721/// An asynchronous packet write operation.
722#[must_use]
723pub struct Write<'a, 'b, 'c, M: RingMem> {
724    write: &'b mut WriteHalf<'a, M>,
725    packet: OutgoingPacket<'c, 'b>,
726}
727
728impl<M: RingMem> Future for Write<'_, '_, '_, M> {
729    type Output = Result<(), Error>;
730
731    fn poll(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
732        let this = self.get_mut();
733        this.write.poll_write(cx, &this.packet)
734    }
735}
736
737/// An abstraction over an open VmBus channel that provides methods to read and
738/// write packets from the ring, as well as poll the ring for readiness.
739///
740/// This is useful when you need to operate on external data packets or send or
741/// receive packets in batch. Otherwise, consider the `Channel` type.
742pub struct Queue<M: RingMem> {
743    core: Core<M>,
744    read: ReadState,
745    write: WriteState,
746}
747
748impl<M: RingMem> Inspect for Queue<M> {
749    fn inspect(&self, req: inspect::Request<'_>) {
750        req.respond()
751            .merge(&self.core)
752            .field("incoming_ring", &self.read)
753            .field("outgoing_ring", &self.write);
754    }
755}
756
757impl<M: RingMem> Queue<M> {
758    /// Constructs a `Queue` object with the given raw channel and given
759    /// configuration.
760    pub fn new(raw: RawAsyncChannel<M>) -> Result<Self, Error> {
761        let incoming = raw.in_ring.incoming().map_err(ErrorInner::Ring)?;
762        let outgoing = raw.out_ring.outgoing().map_err(ErrorInner::Ring)?;
763        let core = Core::new(raw);
764        let read = ReadState::new(incoming);
765        let write = WriteState::new(outgoing);
766
767        Ok(Self { core, read, write })
768    }
769
770    /// Splits the queue into a read half and write half that can be operated on
771    /// independently.
772    pub fn split(&mut self) -> (ReadHalf<'_, M>, WriteHalf<'_, M>) {
773        (
774            ReadHalf {
775                core: &self.core,
776                read: &mut self.read,
777            },
778            WriteHalf {
779                core: &self.core,
780                write: &mut self.write,
781            },
782        )
783    }
784}
785
786/// Returns a pair of connected queues. Useful for testing.
787pub fn connected_queues(ring_size: usize) -> (Queue<FlatRingMem>, Queue<FlatRingMem>) {
788    let (host, guest) = connected_async_channels(ring_size);
789    (Queue::new(host).unwrap(), Queue::new(guest).unwrap())
790}
791
792#[cfg(test)]
793mod tests {
794    use super::*;
795    use pal_async::DefaultDriver;
796    use pal_async::async_test;
797    use pal_async::task::Spawn;
798    use pal_async::timer::PolledTimer;
799    use ring::OutgoingPacketType;
800    use std::future::poll_fn;
801    use std::time::Duration;
802    use vmbus_channel::gpadl::GpadlId;
803    use vmbus_channel::gpadl::GpadlMap;
804
805    #[async_test]
806    async fn test_gpa_direct() {
807        use guestmem::ranges::PagedRange;
808
809        let (mut host_queue, mut guest_queue) = connected_queues(16384);
810
811        let gpa1: Vec<u64> = vec![4096, 8192];
812        let gpa2: Vec<u64> = vec![8192];
813        let gpas = vec![
814            PagedRange::new(20, 4096, &gpa1).unwrap(),
815            PagedRange::new(0, 200, &gpa2).unwrap(),
816        ];
817
818        let payload: &[u8] = &[0xf; 24];
819        guest_queue
820            .split()
821            .1
822            .write(OutgoingPacket {
823                transaction_id: 0,
824                packet_type: OutgoingPacketType::GpaDirect(&gpas),
825                payload: &[payload],
826            })
827            .await
828            .unwrap();
829        host_queue
830            .split()
831            .0
832            .read_batch()
833            .await
834            .unwrap()
835            .packets()
836            .next()
837            .map(|p| match p.unwrap() {
838                IncomingPacket::Data(data) => {
839                    // Check the payload
840                    let mut in_payload = [0_u8; 24];
841                    assert_eq!(payload.len(), data.reader().len());
842                    data.reader().read(&mut in_payload).unwrap();
843                    assert_eq!(in_payload, payload);
844
845                    // Check the external ranges
846                    assert_eq!(data.external_range_count(), 2);
847                    let mut external_data = MultiPagedRangeBuf::new();
848                    data.read_external_ranges(&mut external_data).unwrap();
849                    let in_gpas: Vec<PagedRange<'_>> = external_data.iter().collect();
850                    assert_eq!(in_gpas.len(), gpas.len());
851
852                    for (p, q) in in_gpas.iter().zip(gpas) {
853                        assert_eq!(p.offset(), q.offset());
854                        assert_eq!(p.len(), q.len());
855                        assert_eq!(p.gpns(), q.gpns());
856                    }
857                    Ok(())
858                }
859                _ => Err("should be data"),
860            })
861            .unwrap()
862            .unwrap();
863    }
864
865    #[async_test]
866    async fn test_gpa_direct_empty_external_data() {
867        use guestmem::ranges::PagedRange;
868
869        let (mut host_queue, mut guest_queue) = connected_queues(16384);
870
871        let gpa1: Vec<u64> = vec![];
872        let gpas = vec![PagedRange::new(0, 0, &gpa1).unwrap()];
873
874        let payload: &[u8] = &[0xf; 24];
875        guest_queue
876            .split()
877            .1
878            .write(OutgoingPacket {
879                transaction_id: 0,
880                packet_type: OutgoingPacketType::GpaDirect(&gpas),
881                payload: &[payload],
882            })
883            .await
884            .unwrap();
885        host_queue
886            .split()
887            .0
888            .read_batch()
889            .await
890            .unwrap()
891            .packets()
892            .next()
893            .map(|p| match p.unwrap() {
894                IncomingPacket::Data(data) => {
895                    // Check the payload
896                    let mut in_payload = [0_u8; 24];
897                    assert_eq!(payload.len(), data.reader().len());
898                    data.reader().read(&mut in_payload).unwrap();
899                    assert_eq!(in_payload, payload);
900
901                    // Check the external ranges
902                    assert_eq!(data.external_range_count(), 1);
903                    let mut external_data = MultiPagedRangeBuf::new();
904                    let external_data_result = data.read_external_ranges(&mut external_data);
905                    match external_data_result {
906                        Err(ExternalDataError::GpaRange(_)) => Ok(()),
907                        _ => Err("should be out of range"),
908                    }
909                }
910                _ => Err("should be data"),
911            })
912            .unwrap()
913            .unwrap();
914    }
915
916    #[async_test]
917    async fn test_transfer_pages() {
918        use guestmem::ranges::PagedRange;
919
920        let (mut host_queue, mut guest_queue) = connected_queues(16384);
921
922        let gpadl_map = GpadlMap::new();
923        let buf = vec![0x3000_u64, 1, 2, 3];
924        gpadl_map.add(
925            GpadlId(13),
926            MultiPagedRangeBuf::from_range_buffer(1, buf).unwrap(),
927        );
928
929        let ranges = vec![
930            TransferPageRange {
931                byte_count: 0x10,
932                byte_offset: 0x10,
933            },
934            TransferPageRange {
935                byte_count: 0x10,
936                byte_offset: 0xfff,
937            },
938            TransferPageRange {
939                byte_count: 0x10,
940                byte_offset: 0x1000,
941            },
942        ];
943
944        let payload: &[u8] = &[0xf; 24];
945        guest_queue
946            .split()
947            .1
948            .write(OutgoingPacket {
949                transaction_id: 0,
950                packet_type: OutgoingPacketType::TransferPages(13, &ranges),
951                payload: &[payload],
952            })
953            .await
954            .unwrap();
955        host_queue
956            .split()
957            .0
958            .read_batch()
959            .await
960            .unwrap()
961            .packets()
962            .next()
963            .map(|p| match p.unwrap() {
964                IncomingPacket::Data(data) => {
965                    // Check the payload
966                    let mut in_payload = [0_u8; 24];
967                    assert_eq!(payload.len(), data.reader().len());
968                    data.reader().read(&mut in_payload).unwrap();
969                    assert_eq!(in_payload, payload);
970
971                    // Check the external ranges
972                    assert_eq!(data.external_range_count(), 3);
973                    let gpadl_map_view = gpadl_map.view();
974                    assert_eq!(data.transfer_buffer_id().unwrap(), 13);
975                    let buffer_gpadl = gpadl_map_view.map(GpadlId(13)).unwrap();
976                    let buffer_range = buffer_gpadl.first().unwrap();
977                    let mut external_data = MultiPagedRangeBuf::new();
978                    data.read_transfer_ranges(buffer_range, &mut external_data)
979                        .unwrap();
980                    let in_ranges: Vec<PagedRange<'_>> = external_data.iter().collect();
981                    assert_eq!(in_ranges.len(), ranges.len());
982                    assert_eq!(in_ranges[0].offset(), 0x10);
983                    assert_eq!(in_ranges[0].len(), 0x10);
984                    assert_eq!(in_ranges[0].gpns().len(), 1);
985                    assert_eq!(in_ranges[0].gpns()[0], 1);
986
987                    assert_eq!(in_ranges[1].offset(), 0xfff);
988                    assert_eq!(in_ranges[1].len(), 0x10);
989                    assert_eq!(in_ranges[1].gpns().len(), 2);
990                    assert_eq!(in_ranges[1].gpns()[0], 1);
991                    assert_eq!(in_ranges[1].gpns()[1], 2);
992
993                    assert_eq!(in_ranges[2].offset(), 0);
994                    assert_eq!(in_ranges[2].len(), 0x10);
995                    assert_eq!(in_ranges[2].gpns().len(), 1);
996                    assert_eq!(in_ranges[2].gpns()[0], 2);
997
998                    Ok(())
999                }
1000                _ => Err("should be data"),
1001            })
1002            .unwrap()
1003            .unwrap();
1004    }
1005
1006    #[async_test]
1007    async fn test_ring_full(driver: DefaultDriver) {
1008        let (mut host_queue, mut guest_queue) = connected_queues(4096);
1009
1010        assert!(
1011            poll_fn(|cx| host_queue.split().1.poll_ready(cx, 4000))
1012                .now_or_never()
1013                .is_some()
1014        );
1015
1016        host_queue
1017            .split()
1018            .1
1019            .try_write(&OutgoingPacket {
1020                transaction_id: 0,
1021                packet_type: OutgoingPacketType::InBandNoCompletion,
1022                payload: &[&[0u8; 4000]],
1023            })
1024            .unwrap();
1025
1026        let n = match host_queue
1027            .split()
1028            .1
1029            .try_write(&OutgoingPacket {
1030                transaction_id: 0,
1031                packet_type: OutgoingPacketType::InBandNoCompletion,
1032                payload: &[&[0u8; 4000]],
1033            })
1034            .unwrap_err()
1035        {
1036            TryWriteError::Full(n) => n,
1037            _ => unreachable!(),
1038        };
1039
1040        let mut poll = async move {
1041            let mut host_queue = host_queue;
1042            poll_fn(|cx| host_queue.split().1.poll_ready(cx, n))
1043                .await
1044                .unwrap();
1045            host_queue
1046        }
1047        .boxed();
1048
1049        assert!(futures::poll!(&mut poll).is_pending());
1050        let poll = driver.spawn("test", poll);
1051
1052        PolledTimer::new(&driver)
1053            .sleep(Duration::from_millis(50))
1054            .await;
1055
1056        guest_queue.split().0.read().await.unwrap();
1057        assert!(guest_queue.split().0.try_read().is_err());
1058
1059        let mut host_queue = poll.await;
1060
1061        host_queue
1062            .split()
1063            .1
1064            .try_write(&OutgoingPacket {
1065                transaction_id: 0,
1066                packet_type: OutgoingPacketType::InBandNoCompletion,
1067                payload: &[&[0u8; 4000]],
1068            })
1069            .unwrap();
1070    }
1071}