1use super::core::Core;
8use super::core::ReadState;
9use super::core::WriteState;
10use crate::core::PollError;
11use futures::FutureExt;
12use guestmem::AccessError;
13use guestmem::MemoryRead;
14use guestmem::MemoryWrite;
15use guestmem::ranges::PagedRange;
16use inspect::Inspect;
17use ring::OutgoingPacketType;
18use ring::TransferPageRange;
19use std::future::Future;
20use std::future::poll_fn;
21use std::ops::Deref;
22use std::task::Context;
23use std::task::Poll;
24use std::task::ready;
25use thiserror::Error;
26use vmbus_channel::RawAsyncChannel;
27use vmbus_channel::connected_async_channels;
28use vmbus_ring as ring;
29use vmbus_ring::FlatRingMem;
30use vmbus_ring::IncomingPacketType;
31use vmbus_ring::IncomingRing;
32use vmbus_ring::RingMem;
33use vmbus_ring::gparange::MultiPagedRangeBuf;
34use zerocopy::IntoBytes;
35
36#[derive(Debug, Error)]
38#[error(transparent)]
39pub struct Error(Box<ErrorInner>);
40
41impl From<ErrorInner> for Error {
42 fn from(value: ErrorInner) -> Self {
43 Self(Box::new(value))
44 }
45}
46
47impl Error {
48 pub fn is_closed_error(&self) -> bool {
51 matches!(self.0.as_ref(), ErrorInner::ChannelClosed)
52 }
53}
54
55#[derive(Debug, Error)]
56enum ErrorInner {
57 #[error("guest memory access error")]
59 Access(#[source] AccessError),
60 #[error("ring buffer error")]
62 Ring(#[source] ring::Error),
63 #[error("the channel has been closed")]
65 ChannelClosed,
66}
67
68impl From<PollError> for ErrorInner {
69 fn from(value: PollError) -> Self {
70 match value {
71 PollError::Ring(ring) => Self::Ring(ring),
72 PollError::Closed => Self::ChannelClosed,
73 }
74 }
75}
76
77#[derive(Debug, Error)]
79pub enum TryReadError {
80 #[error("ring is empty")]
82 Empty,
83 #[error("queue error")]
85 Queue(#[source] Error),
86}
87
88#[derive(Debug, Error)]
90pub enum TryWriteError {
91 #[error("ring is empty")]
93 Full(usize),
94 #[error("queue error")]
96 Queue(#[source] Error),
97}
98
99#[derive(Debug, Error)]
101pub enum ExternalDataError {
102 #[error("invalid gpa ranges")]
104 GpaRange(#[source] vmbus_ring::gparange::Error),
105
106 #[error("access error")]
108 Access(#[source] AccessError),
109
110 #[error("external data should have been read by calling read_transfer_ranges")]
113 WrongExternalDataType,
114}
115
116pub struct ReadBatch<'a, M: RingMem> {
118 core: &'a Core<M>,
119 read: &'a mut ReadState,
120}
121
122pub struct ReadBatchIter<'a, 'b, M: RingMem>(&'a mut ReadBatch<'b, M>);
124
125impl<'a, M: RingMem> ReadBatch<'a, M> {
126 fn next_priv(&mut self) -> Result<Option<IncomingPacket<'a, M>>, Error> {
127 let mut ptrs = self.read.ptrs.clone();
128 match self.core.in_ring().read(&mut ptrs) {
129 Ok(packet) => {
130 let packet = IncomingPacket::parse(self.core.in_ring(), packet)?;
131 self.read.ptrs = ptrs;
132 Ok(Some(packet))
133 }
134 Err(ring::ReadError::Empty) => Ok(None),
135 Err(ring::ReadError::Corrupt(err)) => Err(ErrorInner::Ring(err).into()),
136 }
137 }
138
139 fn single_packet(mut self) -> Result<Option<PacketRef<'a, M>>, Error> {
140 if let Some(packet) = self.next_priv()? {
141 Ok(Some(PacketRef {
142 batch: self,
143 packet,
144 }))
145 } else {
146 Ok(None)
147 }
148 }
149
150 pub fn packets(&mut self) -> ReadBatchIter<'_, 'a, M> {
152 ReadBatchIter(self)
153 }
154}
155
156impl<'a, M: RingMem> Iterator for ReadBatchIter<'a, '_, M> {
157 type Item = Result<IncomingPacket<'a, M>, Error>;
158
159 fn next(&mut self) -> Option<Self::Item> {
160 self.0.next_priv().transpose()
161 }
162}
163
164impl<M: RingMem> Drop for ReadBatch<'_, M> {
165 fn drop(&mut self) {
166 self.read.clear_poll(self.core);
167 if self.core.in_ring().commit_read(&mut self.read.ptrs) {
168 self.core.signal();
169 self.read.signals.increment();
170 }
171 }
172}
173
174pub struct PacketRef<'a, M: RingMem> {
176 batch: ReadBatch<'a, M>,
177 packet: IncomingPacket<'a, M>,
178}
179
180impl<'a, M: RingMem> Deref for PacketRef<'a, M> {
181 type Target = IncomingPacket<'a, M>;
182
183 fn deref(&self) -> &Self::Target {
184 &self.packet
185 }
186}
187
188impl<'a, M: RingMem> AsRef<IncomingPacket<'a, M>> for PacketRef<'a, M> {
189 fn as_ref(&self) -> &IncomingPacket<'a, M> {
190 self
191 }
192}
193
194impl<M: RingMem> PacketRef<'_, M> {
195 pub fn revert(&mut self) {
201 self.batch.read.ptrs.revert();
202 }
203}
204
205pub enum IncomingPacket<'a, T: RingMem> {
207 Data(DataPacket<'a, T>),
209 Completion(CompletionPacket<'a, T>),
211}
212
213pub struct DataPacket<'a, T: RingMem> {
215 ring: &'a IncomingRing<T>,
216 payload: ring::RingRange,
217 transaction_id: Option<u64>,
218 buffer_id: Option<u16>,
219 external_data: (u32, ring::RingRange),
220}
221
222impl<T: RingMem> DataPacket<'_, T> {
223 pub fn reader(&self) -> vmbus_ring::RingRangeReader<'_, T> {
229 self.payload.reader(self.ring)
230 }
231
232 pub fn transaction_id(&self) -> Option<u64> {
235 self.transaction_id
236 }
237
238 pub fn external_range_count(&self) -> usize {
240 self.external_data.0 as usize
241 }
242
243 fn read_transfer_page_ranges(
244 &self,
245 transfer_buf: PagedRange<'_>,
246 result: &mut MultiPagedRangeBuf,
247 ) -> Result<(), AccessError> {
248 let len = self.external_data.0 as usize;
249 let mut reader = self.external_data.1.reader(self.ring);
250 let available_count = reader.len() / size_of::<TransferPageRange>();
251 if available_count < len {
252 return Err(AccessError::OutOfRange(0, 0));
253 }
254
255 for _ in 0..len {
256 let range = reader.read_plain::<TransferPageRange>()?;
257 result.push_range(
258 transfer_buf
259 .try_subrange(range.byte_offset as usize, range.byte_count as usize)
260 .ok_or(AccessError::OutOfRange(
261 range.byte_offset as usize,
262 range.byte_count as usize,
263 ))?,
264 );
265 }
266 Ok(())
267 }
268
269 pub fn read_external_ranges(
271 &self,
272 buf: &mut MultiPagedRangeBuf,
273 ) -> Result<(), ExternalDataError> {
274 if self.buffer_id.is_some() {
275 return Err(ExternalDataError::WrongExternalDataType);
276 } else if self.external_data.0 == 0 {
277 return Ok(());
278 }
279
280 let mut reader = self.external_data.1.reader(self.ring);
281 buf.try_extend_with(reader.len() / 8, self.external_data.0 as usize, |b| {
282 reader
283 .read(b.as_mut_bytes())
284 .map_err(ExternalDataError::Access)?;
285 Ok(())
286 })?
287 .map_err(ExternalDataError::GpaRange)?;
288 Ok(())
289 }
290
291 pub fn transfer_buffer_id(&self) -> Option<u16> {
293 self.buffer_id
294 }
295
296 pub fn read_transfer_ranges(
302 &self,
303 transfer_buf: PagedRange<'_>,
304 result: &mut MultiPagedRangeBuf,
305 ) -> Result<(), AccessError> {
306 if self.external_data.0 != 0 {
307 self.read_transfer_page_ranges(transfer_buf, result)?;
308 }
309 Ok(())
310 }
311}
312
313pub struct CompletionPacket<'a, T: RingMem> {
315 ring: &'a IncomingRing<T>,
316 payload: ring::RingRange,
317 transaction_id: u64,
318}
319
320impl<T: RingMem> CompletionPacket<'_, T> {
321 pub fn reader(&self) -> impl MemoryRead + '_ {
323 self.payload.reader(self.ring)
324 }
325
326 pub fn transaction_id(&self) -> u64 {
328 self.transaction_id
329 }
330}
331
332impl<'a, T: RingMem> IncomingPacket<'a, T> {
333 fn parse(ring: &'a IncomingRing<T>, packet: ring::IncomingPacket) -> Result<Self, Error> {
334 Ok(match packet.typ {
335 IncomingPacketType::InBand => IncomingPacket::Data(DataPacket {
336 ring,
337 payload: packet.payload,
338 transaction_id: packet.transaction_id,
339 buffer_id: None,
340 external_data: (0, ring::RingRange::empty()),
341 }),
342 IncomingPacketType::GpaDirect(count, ranges) => IncomingPacket::Data(DataPacket {
343 ring,
344 payload: packet.payload,
345 transaction_id: packet.transaction_id,
346 buffer_id: None,
347 external_data: (count, ranges),
348 }),
349 IncomingPacketType::Completion => IncomingPacket::Completion(CompletionPacket {
350 ring,
351 payload: packet.payload,
352 transaction_id: packet.transaction_id.unwrap(),
353 }),
354 IncomingPacketType::TransferPages(id, count, ranges) => {
355 IncomingPacket::Data(DataPacket {
356 ring,
357 payload: packet.payload,
358 transaction_id: packet.transaction_id,
359 buffer_id: Some(id),
360 external_data: (count, ranges),
361 })
362 }
363 })
364 }
365}
366
367pub struct ReadHalf<'a, M: RingMem> {
369 core: &'a Core<M>,
370 read: &'a mut ReadState,
371}
372
373impl<'a, M: RingMem> ReadHalf<'a, M> {
374 pub fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
381 ready!(self.read.poll_ready(cx, self.core)).map_err(ErrorInner::from)?;
382 Poll::Ready(Ok(()))
383 }
384
385 pub fn poll_read_batch<'b>(
390 &'b mut self,
391 cx: &mut Context<'_>,
392 ) -> Poll<Result<ReadBatch<'b, M>, Error>> {
393 let batch = loop {
394 std::task::ready!(self.poll_ready(cx))?;
395 if self
396 .core
397 .in_ring()
398 .can_read(&mut self.read.ptrs)
399 .map_err(ErrorInner::Ring)?
400 {
401 break ReadBatch {
402 core: self.core,
403 read: self.read,
404 };
405 } else {
406 self.read.clear_ready();
407 }
408 };
409 Poll::Ready(Ok(batch))
410 }
411
412 pub fn try_read_batch(&mut self) -> Result<ReadBatch<'_, M>, TryReadError> {
414 if self
415 .core
416 .in_ring()
417 .can_read(&mut self.read.ptrs)
418 .map_err(|err| TryReadError::Queue(Error::from(ErrorInner::Ring(err))))?
419 {
420 Ok(ReadBatch {
421 core: self.core,
422 read: self.read,
423 })
424 } else {
425 self.read.clear_ready();
426 Err(TryReadError::Empty)
427 }
428 }
429
430 pub fn read_batch<'b>(&'b mut self) -> BatchRead<'a, 'b, M> {
435 BatchRead(Some(self))
436 }
437
438 pub fn try_read(&mut self) -> Result<PacketRef<'_, M>, TryReadError> {
442 let batch = self.try_read_batch()?;
443 batch
444 .single_packet()
445 .map_err(TryReadError::Queue)?
446 .ok_or(TryReadError::Empty)
447 }
448
449 pub fn read<'b>(&'b mut self) -> Read<'a, 'b, M> {
451 Read(self.read_batch())
452 }
453
454 pub fn supports_pending_send_size(&self) -> bool {
457 self.core.in_ring().supports_pending_send_size()
458 }
459}
460
461pub struct BatchRead<'a, 'b, M: RingMem>(Option<&'a mut ReadHalf<'b, M>>);
463
464impl<'a, M: RingMem> Future for BatchRead<'a, '_, M> {
465 type Output = Result<ReadBatch<'a, M>, Error>;
466
467 fn poll(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
468 let this = self.get_mut();
469 let _ = std::task::ready!(this.0.as_mut().unwrap().poll_read_batch(cx))?;
471 let this = this.0.take().unwrap();
472 Poll::Ready(Ok(ReadBatch {
473 core: this.core,
474 read: this.read,
475 }))
476 }
477}
478
479pub struct Read<'a, 'b, M: RingMem>(BatchRead<'a, 'b, M>);
481
482impl<'a, M: RingMem> Future for Read<'a, '_, M> {
483 type Output = Result<PacketRef<'a, M>, Error>;
484
485 fn poll(mut self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
486 let batch = std::task::ready!(self.0.poll_unpin(cx))?;
487 Poll::Ready(
488 batch
489 .single_packet()
490 .transpose()
491 .expect("batch was non-empty"),
492 )
493 }
494}
495
496pub struct OutgoingPacket<'a, 'b> {
498 pub transaction_id: u64,
500 pub packet_type: OutgoingPacketType<'a>,
502 pub payload: &'b [&'b [u8]],
504}
505
506pub struct WriteHalf<'a, M: RingMem> {
508 core: &'a Core<M>,
509 write: &'a mut WriteState,
510}
511
512impl<'a, M: RingMem> WriteHalf<'a, M> {
513 pub fn poll_ready(
519 &mut self,
520 cx: &mut Context<'_>,
521 send_size: usize,
522 ) -> Poll<Result<(), Error>> {
523 loop {
524 std::task::ready!(self.write.poll_ready(cx, self.core, send_size))
525 .map_err(ErrorInner::from)?;
526 if self.can_write(send_size)? {
527 break Poll::Ready(Ok(()));
528 }
529 }
530 }
531
532 pub async fn wait_ready(&mut self, send_size: usize) -> Result<(), Error> {
538 poll_fn(|cx| self.poll_ready(cx, send_size)).await
539 }
540
541 pub fn batched(&mut self) -> WriteBatch<'_, M> {
549 WriteBatch {
550 core: self.core,
551 write: self.write,
552 }
553 }
554
555 pub fn can_write(&mut self, send_size: usize) -> Result<bool, Error> {
558 self.batched().can_write(send_size)
559 }
560
561 pub fn capacity(&self) -> usize {
563 self.core.out_ring().maximum_packet_size()
564 }
565
566 pub fn try_write(&mut self, packet: &OutgoingPacket<'_, '_>) -> Result<(), TryWriteError> {
570 self.batched().try_write(packet)
571 }
572
573 pub fn poll_write(
575 &mut self,
576 cx: &mut Context<'_>,
577 packet: &OutgoingPacket<'_, '_>,
578 ) -> Poll<Result<(), Error>> {
579 let mut send_size = 32;
580 let r = loop {
581 std::task::ready!(self.write.poll_ready(cx, self.core, send_size))
582 .map_err(ErrorInner::from)?;
583 match self.try_write(packet) {
584 Ok(()) => break Ok(()),
585 Err(TryWriteError::Full(len)) => send_size = len,
586 Err(TryWriteError::Queue(err)) => break Err(err),
587 }
588 };
589 Poll::Ready(r)
590 }
591
592 pub fn write<'b, 'c>(&'b mut self, packet: OutgoingPacket<'c, 'b>) -> Write<'a, 'b, 'c, M> {
594 Write {
595 write: self,
596 packet,
597 }
598 }
599}
600
601pub struct WriteBatch<'a, M: RingMem> {
603 core: &'a Core<M>,
604 write: &'a mut WriteState,
605}
606
607impl<'a, M: RingMem> WriteBatch<'a, M> {
608 pub fn can_write(&mut self, send_size: usize) -> Result<bool, Error> {
611 let can_write = self
612 .core
613 .out_ring()
614 .can_write(&mut self.write.ptrs, send_size)
615 .map_err(ErrorInner::Ring)?;
616
617 if !can_write {
619 self.write.clear_ready();
620 }
621 Ok(can_write)
622 }
623
624 pub fn try_write(&mut self, packet: &OutgoingPacket<'_, '_>) -> Result<(), TryWriteError> {
628 let size = packet.payload.iter().fold(0, |a, p| a + p.len());
629 let ring_packet = ring::OutgoingPacket {
630 transaction_id: packet.transaction_id,
631 size,
632 typ: packet.packet_type,
633 };
634 let mut builder = self.try_start_write(&ring_packet)?;
635 let mut writer = builder.writer();
636 for &p in packet.payload {
637 writer
638 .write(p)
639 .map_err(|err| TryWriteError::Queue(ErrorInner::Access(err).into()))?;
640 }
641 builder.finish();
642 Ok(())
643 }
644
645 pub fn try_write_aligned(
651 &mut self,
652 transaction_id: u64,
653 packet_type: OutgoingPacketType<'_>,
654 data: &[u64],
655 ) -> Result<(), TryWriteError> {
656 let size = data.len() * 8;
657 let ring_packet = ring::OutgoingPacket {
658 transaction_id,
659 size,
660 typ: packet_type,
661 };
662 let mut builder = self.try_start_write(&ring_packet)?;
663 builder.write_aligned_full(data);
664 builder.finish();
665 Ok(())
666 }
667
668 fn try_start_write(
669 &mut self,
670 packet: &ring::OutgoingPacket<'_>,
671 ) -> Result<WritePacketBuilder<'_, 'a, M>, TryWriteError> {
672 let mut ptrs = self.write.ptrs.clone();
673 match self.core.out_ring().write(&mut ptrs, packet) {
674 Ok(range) => Ok(WritePacketBuilder {
675 batch: self,
676 range,
677 ptrs,
678 }),
679 Err(ring::WriteError::Full(n)) => {
680 self.write.clear_ready();
681 Err(TryWriteError::Full(n))
682 }
683 Err(ring::WriteError::Corrupt(err)) => {
684 Err(TryWriteError::Queue(ErrorInner::Ring(err).into()))
685 }
686 }
687 }
688}
689
690struct WritePacketBuilder<'a, 'b, M: RingMem> {
691 batch: &'a mut WriteBatch<'b, M>,
692 range: ring::RingRange,
693 ptrs: ring::OutgoingOffset,
694}
695
696impl<M: RingMem> WritePacketBuilder<'_, '_, M> {
697 fn writer(&mut self) -> vmbus_ring::RingRangeWriter<'_, M> {
698 self.range.writer(self.batch.core.out_ring())
699 }
700
701 fn write_aligned_full(&mut self, data: &[u64]) {
702 self.range
703 .write_aligned_full(self.batch.core.out_ring(), data)
704 }
705
706 fn finish(self) {
707 self.batch.write.clear_poll(self.batch.core);
708 self.batch.write.ptrs = self.ptrs;
709 }
710}
711
712impl<M: RingMem> Drop for WriteBatch<'_, M> {
713 fn drop(&mut self) {
714 if self.core.out_ring().commit_write(&mut self.write.ptrs) {
715 self.core.signal();
716 self.write.signals.increment();
717 }
718 }
719}
720
721#[must_use]
723pub struct Write<'a, 'b, 'c, M: RingMem> {
724 write: &'b mut WriteHalf<'a, M>,
725 packet: OutgoingPacket<'c, 'b>,
726}
727
728impl<M: RingMem> Future for Write<'_, '_, '_, M> {
729 type Output = Result<(), Error>;
730
731 fn poll(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
732 let this = self.get_mut();
733 this.write.poll_write(cx, &this.packet)
734 }
735}
736
737pub struct Queue<M: RingMem> {
743 core: Core<M>,
744 read: ReadState,
745 write: WriteState,
746}
747
748impl<M: RingMem> Inspect for Queue<M> {
749 fn inspect(&self, req: inspect::Request<'_>) {
750 req.respond()
751 .merge(&self.core)
752 .field("incoming_ring", &self.read)
753 .field("outgoing_ring", &self.write);
754 }
755}
756
757impl<M: RingMem> Queue<M> {
758 pub fn new(raw: RawAsyncChannel<M>) -> Result<Self, Error> {
761 let incoming = raw.in_ring.incoming().map_err(ErrorInner::Ring)?;
762 let outgoing = raw.out_ring.outgoing().map_err(ErrorInner::Ring)?;
763 let core = Core::new(raw);
764 let read = ReadState::new(incoming);
765 let write = WriteState::new(outgoing);
766
767 Ok(Self { core, read, write })
768 }
769
770 pub fn split(&mut self) -> (ReadHalf<'_, M>, WriteHalf<'_, M>) {
773 (
774 ReadHalf {
775 core: &self.core,
776 read: &mut self.read,
777 },
778 WriteHalf {
779 core: &self.core,
780 write: &mut self.write,
781 },
782 )
783 }
784}
785
786pub fn connected_queues(ring_size: usize) -> (Queue<FlatRingMem>, Queue<FlatRingMem>) {
788 let (host, guest) = connected_async_channels(ring_size);
789 (Queue::new(host).unwrap(), Queue::new(guest).unwrap())
790}
791
792#[cfg(test)]
793mod tests {
794 use super::*;
795 use pal_async::DefaultDriver;
796 use pal_async::async_test;
797 use pal_async::task::Spawn;
798 use pal_async::timer::PolledTimer;
799 use ring::OutgoingPacketType;
800 use std::future::poll_fn;
801 use std::time::Duration;
802 use vmbus_channel::gpadl::GpadlId;
803 use vmbus_channel::gpadl::GpadlMap;
804
805 #[async_test]
806 async fn test_gpa_direct() {
807 use guestmem::ranges::PagedRange;
808
809 let (mut host_queue, mut guest_queue) = connected_queues(16384);
810
811 let gpa1: Vec<u64> = vec![4096, 8192];
812 let gpa2: Vec<u64> = vec![8192];
813 let gpas = vec![
814 PagedRange::new(20, 4096, &gpa1).unwrap(),
815 PagedRange::new(0, 200, &gpa2).unwrap(),
816 ];
817
818 let payload: &[u8] = &[0xf; 24];
819 guest_queue
820 .split()
821 .1
822 .write(OutgoingPacket {
823 transaction_id: 0,
824 packet_type: OutgoingPacketType::GpaDirect(&gpas),
825 payload: &[payload],
826 })
827 .await
828 .unwrap();
829 host_queue
830 .split()
831 .0
832 .read_batch()
833 .await
834 .unwrap()
835 .packets()
836 .next()
837 .map(|p| match p.unwrap() {
838 IncomingPacket::Data(data) => {
839 let mut in_payload = [0_u8; 24];
841 assert_eq!(payload.len(), data.reader().len());
842 data.reader().read(&mut in_payload).unwrap();
843 assert_eq!(in_payload, payload);
844
845 assert_eq!(data.external_range_count(), 2);
847 let mut external_data = MultiPagedRangeBuf::new();
848 data.read_external_ranges(&mut external_data).unwrap();
849 let in_gpas: Vec<PagedRange<'_>> = external_data.iter().collect();
850 assert_eq!(in_gpas.len(), gpas.len());
851
852 for (p, q) in in_gpas.iter().zip(gpas) {
853 assert_eq!(p.offset(), q.offset());
854 assert_eq!(p.len(), q.len());
855 assert_eq!(p.gpns(), q.gpns());
856 }
857 Ok(())
858 }
859 _ => Err("should be data"),
860 })
861 .unwrap()
862 .unwrap();
863 }
864
865 #[async_test]
866 async fn test_gpa_direct_empty_external_data() {
867 use guestmem::ranges::PagedRange;
868
869 let (mut host_queue, mut guest_queue) = connected_queues(16384);
870
871 let gpa1: Vec<u64> = vec![];
872 let gpas = vec![PagedRange::new(0, 0, &gpa1).unwrap()];
873
874 let payload: &[u8] = &[0xf; 24];
875 guest_queue
876 .split()
877 .1
878 .write(OutgoingPacket {
879 transaction_id: 0,
880 packet_type: OutgoingPacketType::GpaDirect(&gpas),
881 payload: &[payload],
882 })
883 .await
884 .unwrap();
885 host_queue
886 .split()
887 .0
888 .read_batch()
889 .await
890 .unwrap()
891 .packets()
892 .next()
893 .map(|p| match p.unwrap() {
894 IncomingPacket::Data(data) => {
895 let mut in_payload = [0_u8; 24];
897 assert_eq!(payload.len(), data.reader().len());
898 data.reader().read(&mut in_payload).unwrap();
899 assert_eq!(in_payload, payload);
900
901 assert_eq!(data.external_range_count(), 1);
903 let mut external_data = MultiPagedRangeBuf::new();
904 let external_data_result = data.read_external_ranges(&mut external_data);
905 match external_data_result {
906 Err(ExternalDataError::GpaRange(_)) => Ok(()),
907 _ => Err("should be out of range"),
908 }
909 }
910 _ => Err("should be data"),
911 })
912 .unwrap()
913 .unwrap();
914 }
915
916 #[async_test]
917 async fn test_transfer_pages() {
918 use guestmem::ranges::PagedRange;
919
920 let (mut host_queue, mut guest_queue) = connected_queues(16384);
921
922 let gpadl_map = GpadlMap::new();
923 let buf = vec![0x3000_u64, 1, 2, 3];
924 gpadl_map.add(
925 GpadlId(13),
926 MultiPagedRangeBuf::from_range_buffer(1, buf).unwrap(),
927 );
928
929 let ranges = vec![
930 TransferPageRange {
931 byte_count: 0x10,
932 byte_offset: 0x10,
933 },
934 TransferPageRange {
935 byte_count: 0x10,
936 byte_offset: 0xfff,
937 },
938 TransferPageRange {
939 byte_count: 0x10,
940 byte_offset: 0x1000,
941 },
942 ];
943
944 let payload: &[u8] = &[0xf; 24];
945 guest_queue
946 .split()
947 .1
948 .write(OutgoingPacket {
949 transaction_id: 0,
950 packet_type: OutgoingPacketType::TransferPages(13, &ranges),
951 payload: &[payload],
952 })
953 .await
954 .unwrap();
955 host_queue
956 .split()
957 .0
958 .read_batch()
959 .await
960 .unwrap()
961 .packets()
962 .next()
963 .map(|p| match p.unwrap() {
964 IncomingPacket::Data(data) => {
965 let mut in_payload = [0_u8; 24];
967 assert_eq!(payload.len(), data.reader().len());
968 data.reader().read(&mut in_payload).unwrap();
969 assert_eq!(in_payload, payload);
970
971 assert_eq!(data.external_range_count(), 3);
973 let gpadl_map_view = gpadl_map.view();
974 assert_eq!(data.transfer_buffer_id().unwrap(), 13);
975 let buffer_gpadl = gpadl_map_view.map(GpadlId(13)).unwrap();
976 let buffer_range = buffer_gpadl.first().unwrap();
977 let mut external_data = MultiPagedRangeBuf::new();
978 data.read_transfer_ranges(buffer_range, &mut external_data)
979 .unwrap();
980 let in_ranges: Vec<PagedRange<'_>> = external_data.iter().collect();
981 assert_eq!(in_ranges.len(), ranges.len());
982 assert_eq!(in_ranges[0].offset(), 0x10);
983 assert_eq!(in_ranges[0].len(), 0x10);
984 assert_eq!(in_ranges[0].gpns().len(), 1);
985 assert_eq!(in_ranges[0].gpns()[0], 1);
986
987 assert_eq!(in_ranges[1].offset(), 0xfff);
988 assert_eq!(in_ranges[1].len(), 0x10);
989 assert_eq!(in_ranges[1].gpns().len(), 2);
990 assert_eq!(in_ranges[1].gpns()[0], 1);
991 assert_eq!(in_ranges[1].gpns()[1], 2);
992
993 assert_eq!(in_ranges[2].offset(), 0);
994 assert_eq!(in_ranges[2].len(), 0x10);
995 assert_eq!(in_ranges[2].gpns().len(), 1);
996 assert_eq!(in_ranges[2].gpns()[0], 2);
997
998 Ok(())
999 }
1000 _ => Err("should be data"),
1001 })
1002 .unwrap()
1003 .unwrap();
1004 }
1005
1006 #[async_test]
1007 async fn test_ring_full(driver: DefaultDriver) {
1008 let (mut host_queue, mut guest_queue) = connected_queues(4096);
1009
1010 assert!(
1011 poll_fn(|cx| host_queue.split().1.poll_ready(cx, 4000))
1012 .now_or_never()
1013 .is_some()
1014 );
1015
1016 host_queue
1017 .split()
1018 .1
1019 .try_write(&OutgoingPacket {
1020 transaction_id: 0,
1021 packet_type: OutgoingPacketType::InBandNoCompletion,
1022 payload: &[&[0u8; 4000]],
1023 })
1024 .unwrap();
1025
1026 let n = match host_queue
1027 .split()
1028 .1
1029 .try_write(&OutgoingPacket {
1030 transaction_id: 0,
1031 packet_type: OutgoingPacketType::InBandNoCompletion,
1032 payload: &[&[0u8; 4000]],
1033 })
1034 .unwrap_err()
1035 {
1036 TryWriteError::Full(n) => n,
1037 _ => unreachable!(),
1038 };
1039
1040 let mut poll = async move {
1041 let mut host_queue = host_queue;
1042 poll_fn(|cx| host_queue.split().1.poll_ready(cx, n))
1043 .await
1044 .unwrap();
1045 host_queue
1046 }
1047 .boxed();
1048
1049 assert!(futures::poll!(&mut poll).is_pending());
1050 let poll = driver.spawn("test", poll);
1051
1052 PolledTimer::new(&driver)
1053 .sleep(Duration::from_millis(50))
1054 .await;
1055
1056 guest_queue.split().0.read().await.unwrap();
1057 assert!(guest_queue.split().0.try_read().is_err());
1058
1059 let mut host_queue = poll.await;
1060
1061 host_queue
1062 .split()
1063 .1
1064 .try_write(&OutgoingPacket {
1065 transaction_id: 0,
1066 packet_type: OutgoingPacketType::InBandNoCompletion,
1067 payload: &[&[0u8; 4000]],
1068 })
1069 .unwrap();
1070 }
1071}