1use super::core::Core;
8use super::core::ReadState;
9use super::core::WriteState;
10use crate::core::PollError;
11use futures::FutureExt;
12use guestmem::AccessError;
13use guestmem::MemoryRead;
14use guestmem::MemoryWrite;
15use guestmem::ranges::PagedRange;
16use inspect::Inspect;
17use ring::OutgoingPacketType;
18use ring::TransferPageRange;
19use smallvec::smallvec;
20use std::future::Future;
21use std::future::poll_fn;
22use std::ops::Deref;
23use std::task::Context;
24use std::task::Poll;
25use std::task::ready;
26use thiserror::Error;
27use vmbus_channel::RawAsyncChannel;
28use vmbus_channel::connected_async_channels;
29use vmbus_ring as ring;
30use vmbus_ring::FlatRingMem;
31use vmbus_ring::IncomingPacketType;
32use vmbus_ring::IncomingRing;
33use vmbus_ring::RingMem;
34use vmbus_ring::gparange::GpnList;
35use vmbus_ring::gparange::MultiPagedRangeBuf;
36use vmbus_ring::gparange::zeroed_gpn_list;
37use zerocopy::FromBytes;
38use zerocopy::FromZeros;
39use zerocopy::IntoBytes;
40
41#[derive(Debug, Error)]
43#[error(transparent)]
44pub struct Error(Box<ErrorInner>);
45
46impl From<ErrorInner> for Error {
47 fn from(value: ErrorInner) -> Self {
48 Self(Box::new(value))
49 }
50}
51
52impl Error {
53 pub fn is_closed_error(&self) -> bool {
56 matches!(self.0.as_ref(), ErrorInner::ChannelClosed)
57 }
58}
59
60#[derive(Debug, Error)]
61enum ErrorInner {
62 #[error("guest memory access error")]
64 Access(#[source] AccessError),
65 #[error("ring buffer error")]
67 Ring(#[source] ring::Error),
68 #[error("the channel has been closed")]
70 ChannelClosed,
71}
72
73impl From<PollError> for ErrorInner {
74 fn from(value: PollError) -> Self {
75 match value {
76 PollError::Ring(ring) => Self::Ring(ring),
77 PollError::Closed => Self::ChannelClosed,
78 }
79 }
80}
81
82#[derive(Debug, Error)]
84pub enum TryReadError {
85 #[error("ring is empty")]
87 Empty,
88 #[error("queue error")]
90 Queue(#[source] Error),
91}
92
93#[derive(Debug, Error)]
95pub enum TryWriteError {
96 #[error("ring is empty")]
98 Full(usize),
99 #[error("queue error")]
101 Queue(#[source] Error),
102}
103
104#[derive(Debug, Error)]
106pub enum ExternalDataError {
107 #[error("invalid gpa ranges")]
109 GpaRange(#[source] vmbus_ring::gparange::Error),
110
111 #[error("access error")]
113 Access(#[source] AccessError),
114
115 #[error("external data should have been read by calling read_transfer_ranges")]
118 WrongExternalDataType,
119}
120
121pub struct ReadBatch<'a, M: RingMem> {
123 core: &'a Core<M>,
124 read: &'a mut ReadState,
125}
126
127pub struct ReadBatchIter<'a, 'b, M: RingMem>(&'a mut ReadBatch<'b, M>);
129
130impl<'a, M: RingMem> ReadBatch<'a, M> {
131 fn next_priv(&mut self) -> Result<Option<IncomingPacket<'a, M>>, Error> {
132 let mut ptrs = self.read.ptrs.clone();
133 match self.core.in_ring().read(&mut ptrs) {
134 Ok(packet) => {
135 let packet = IncomingPacket::parse(self.core.in_ring(), packet)?;
136 self.read.ptrs = ptrs;
137 Ok(Some(packet))
138 }
139 Err(ring::ReadError::Empty) => Ok(None),
140 Err(ring::ReadError::Corrupt(err)) => Err(ErrorInner::Ring(err).into()),
141 }
142 }
143
144 fn single_packet(mut self) -> Result<Option<PacketRef<'a, M>>, Error> {
145 if let Some(packet) = self.next_priv()? {
146 Ok(Some(PacketRef {
147 batch: self,
148 packet,
149 }))
150 } else {
151 Ok(None)
152 }
153 }
154
155 pub fn packets(&mut self) -> ReadBatchIter<'_, 'a, M> {
157 ReadBatchIter(self)
158 }
159}
160
161impl<'a, M: RingMem> Iterator for ReadBatchIter<'a, '_, M> {
162 type Item = Result<IncomingPacket<'a, M>, Error>;
163
164 fn next(&mut self) -> Option<Self::Item> {
165 self.0.next_priv().transpose()
166 }
167}
168
169impl<M: RingMem> Drop for ReadBatch<'_, M> {
170 fn drop(&mut self) {
171 self.read.clear_poll(self.core);
172 if self.core.in_ring().commit_read(&mut self.read.ptrs) {
173 self.core.signal();
174 self.read.signals.increment();
175 }
176 }
177}
178
179pub struct PacketRef<'a, M: RingMem> {
181 batch: ReadBatch<'a, M>,
182 packet: IncomingPacket<'a, M>,
183}
184
185impl<'a, M: RingMem> Deref for PacketRef<'a, M> {
186 type Target = IncomingPacket<'a, M>;
187
188 fn deref(&self) -> &Self::Target {
189 &self.packet
190 }
191}
192
193impl<'a, M: RingMem> AsRef<IncomingPacket<'a, M>> for PacketRef<'a, M> {
194 fn as_ref(&self) -> &IncomingPacket<'a, M> {
195 self
196 }
197}
198
199impl<M: RingMem> PacketRef<'_, M> {
200 pub fn revert(&mut self) {
206 self.batch.read.ptrs.revert();
207 }
208}
209
210pub enum IncomingPacket<'a, T: RingMem> {
212 Data(DataPacket<'a, T>),
214 Completion(CompletionPacket<'a, T>),
216}
217
218pub struct DataPacket<'a, T: RingMem> {
220 ring: &'a IncomingRing<T>,
221 payload: ring::RingRange,
222 transaction_id: Option<u64>,
223 buffer_id: Option<u16>,
224 external_data: (u32, ring::RingRange),
225}
226
227impl<T: RingMem> DataPacket<'_, T> {
228 pub fn reader(&self) -> impl MemoryRead + '_ {
234 self.payload.reader(self.ring)
235 }
236
237 pub fn transaction_id(&self) -> Option<u64> {
240 self.transaction_id
241 }
242
243 pub fn external_range_count(&self) -> usize {
245 self.external_data.0 as usize
246 }
247
248 fn read_transfer_page_ranges(
249 &self,
250 transfer_buf: &MultiPagedRangeBuf<GpnList>,
251 ) -> Result<MultiPagedRangeBuf<GpnList>, AccessError> {
252 let len = self.external_data.0 as usize;
253 let mut reader = self.external_data.1.reader(self.ring);
254 let available_count = reader.len() / size_of::<TransferPageRange>();
255 if available_count < len {
256 return Err(AccessError::OutOfRange(0, 0));
257 }
258
259 let mut buf: GpnList = smallvec![FromZeros::new_zeroed(); len];
260 reader.read(buf.as_mut_bytes())?;
261
262 let transfer_buf: GpnList = buf
265 .iter()
266 .map(|range| {
267 let range_data = TransferPageRange::read_from_prefix(range.as_bytes())
268 .unwrap()
269 .0; let sub_range = transfer_buf
271 .subrange(
272 range_data.byte_offset as usize,
273 range_data.byte_count as usize,
274 )
275 .map_err(|_| {
276 AccessError::OutOfRange(
277 range_data.byte_offset as usize,
278 range_data.byte_count as usize,
279 )
280 })?;
281 Ok(sub_range.into_buffer())
282 })
283 .collect::<Result<Vec<GpnList>, AccessError>>()?
284 .into_iter()
285 .flatten()
286 .collect();
287 Ok(MultiPagedRangeBuf::new(len, transfer_buf).unwrap())
288 }
289
290 pub fn read_external_ranges(&self) -> Result<MultiPagedRangeBuf<GpnList>, ExternalDataError> {
292 if self.buffer_id.is_some() {
293 return Err(ExternalDataError::WrongExternalDataType);
294 } else if self.external_data.0 == 0 {
295 return Ok(MultiPagedRangeBuf::empty());
296 }
297
298 let mut reader = self.external_data.1.reader(self.ring);
299 let len = reader.len() / 8;
300 let mut buf = zeroed_gpn_list(len);
301 reader
302 .read(buf.as_mut_bytes())
303 .map_err(ExternalDataError::Access)?;
304 MultiPagedRangeBuf::new(self.external_data.0 as usize, buf)
305 .map_err(ExternalDataError::GpaRange)
306 }
307
308 pub fn transfer_buffer_id(&self) -> Option<u16> {
310 self.buffer_id
311 }
312
313 pub fn read_transfer_ranges<'a, I>(
316 &self,
317 transfer_buf: I,
318 ) -> Result<MultiPagedRangeBuf<GpnList>, AccessError>
319 where
320 I: Iterator<Item = PagedRange<'a>>,
321 {
322 if self.external_data.0 == 0 {
323 return Ok(MultiPagedRangeBuf::empty());
324 }
325
326 let buf: MultiPagedRangeBuf<GpnList> = transfer_buf.collect();
327 self.read_transfer_page_ranges(&buf)
328 }
329}
330
331pub struct CompletionPacket<'a, T: RingMem> {
333 ring: &'a IncomingRing<T>,
334 payload: ring::RingRange,
335 transaction_id: u64,
336}
337
338impl<T: RingMem> CompletionPacket<'_, T> {
339 pub fn reader(&self) -> impl MemoryRead + '_ {
341 self.payload.reader(self.ring)
342 }
343
344 pub fn transaction_id(&self) -> u64 {
346 self.transaction_id
347 }
348}
349
350impl<'a, T: RingMem> IncomingPacket<'a, T> {
351 fn parse(ring: &'a IncomingRing<T>, packet: ring::IncomingPacket) -> Result<Self, Error> {
352 Ok(match packet.typ {
353 IncomingPacketType::InBand => IncomingPacket::Data(DataPacket {
354 ring,
355 payload: packet.payload,
356 transaction_id: packet.transaction_id,
357 buffer_id: None,
358 external_data: (0, ring::RingRange::empty()),
359 }),
360 IncomingPacketType::GpaDirect(count, ranges) => IncomingPacket::Data(DataPacket {
361 ring,
362 payload: packet.payload,
363 transaction_id: packet.transaction_id,
364 buffer_id: None,
365 external_data: (count, ranges),
366 }),
367 IncomingPacketType::Completion => IncomingPacket::Completion(CompletionPacket {
368 ring,
369 payload: packet.payload,
370 transaction_id: packet.transaction_id.unwrap(),
371 }),
372 IncomingPacketType::TransferPages(id, count, ranges) => {
373 IncomingPacket::Data(DataPacket {
374 ring,
375 payload: packet.payload,
376 transaction_id: packet.transaction_id,
377 buffer_id: Some(id),
378 external_data: (count, ranges),
379 })
380 }
381 })
382 }
383}
384
385pub struct ReadHalf<'a, M: RingMem> {
387 core: &'a Core<M>,
388 read: &'a mut ReadState,
389}
390
391impl<'a, M: RingMem> ReadHalf<'a, M> {
392 pub fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
399 ready!(self.read.poll_ready(cx, self.core)).map_err(ErrorInner::from)?;
400 Poll::Ready(Ok(()))
401 }
402
403 pub fn poll_read_batch<'b>(
408 &'b mut self,
409 cx: &mut Context<'_>,
410 ) -> Poll<Result<ReadBatch<'b, M>, Error>> {
411 let batch = loop {
412 std::task::ready!(self.poll_ready(cx))?;
413 if self
414 .core
415 .in_ring()
416 .can_read(&mut self.read.ptrs)
417 .map_err(ErrorInner::Ring)?
418 {
419 break ReadBatch {
420 core: self.core,
421 read: self.read,
422 };
423 } else {
424 self.read.clear_ready();
425 }
426 };
427 Poll::Ready(Ok(batch))
428 }
429
430 pub fn try_read_batch(&mut self) -> Result<ReadBatch<'_, M>, TryReadError> {
432 if self
433 .core
434 .in_ring()
435 .can_read(&mut self.read.ptrs)
436 .map_err(|err| TryReadError::Queue(Error::from(ErrorInner::Ring(err))))?
437 {
438 Ok(ReadBatch {
439 core: self.core,
440 read: self.read,
441 })
442 } else {
443 self.read.clear_ready();
444 Err(TryReadError::Empty)
445 }
446 }
447
448 pub fn read_batch<'b>(&'b mut self) -> BatchRead<'a, 'b, M> {
453 BatchRead(Some(self))
454 }
455
456 pub fn try_read(&mut self) -> Result<PacketRef<'_, M>, TryReadError> {
460 let batch = self.try_read_batch()?;
461 batch
462 .single_packet()
463 .map_err(TryReadError::Queue)?
464 .ok_or(TryReadError::Empty)
465 }
466
467 pub fn read<'b>(&'b mut self) -> Read<'a, 'b, M> {
469 Read(self.read_batch())
470 }
471
472 pub fn supports_pending_send_size(&self) -> bool {
475 self.core.in_ring().supports_pending_send_size()
476 }
477}
478
479pub struct BatchRead<'a, 'b, M: RingMem>(Option<&'a mut ReadHalf<'b, M>>);
481
482impl<'a, M: RingMem> Future for BatchRead<'a, '_, M> {
483 type Output = Result<ReadBatch<'a, M>, Error>;
484
485 fn poll(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
486 let this = self.get_mut();
487 let _ = std::task::ready!(this.0.as_mut().unwrap().poll_read_batch(cx))?;
489 let this = this.0.take().unwrap();
490 Poll::Ready(Ok(ReadBatch {
491 core: this.core,
492 read: this.read,
493 }))
494 }
495}
496
497pub struct Read<'a, 'b, M: RingMem>(BatchRead<'a, 'b, M>);
499
500impl<'a, M: RingMem> Future for Read<'a, '_, M> {
501 type Output = Result<PacketRef<'a, M>, Error>;
502
503 fn poll(mut self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
504 let batch = std::task::ready!(self.0.poll_unpin(cx))?;
505 Poll::Ready(
506 batch
507 .single_packet()
508 .transpose()
509 .expect("batch was non-empty"),
510 )
511 }
512}
513
514pub struct OutgoingPacket<'a, 'b> {
516 pub transaction_id: u64,
518 pub packet_type: OutgoingPacketType<'a>,
520 pub payload: &'b [&'b [u8]],
522}
523
524pub struct WriteHalf<'a, M: RingMem> {
526 core: &'a Core<M>,
527 write: &'a mut WriteState,
528}
529
530impl<'a, M: RingMem> WriteHalf<'a, M> {
531 pub fn poll_ready(
537 &mut self,
538 cx: &mut Context<'_>,
539 send_size: usize,
540 ) -> Poll<Result<(), Error>> {
541 loop {
542 std::task::ready!(self.write.poll_ready(cx, self.core, send_size))
543 .map_err(ErrorInner::from)?;
544 if self.can_write(send_size)? {
545 break Poll::Ready(Ok(()));
546 }
547 }
548 }
549
550 pub async fn wait_ready(&mut self, send_size: usize) -> Result<(), Error> {
556 poll_fn(|cx| self.poll_ready(cx, send_size)).await
557 }
558
559 pub fn batched(&mut self) -> WriteBatch<'_, M> {
567 WriteBatch {
568 core: self.core,
569 write: self.write,
570 }
571 }
572
573 pub fn can_write(&mut self, send_size: usize) -> Result<bool, Error> {
576 self.batched().can_write(send_size)
577 }
578
579 pub fn capacity(&self) -> usize {
581 self.core.out_ring().maximum_packet_size()
582 }
583
584 pub fn try_write(&mut self, packet: &OutgoingPacket<'_, '_>) -> Result<(), TryWriteError> {
588 self.batched().try_write(packet)
589 }
590
591 pub fn poll_write(
593 &mut self,
594 cx: &mut Context<'_>,
595 packet: &OutgoingPacket<'_, '_>,
596 ) -> Poll<Result<(), Error>> {
597 let mut send_size = 32;
598 let r = loop {
599 std::task::ready!(self.write.poll_ready(cx, self.core, send_size))
600 .map_err(ErrorInner::from)?;
601 match self.try_write(packet) {
602 Ok(()) => break Ok(()),
603 Err(TryWriteError::Full(len)) => send_size = len,
604 Err(TryWriteError::Queue(err)) => break Err(err),
605 }
606 };
607 Poll::Ready(r)
608 }
609
610 pub fn write<'b, 'c>(&'b mut self, packet: OutgoingPacket<'c, 'b>) -> Write<'a, 'b, 'c, M> {
612 Write {
613 write: self,
614 packet,
615 }
616 }
617}
618
619pub struct WriteBatch<'a, M: RingMem> {
621 core: &'a Core<M>,
622 write: &'a mut WriteState,
623}
624
625impl<M: RingMem> WriteBatch<'_, M> {
626 pub fn can_write(&mut self, send_size: usize) -> Result<bool, Error> {
629 let can_write = self
630 .core
631 .out_ring()
632 .can_write(&mut self.write.ptrs, send_size)
633 .map_err(ErrorInner::Ring)?;
634
635 if !can_write {
637 self.write.clear_ready();
638 }
639 Ok(can_write)
640 }
641
642 pub fn try_write(&mut self, packet: &OutgoingPacket<'_, '_>) -> Result<(), TryWriteError> {
646 let size = packet.payload.iter().fold(0, |a, p| a + p.len());
647 let ring_packet = ring::OutgoingPacket {
648 transaction_id: packet.transaction_id,
649 size,
650 typ: packet.packet_type,
651 };
652 let mut ptrs = self.write.ptrs.clone();
653 match self.core.out_ring().write(&mut ptrs, &ring_packet) {
654 Ok(range) => {
655 let mut writer = range.writer(self.core.out_ring());
656 for p in packet.payload.iter().copied() {
657 writer.write(p).map_err(|err| {
658 TryWriteError::Queue(Error::from(ErrorInner::Access(err)))
659 })?;
660 }
661 self.write.clear_poll(self.core);
662 self.write.ptrs = ptrs;
663 Ok(())
664 }
665 Err(ring::WriteError::Full(n)) => {
666 self.write.clear_ready();
667 Err(TryWriteError::Full(n))
668 }
669 Err(ring::WriteError::Corrupt(err)) => {
670 Err(TryWriteError::Queue(ErrorInner::Ring(err).into()))
671 }
672 }
673 }
674}
675
676impl<M: RingMem> Drop for WriteBatch<'_, M> {
677 fn drop(&mut self) {
678 if self.core.out_ring().commit_write(&mut self.write.ptrs) {
679 self.core.signal();
680 self.write.signals.increment();
681 }
682 }
683}
684
685#[must_use]
687pub struct Write<'a, 'b, 'c, M: RingMem> {
688 write: &'b mut WriteHalf<'a, M>,
689 packet: OutgoingPacket<'c, 'b>,
690}
691
692impl<M: RingMem> Future for Write<'_, '_, '_, M> {
693 type Output = Result<(), Error>;
694
695 fn poll(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
696 let this = self.get_mut();
697 this.write.poll_write(cx, &this.packet)
698 }
699}
700
701pub struct Queue<M: RingMem> {
707 core: Core<M>,
708 read: ReadState,
709 write: WriteState,
710}
711
712impl<M: RingMem> Inspect for Queue<M> {
713 fn inspect(&self, req: inspect::Request<'_>) {
714 req.respond()
715 .merge(&self.core)
716 .field("incoming_ring", &self.read)
717 .field("outgoing_ring", &self.write);
718 }
719}
720
721impl<M: RingMem> Queue<M> {
722 pub fn new(raw: RawAsyncChannel<M>) -> Result<Self, Error> {
725 let incoming = raw.in_ring.incoming().map_err(ErrorInner::Ring)?;
726 let outgoing = raw.out_ring.outgoing().map_err(ErrorInner::Ring)?;
727 let core = Core::new(raw);
728 let read = ReadState::new(incoming);
729 let write = WriteState::new(outgoing);
730
731 Ok(Self { core, read, write })
732 }
733
734 pub fn split(&mut self) -> (ReadHalf<'_, M>, WriteHalf<'_, M>) {
737 (
738 ReadHalf {
739 core: &self.core,
740 read: &mut self.read,
741 },
742 WriteHalf {
743 core: &self.core,
744 write: &mut self.write,
745 },
746 )
747 }
748}
749
750pub fn connected_queues(ring_size: usize) -> (Queue<FlatRingMem>, Queue<FlatRingMem>) {
752 let (host, guest) = connected_async_channels(ring_size);
753 (Queue::new(host).unwrap(), Queue::new(guest).unwrap())
754}
755
756#[cfg(test)]
757mod tests {
758 use super::*;
759 use pal_async::DefaultDriver;
760 use pal_async::async_test;
761 use pal_async::task::Spawn;
762 use pal_async::timer::PolledTimer;
763 use ring::OutgoingPacketType;
764 use std::future::poll_fn;
765 use std::time::Duration;
766 use vmbus_channel::gpadl::GpadlId;
767 use vmbus_channel::gpadl::GpadlMap;
768
769 #[async_test]
770 async fn test_gpa_direct() {
771 use guestmem::ranges::PagedRange;
772
773 let (mut host_queue, mut guest_queue) = connected_queues(16384);
774
775 let gpa1: Vec<u64> = vec![4096, 8192];
776 let gpa2: Vec<u64> = vec![8192];
777 let gpas = vec![
778 PagedRange::new(20, 4096, &gpa1).unwrap(),
779 PagedRange::new(0, 200, &gpa2).unwrap(),
780 ];
781
782 let payload: &[u8] = &[0xf; 24];
783 guest_queue
784 .split()
785 .1
786 .write(OutgoingPacket {
787 transaction_id: 0,
788 packet_type: OutgoingPacketType::GpaDirect(&gpas),
789 payload: &[payload],
790 })
791 .await
792 .unwrap();
793 host_queue
794 .split()
795 .0
796 .read_batch()
797 .await
798 .unwrap()
799 .packets()
800 .next()
801 .map(|p| match p.unwrap() {
802 IncomingPacket::Data(data) => {
803 let mut in_payload = [0_u8; 24];
805 assert_eq!(payload.len(), data.reader().len());
806 data.reader().read(&mut in_payload).unwrap();
807 assert_eq!(in_payload, payload);
808
809 assert_eq!(data.external_range_count(), 2);
811 let external_data = data.read_external_ranges().unwrap();
812 let in_gpas: Vec<PagedRange<'_>> = external_data.iter().collect();
813 assert_eq!(in_gpas.len(), gpas.len());
814
815 for (p, q) in in_gpas.iter().zip(gpas) {
816 assert_eq!(p.offset(), q.offset());
817 assert_eq!(p.len(), q.len());
818 assert_eq!(p.gpns(), q.gpns());
819 }
820 Ok(())
821 }
822 _ => Err("should be data"),
823 })
824 .unwrap()
825 .unwrap();
826 }
827
828 #[async_test]
829 async fn test_gpa_direct_empty_external_data() {
830 use guestmem::ranges::PagedRange;
831
832 let (mut host_queue, mut guest_queue) = connected_queues(16384);
833
834 let gpa1: Vec<u64> = vec![];
835 let gpas = vec![PagedRange::new(0, 0, &gpa1).unwrap()];
836
837 let payload: &[u8] = &[0xf; 24];
838 guest_queue
839 .split()
840 .1
841 .write(OutgoingPacket {
842 transaction_id: 0,
843 packet_type: OutgoingPacketType::GpaDirect(&gpas),
844 payload: &[payload],
845 })
846 .await
847 .unwrap();
848 host_queue
849 .split()
850 .0
851 .read_batch()
852 .await
853 .unwrap()
854 .packets()
855 .next()
856 .map(|p| match p.unwrap() {
857 IncomingPacket::Data(data) => {
858 let mut in_payload = [0_u8; 24];
860 assert_eq!(payload.len(), data.reader().len());
861 data.reader().read(&mut in_payload).unwrap();
862 assert_eq!(in_payload, payload);
863
864 assert_eq!(data.external_range_count(), 1);
866 let external_data_result = data.read_external_ranges();
867 assert_eq!(data.read_external_ranges().is_err(), true);
868 match external_data_result {
869 Err(ExternalDataError::GpaRange(_)) => Ok(()),
870 _ => Err("should be out of range"),
871 }
872 }
873 _ => Err("should be data"),
874 })
875 .unwrap()
876 .unwrap();
877 }
878
879 #[async_test]
880 async fn test_transfer_pages() {
881 use guestmem::ranges::PagedRange;
882
883 let (mut host_queue, mut guest_queue) = connected_queues(16384);
884
885 let gpadl_map = GpadlMap::new();
886 let buf = vec![0x3000_u64, 1, 2, 3];
887 gpadl_map.add(GpadlId(13), MultiPagedRangeBuf::new(1, buf).unwrap());
888
889 let ranges = vec![
890 TransferPageRange {
891 byte_count: 0x10,
892 byte_offset: 0x10,
893 },
894 TransferPageRange {
895 byte_count: 0x10,
896 byte_offset: 0xfff,
897 },
898 TransferPageRange {
899 byte_count: 0x10,
900 byte_offset: 0x1000,
901 },
902 ];
903
904 let payload: &[u8] = &[0xf; 24];
905 guest_queue
906 .split()
907 .1
908 .write(OutgoingPacket {
909 transaction_id: 0,
910 packet_type: OutgoingPacketType::TransferPages(13, &ranges),
911 payload: &[payload],
912 })
913 .await
914 .unwrap();
915 host_queue
916 .split()
917 .0
918 .read_batch()
919 .await
920 .unwrap()
921 .packets()
922 .next()
923 .map(|p| match p.unwrap() {
924 IncomingPacket::Data(data) => {
925 let mut in_payload = [0_u8; 24];
927 assert_eq!(payload.len(), data.reader().len());
928 data.reader().read(&mut in_payload).unwrap();
929 assert_eq!(in_payload, payload);
930
931 assert_eq!(data.external_range_count(), 3);
933 let gpadl_map_view = gpadl_map.view();
934 assert_eq!(data.transfer_buffer_id().unwrap(), 13);
935 let buffer_range = gpadl_map_view.map(GpadlId(13)).unwrap();
936 let external_data = data.read_transfer_ranges(buffer_range.iter()).unwrap();
937 let in_ranges: Vec<PagedRange<'_>> = external_data.iter().collect();
938 assert_eq!(in_ranges.len(), ranges.len());
939 assert_eq!(in_ranges[0].offset(), 0x10);
940 assert_eq!(in_ranges[0].len(), 0x10);
941 assert_eq!(in_ranges[0].gpns().len(), 1);
942 assert_eq!(in_ranges[0].gpns()[0], 1);
943
944 assert_eq!(in_ranges[1].offset(), 0xfff);
945 assert_eq!(in_ranges[1].len(), 0x10);
946 assert_eq!(in_ranges[1].gpns().len(), 2);
947 assert_eq!(in_ranges[1].gpns()[0], 1);
948 assert_eq!(in_ranges[1].gpns()[1], 2);
949
950 assert_eq!(in_ranges[2].offset(), 0);
951 assert_eq!(in_ranges[2].len(), 0x10);
952 assert_eq!(in_ranges[2].gpns().len(), 1);
953 assert_eq!(in_ranges[2].gpns()[0], 2);
954
955 Ok(())
956 }
957 _ => Err("should be data"),
958 })
959 .unwrap()
960 .unwrap();
961 }
962
963 #[async_test]
964 async fn test_ring_full(driver: DefaultDriver) {
965 let (mut host_queue, mut guest_queue) = connected_queues(4096);
966
967 assert!(
968 poll_fn(|cx| host_queue.split().1.poll_ready(cx, 4000))
969 .now_or_never()
970 .is_some()
971 );
972
973 host_queue
974 .split()
975 .1
976 .try_write(&OutgoingPacket {
977 transaction_id: 0,
978 packet_type: OutgoingPacketType::InBandNoCompletion,
979 payload: &[&[0u8; 4000]],
980 })
981 .unwrap();
982
983 let n = match host_queue
984 .split()
985 .1
986 .try_write(&OutgoingPacket {
987 transaction_id: 0,
988 packet_type: OutgoingPacketType::InBandNoCompletion,
989 payload: &[&[0u8; 4000]],
990 })
991 .unwrap_err()
992 {
993 TryWriteError::Full(n) => n,
994 _ => unreachable!(),
995 };
996
997 let mut poll = async move {
998 let mut host_queue = host_queue;
999 poll_fn(|cx| host_queue.split().1.poll_ready(cx, n))
1000 .await
1001 .unwrap();
1002 host_queue
1003 }
1004 .boxed();
1005
1006 assert!(futures::poll!(&mut poll).is_pending());
1007 let poll = driver.spawn("test", poll);
1008
1009 PolledTimer::new(&driver)
1010 .sleep(Duration::from_millis(50))
1011 .await;
1012
1013 guest_queue.split().0.read().await.unwrap();
1014 assert!(guest_queue.split().0.try_read().is_err());
1015
1016 let mut host_queue = poll.await;
1017
1018 host_queue
1019 .split()
1020 .1
1021 .try_write(&OutgoingPacket {
1022 transaction_id: 0,
1023 packet_type: OutgoingPacketType::InBandNoCompletion,
1024 payload: &[&[0u8; 4000]],
1025 })
1026 .unwrap();
1027 }
1028}