#![expect(missing_docs)]
#![forbid(unsafe_code)]
pub mod gparange;
pub use pipe_protocol::*;
pub use protocol::PAGE_SIZE;
pub use protocol::TransferPageRange;
use crate::gparange::GpaRange;
use guestmem::AccessError;
use guestmem::MemoryRead;
use guestmem::MemoryWrite;
use guestmem::ranges::PagedRange;
use inspect::Inspect;
use protocol::*;
use safeatomic::AtomicSliceOps;
use std::fmt::Debug;
use std::sync::Arc;
use std::sync::atomic::AtomicU8;
use std::sync::atomic::AtomicU32;
use std::sync::atomic::AtomicU64;
use std::sync::atomic::Ordering;
use thiserror::Error;
use zerocopy::FromZeros;
use zerocopy::IntoBytes;
mod pipe_protocol {
use zerocopy::FromBytes;
use zerocopy::Immutable;
use zerocopy::IntoBytes;
use zerocopy::KnownLayout;
#[repr(C)]
#[derive(Debug, Copy, Clone, IntoBytes, Immutable, KnownLayout, FromBytes)]
pub struct PipeHeader {
pub packet_type: u32,
pub len: u32,
}
pub const PIPE_PACKET_TYPE_DATA: u32 = 1;
pub const PIPE_PACKET_TYPE_PARTIAL: u32 = 2;
pub const PIPE_PACKET_TYPE_SETUP_GPA_DIRECT: u32 = 3;
pub const PIPE_PACKET_TYPE_TEARDOWN_GPA_DIRECT: u32 = 4;
pub const MAXIMUM_PIPE_PACKET_SIZE: usize = 16384;
}
mod protocol {
use crate::CONTROL_WORD_COUNT;
use inspect::Inspect;
use std::fmt::Debug;
use std::sync::atomic::AtomicU32;
use std::sync::atomic::Ordering;
use zerocopy::FromBytes;
use zerocopy::Immutable;
use zerocopy::IntoBytes;
use zerocopy::KnownLayout;
pub const PAGE_SIZE: usize = 4096;
#[repr(C)]
#[derive(Copy, Clone, Debug, IntoBytes, Immutable, KnownLayout, FromBytes)]
pub struct PacketDescriptor {
pub packet_type: u16,
pub data_offset8: u16,
pub length8: u16,
pub flags: u16,
pub transaction_id: u64,
}
pub struct Control<'a>(pub &'a [AtomicU32; CONTROL_WORD_COUNT]);
impl Control<'_> {
pub fn inp(&self) -> &AtomicU32 {
&self.0[0]
}
pub fn outp(&self) -> &AtomicU32 {
&self.0[1]
}
pub fn interrupt_mask(&self) -> &AtomicU32 {
&self.0[2]
}
pub fn pending_send_size(&self) -> &AtomicU32 {
&self.0[3]
}
pub fn feature_bits(&self) -> &AtomicU32 {
&self.0[16]
}
}
impl Inspect for Control<'_> {
fn inspect(&self, req: inspect::Request<'_>) {
req.respond()
.hex("in", self.inp().load(Ordering::Relaxed))
.hex("out", self.outp().load(Ordering::Relaxed))
.hex(
"interrupt_mask",
self.interrupt_mask().load(Ordering::Relaxed),
)
.hex(
"pending_send_size",
self.pending_send_size().load(Ordering::Relaxed),
)
.hex("feature_bits", self.feature_bits().load(Ordering::Relaxed));
}
}
impl Debug for Control<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Control")
.field("inp", self.inp())
.field("outp", self.outp())
.field("interrupt_mask", self.interrupt_mask())
.field("pending_send_size", self.pending_send_size())
.field("feature_bits", self.feature_bits())
.finish()
}
}
pub const FEATURE_SUPPORTS_PENDING_SEND_SIZE: u32 = 1;
#[repr(C)]
#[derive(Debug, Copy, Clone, IntoBytes, Immutable, KnownLayout, FromBytes)]
pub struct TransferPageRange {
pub byte_count: u32,
pub byte_offset: u32,
}
#[repr(C)]
#[derive(Copy, Clone, Debug, IntoBytes, Immutable, KnownLayout, FromBytes)]
pub struct TransferPageHeader {
pub transfer_page_set_id: u16,
pub reserved: u16, pub range_count: u32,
}
#[repr(C)]
#[derive(Copy, Clone, Debug, IntoBytes, Immutable, KnownLayout, FromBytes)]
pub struct GpaDirectHeader {
pub reserved: u32, pub range_count: u32,
}
pub const PACKET_FLAG_COMPLETION_REQUESTED: u16 = 1;
#[repr(C)]
#[derive(Copy, Clone, Debug, IntoBytes, Immutable, KnownLayout, FromBytes)]
pub struct Footer {
pub reserved: u32,
pub offset: u32,
}
}
#[derive(Copy, Clone, Debug, Error)]
pub enum Error {
#[error("invalid ring buffer pointer")]
InvalidRingPointer,
#[error("invalid message length")]
InvalidMessageLength,
#[error("invalid data available")]
InvalidDataAvailable,
#[error("ring buffer too large")]
RingTooLarge,
#[error("invalid ring memory")]
InvalidRingMemory,
#[error("invalid descriptor offset or length")]
InvalidDescriptorLengths,
#[error("unknown packet descriptor flags")]
InvalidDescriptorFlags,
#[error("unknown packet descriptor type")]
InvalidDescriptorType,
#[error("invalid range count for gpa direct packet")]
InvalidDescriptorGpaDirectRangeCount,
#[error("the interrupt mask bit was supposed to be clear but it is set")]
InterruptsExternallyMasked,
}
#[derive(Copy, Clone, Debug, Error)]
pub enum ReadError {
#[error("ring buffer empty")]
Empty,
#[error(transparent)]
Corrupt(#[from] Error),
}
#[derive(Copy, Clone, Debug, Error)]
pub enum WriteError {
#[error("ring buffer full")]
Full(usize),
#[error(transparent)]
Corrupt(#[from] Error),
}
#[derive(Copy, Clone, Debug)]
pub struct RingRange {
off: u32,
size: u32,
}
impl RingRange {
pub fn empty() -> Self {
RingRange { off: 0, size: 0 }
}
pub fn writer<'a, T: Ring>(&self, ring: &'a T) -> RingRangeWriter<'a, T::Memory> {
RingRangeWriter {
start: self.off,
end: self.off + self.size,
mem: ring.mem(),
}
}
pub fn reader<'a, T: Ring>(&self, ring: &'a T) -> RingRangeReader<'a, T::Memory> {
RingRangeReader {
start: self.off,
end: self.off + self.size,
mem: ring.mem(),
}
}
pub fn len(&self) -> usize {
self.size as usize
}
pub fn is_empty(&self) -> bool {
self.size == 0
}
}
pub struct RingRangeReader<'a, T> {
start: u32,
end: u32,
mem: &'a T,
}
impl<T: RingMem> MemoryRead for RingRangeReader<'_, T> {
fn read(&mut self, data: &mut [u8]) -> Result<&mut Self, AccessError> {
if self.len() < data.len() {
return Err(AccessError::OutOfRange(self.len(), data.len()));
}
self.mem.read_at(self.start as usize, data);
self.start += data.len() as u32;
Ok(self)
}
fn skip(&mut self, len: usize) -> Result<&mut Self, AccessError> {
if self.len() < len {
return Err(AccessError::OutOfRange(self.len(), len));
}
self.start += len as u32;
Ok(self)
}
fn len(&self) -> usize {
(self.end - self.start) as usize
}
}
pub struct RingRangeWriter<'a, T> {
start: u32,
end: u32,
mem: &'a T,
}
impl<T: RingMem> MemoryWrite for RingRangeWriter<'_, T> {
fn write(&mut self, data: &[u8]) -> Result<(), AccessError> {
if self.len() < data.len() {
return Err(AccessError::OutOfRange(self.len(), data.len()));
}
self.mem.write_at(self.start as usize, data);
self.start += data.len() as u32;
Ok(())
}
fn fill(&mut self, _val: u8, _len: usize) -> Result<(), AccessError> {
unimplemented!()
}
fn len(&self) -> usize {
(self.end - self.start) as usize
}
}
#[derive(Debug, Copy, Clone)]
pub enum IncomingPacketType {
InBand,
Completion,
GpaDirect(u32, RingRange),
TransferPages(u16, u32, RingRange),
}
#[derive(Debug)]
pub struct IncomingPacket {
pub transaction_id: Option<u64>,
pub typ: IncomingPacketType,
pub payload: RingRange,
}
const PACKET_TYPE_IN_BAND: u16 = 6;
const PACKET_TYPE_TRANSFER_PAGES: u16 = 0x7;
const PACKET_TYPE_GPA_DIRECT: u16 = 0x9;
const PACKET_TYPE_COMPLETION: u16 = 0xb;
fn parse_packet<M: RingMem>(
ring: &M,
ring_off: u32,
avail: u32,
) -> Result<(u32, IncomingPacket), ReadError> {
let mut desc = PacketDescriptor::new_zeroed();
ring.read_aligned(ring_off as usize, desc.as_mut_bytes());
let len = desc.length8 as u32 * 8;
if desc.length8 < desc.data_offset8 || desc.data_offset8 < 2 || avail < len {
return Err(ReadError::Corrupt(Error::InvalidDescriptorLengths));
}
if (desc.flags & !PACKET_FLAG_COMPLETION_REQUESTED) != 0 {
return Err(ReadError::Corrupt(Error::InvalidDescriptorFlags));
}
let transaction_id = if desc.flags & PACKET_FLAG_COMPLETION_REQUESTED != 0
|| desc.packet_type == PACKET_TYPE_COMPLETION
{
Some(desc.transaction_id)
} else {
None
};
let typ = match desc.packet_type {
PACKET_TYPE_IN_BAND => IncomingPacketType::InBand,
PACKET_TYPE_COMPLETION => IncomingPacketType::Completion,
PACKET_TYPE_TRANSFER_PAGES => {
let mut tph = TransferPageHeader::new_zeroed();
ring.read_aligned(ring_off as usize + 16, tph.as_mut_bytes());
IncomingPacketType::TransferPages(
tph.transfer_page_set_id,
tph.range_count,
RingRange {
off: ring_off + 24,
size: desc.data_offset8 as u32 * 8 - 24,
},
)
}
PACKET_TYPE_GPA_DIRECT => {
let mut gph = GpaDirectHeader::new_zeroed();
ring.read_aligned(ring_off as usize + 16, gph.as_mut_bytes());
if gph.range_count == 0 {
return Err(ReadError::Corrupt(
Error::InvalidDescriptorGpaDirectRangeCount,
));
}
IncomingPacketType::GpaDirect(
gph.range_count,
RingRange {
off: ring_off + 24,
size: desc.data_offset8 as u32 * 8 - 24,
},
)
}
_ => return Err(ReadError::Corrupt(Error::InvalidDescriptorType)),
};
let payload = RingRange {
off: ring_off + desc.data_offset8 as u32 * 8,
size: (desc.length8 - desc.data_offset8) as u32 * 8,
};
Ok((
len,
IncomingPacket {
transaction_id,
typ,
payload,
},
))
}
pub const CONTROL_WORD_COUNT: usize = 32;
pub trait RingMem: Send {
fn control(&self) -> &[AtomicU32; CONTROL_WORD_COUNT];
fn read_at(&self, addr: usize, data: &mut [u8]);
fn read_aligned(&self, addr: usize, data: &mut [u8]) {
debug_assert!(addr % 8 == 0);
debug_assert!(data.len() % 8 == 0);
self.read_at(addr, data)
}
fn write_at(&self, addr: usize, data: &[u8]);
fn write_aligned(&self, addr: usize, data: &[u8]) {
debug_assert!(addr % 8 == 0);
debug_assert!(data.len() % 8 == 0);
self.write_at(addr, data)
}
fn len(&self) -> usize;
}
impl<T: RingMem + Sync> RingMem for &'_ T {
fn control(&self) -> &[AtomicU32; CONTROL_WORD_COUNT] {
(*self).control()
}
fn read_at(&self, addr: usize, data: &mut [u8]) {
(*self).read_at(addr, data)
}
fn write_at(&self, addr: usize, data: &[u8]) {
(*self).write_at(addr, data)
}
fn len(&self) -> usize {
(*self).len()
}
fn read_aligned(&self, addr: usize, data: &mut [u8]) {
(*self).read_aligned(addr, data)
}
fn write_aligned(&self, addr: usize, data: &[u8]) {
(*self).write_aligned(addr, data)
}
}
#[derive(Clone)]
pub struct FlatRingMem {
inner: Arc<FlatRingInner>,
}
struct FlatRingInner {
control: [AtomicU32; CONTROL_WORD_COUNT],
data: Vec<AtomicU8>,
}
impl FlatRingMem {
pub fn new(len: usize) -> Self {
let mut data = Vec::new();
data.resize_with(len, Default::default);
Self {
inner: Arc::new(FlatRingInner {
control: [0; CONTROL_WORD_COUNT].map(Into::into),
data,
}),
}
}
}
impl Debug for FlatRingMem {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("FlatRingMem").finish()
}
}
impl RingMem for FlatRingMem {
fn read_at(&self, mut addr: usize, data: &mut [u8]) {
if addr > self.len() {
addr -= self.len();
}
if addr + data.len() <= self.len() {
self.inner.data[addr..addr + data.len()].atomic_read(data);
} else {
let data_len = data.len();
let (first, last) = data.split_at_mut(self.len() - addr);
self.inner.data[addr..].atomic_read(first);
self.inner.data[..data_len - (self.len() - addr)].atomic_read(last);
}
}
fn write_at(&self, mut addr: usize, data: &[u8]) {
if addr > self.len() {
addr -= self.len();
}
if addr + data.len() <= self.len() {
self.inner.data[addr..addr + data.len()].atomic_write(data);
} else {
let (first, last) = data.split_at(self.len() - addr);
self.inner.data[addr..].atomic_write(first);
self.inner.data[..data.len() - (self.len() - addr)].atomic_write(last);
}
}
fn control(&self) -> &[AtomicU32; CONTROL_WORD_COUNT] {
&self.inner.control
}
fn len(&self) -> usize {
self.inner.data.len()
}
}
pub trait PagedMemory: Send {
fn control(&self) -> &[AtomicU8; PAGE_SIZE];
fn data_page_count(&self) -> usize;
fn data(&self, page: usize) -> &[AtomicU8; PAGE_SIZE];
}
#[derive(Debug, Clone)]
pub struct PagedRingMem<T>(T);
impl<T: PagedMemory> PagedRingMem<T> {
pub fn new(inner: T) -> Self {
Self(inner)
}
}
impl<T: PagedMemory> RingMem for PagedRingMem<T> {
fn len(&self) -> usize {
self.0.data_page_count() * PAGE_SIZE
}
fn read_at(&self, mut addr: usize, mut data: &mut [u8]) {
while !data.is_empty() {
let page = addr / PAGE_SIZE;
let offset = addr % PAGE_SIZE;
let offset_end = PAGE_SIZE.min(offset + data.len());
let len = offset_end - offset;
let (this, next) = data.split_at_mut(len);
self.0.data(page)[offset..offset_end].atomic_read(this);
addr += len;
data = next;
}
}
fn write_at(&self, mut addr: usize, mut data: &[u8]) {
while !data.is_empty() {
let page = addr / PAGE_SIZE;
let offset = addr % PAGE_SIZE;
let offset_end = PAGE_SIZE.min(offset + data.len());
let len = offset_end - offset;
let (this, next) = data.split_at(len);
self.0.data(page)[offset..offset_end].atomic_write(this);
addr += len;
data = next;
}
}
#[inline]
fn read_aligned(&self, addr: usize, data: &mut [u8]) {
debug_assert!(addr % 8 == 0);
debug_assert!(data.len() % 8 == 0);
for (i, b) in data.chunks_exact_mut(8).enumerate() {
let addr = (addr & !7) + i * 8;
let page = addr / PAGE_SIZE;
let offset = addr % PAGE_SIZE;
b.copy_from_slice(
&self.0.data(page)[offset..offset + 8]
.as_atomic::<AtomicU64>()
.unwrap()
.load(Ordering::Relaxed)
.to_ne_bytes(),
);
}
}
#[inline]
fn write_aligned(&self, addr: usize, data: &[u8]) {
debug_assert!(addr % 8 == 0);
debug_assert!(data.len() % 8 == 0);
for (i, b) in data.chunks_exact(8).enumerate() {
let addr = (addr & !7) + i * 8;
let page = addr / PAGE_SIZE;
let offset = addr % PAGE_SIZE;
self.0.data(page)[offset..offset + 8]
.as_atomic::<AtomicU64>()
.unwrap()
.store(u64::from_ne_bytes(b.try_into().unwrap()), Ordering::Relaxed);
}
}
#[inline]
fn control(&self) -> &[AtomicU32; CONTROL_WORD_COUNT] {
self.0.control().as_atomic_slice().unwrap()[..CONTROL_WORD_COUNT]
.try_into()
.unwrap()
}
}
#[derive(Debug)]
pub struct OutgoingPacket<'a> {
pub transaction_id: u64,
pub size: usize,
pub typ: OutgoingPacketType<'a>,
}
#[derive(Debug, Copy, Clone)]
pub enum OutgoingPacketType<'a> {
InBandNoCompletion,
InBandWithCompletion,
Completion,
GpaDirect(&'a [PagedRange<'a>]),
TransferPages(u16, &'a [TransferPageRange]),
}
pub struct PacketSize(());
impl PacketSize {
pub const fn in_band(payload_len: usize) -> usize {
size_of::<PacketDescriptor>() + ((payload_len + 7) & !7) + size_of::<Footer>()
}
pub const fn completion(payload_len: usize) -> usize {
Self::in_band(payload_len)
}
pub const fn transfer_pages(count: usize, payload_len: usize) -> usize {
Self::in_band(payload_len)
+ size_of::<TransferPageHeader>()
+ count * size_of::<TransferPageRange>()
}
}
pub trait Ring {
type Memory: RingMem;
fn mem(&self) -> &Self::Memory;
}
#[derive(Debug)]
pub struct IncomingRing<M: RingMem> {
inner: InnerRing<M>,
}
impl<M: RingMem> Inspect for IncomingRing<M> {
fn inspect(&self, req: inspect::Request<'_>) {
self.inner.inspect(req);
}
}
#[derive(Debug, Clone, Inspect)]
pub struct IncomingOffset {
#[inspect(hex)]
cached_in: u32,
#[inspect(hex)]
committed_out: u32,
#[inspect(hex)]
next_out: u32,
}
impl IncomingOffset {
pub fn revert(&mut self) {
self.next_out = self.committed_out;
}
}
impl<M: RingMem> Ring for IncomingRing<M> {
type Memory = M;
fn mem(&self) -> &Self::Memory {
&self.inner.mem
}
}
impl<M: RingMem> IncomingRing<M> {
pub fn new(mem: M) -> Result<Self, Error> {
let inner = InnerRing::new(mem)?;
let control = inner.control();
control.interrupt_mask().store(1, Ordering::Relaxed);
Ok(Self { inner })
}
pub fn supports_pending_send_size(&self) -> bool {
let feature_bits = self.inner.control().feature_bits().load(Ordering::Relaxed);
(feature_bits & FEATURE_SUPPORTS_PENDING_SEND_SIZE) != 0
}
pub fn set_interrupt_mask(&self, state: bool) {
self.inner
.control()
.interrupt_mask()
.store(state as u32, Ordering::SeqCst);
}
pub fn verify_interrupts_unmasked(&self) -> Result<(), Error> {
if self
.inner
.control()
.interrupt_mask()
.load(Ordering::Relaxed)
== 0
{
Ok(())
} else {
Err(Error::InterruptsExternallyMasked)
}
}
pub fn incoming(&self) -> Result<IncomingOffset, Error> {
let control = self.inner.control();
let next_out = self
.inner
.validate(control.outp().load(Ordering::Relaxed))?;
let cached_in = self.inner.validate(control.inp().load(Ordering::Relaxed))?;
Ok(IncomingOffset {
next_out,
cached_in,
committed_out: next_out,
})
}
pub fn can_read(&self, incoming: &mut IncomingOffset) -> Result<bool, Error> {
let can_read = if incoming.next_out != incoming.cached_in {
true
} else {
let inp = self
.inner
.validate(self.inner.control().inp().load(Ordering::Acquire))?;
incoming.cached_in = inp;
incoming.next_out != inp
};
Ok(can_read)
}
pub fn commit_read(&self, ptrs: &mut IncomingOffset) -> bool {
if ptrs.committed_out == ptrs.next_out {
return false;
}
let control = self.inner.control();
control.outp().store(ptrs.next_out, Ordering::SeqCst);
let pending_send_size = control.pending_send_size().load(Ordering::SeqCst);
let pending_send_size = pending_send_size & !7;
let signal = if pending_send_size != 0 {
if let Ok(inp) = self.inner.validate(control.inp().load(Ordering::SeqCst)) {
let old_free = self.inner.free(inp, ptrs.committed_out);
let new_free = self.inner.free(inp, ptrs.next_out);
old_free < pending_send_size && new_free >= pending_send_size
} else {
false
}
} else {
false
};
ptrs.committed_out = ptrs.next_out;
signal
}
pub fn read(&self, ptrs: &mut IncomingOffset) -> Result<IncomingPacket, ReadError> {
let outp = ptrs.next_out;
let mut inp = ptrs.cached_in;
if inp == outp {
inp = self
.inner
.validate(self.inner.control().inp().load(Ordering::Acquire))?;
if inp == outp {
return Err(ReadError::Empty);
}
ptrs.cached_in = inp;
}
let avail = self.inner.available(inp, outp);
if avail < 16 {
return Err(ReadError::Corrupt(Error::InvalidDataAvailable));
}
let (len, packet) = parse_packet(&self.inner.mem, outp, avail)?;
ptrs.next_out = self
.inner
.add_pointer(outp, len + size_of::<Footer>() as u32);
Ok(packet)
}
}
#[derive(Debug)]
pub struct OutgoingRing<M: RingMem> {
inner: InnerRing<M>,
}
impl<M: RingMem> Inspect for OutgoingRing<M> {
fn inspect(&self, req: inspect::Request<'_>) {
self.inner.inspect(req);
}
}
#[derive(Debug, Clone, Inspect)]
pub struct OutgoingOffset {
#[inspect(hex)]
cached_out: u32,
#[inspect(hex)]
committed_in: u32,
#[inspect(hex)]
next_in: u32,
}
impl OutgoingOffset {
pub fn revert(&mut self) {
self.next_in = self.committed_in;
}
}
impl<M: RingMem> Ring for OutgoingRing<M> {
type Memory = M;
fn mem(&self) -> &Self::Memory {
&self.inner.mem
}
}
impl<M: RingMem> OutgoingRing<M> {
pub fn new(mem: M) -> Result<Self, Error> {
let inner = InnerRing::new(mem)?;
let control = inner.control();
control
.feature_bits()
.store(FEATURE_SUPPORTS_PENDING_SEND_SIZE, Ordering::Relaxed);
control.pending_send_size().store(0, Ordering::Relaxed);
Ok(Self { inner })
}
pub fn outgoing(&self) -> Result<OutgoingOffset, Error> {
let control = self.inner.control();
let next_in = self.inner.validate(control.inp().load(Ordering::Relaxed))?;
let cached_out = self
.inner
.validate(control.outp().load(Ordering::Relaxed))?;
Ok(OutgoingOffset {
cached_out,
committed_in: next_in,
next_in,
})
}
pub fn set_pending_send_size(&self, len: usize) -> Result<(), Error> {
if len > self.maximum_packet_size() {
return Err(Error::InvalidMessageLength);
}
self.inner
.control()
.pending_send_size()
.store((len as u32 + 7) & !7, Ordering::SeqCst);
Ok(())
}
pub fn maximum_packet_size(&self) -> usize {
self.inner.len() as usize - 8
}
pub fn can_write(&self, ptrs: &mut OutgoingOffset, len: usize) -> Result<bool, Error> {
let can_write = if self.inner.free(ptrs.next_in, ptrs.cached_out) as usize >= len {
true
} else {
let outp = self
.inner
.validate(self.inner.control().outp().load(Ordering::Relaxed))?;
ptrs.cached_out = outp;
self.inner.free(ptrs.next_in, outp) as usize >= len
};
Ok(can_write)
}
pub fn commit_write(&self, ptrs: &mut OutgoingOffset) -> bool {
if ptrs.committed_in == ptrs.next_in {
return false;
}
let inp = ptrs.next_in;
let control = self.inner.control();
control.inp().store(inp, Ordering::SeqCst);
let needs_interrupt = control.interrupt_mask().load(Ordering::SeqCst) == 0
&& control.outp().load(Ordering::SeqCst) == ptrs.committed_in;
ptrs.committed_in = inp;
needs_interrupt
}
pub fn write(
&self,
ptrs: &mut OutgoingOffset,
packet: &OutgoingPacket<'_>,
) -> Result<RingRange, WriteError> {
const DESCRIPTOR_SIZE: usize = size_of::<PacketDescriptor>();
let (packet_type, header_size, flags) = match packet.typ {
OutgoingPacketType::InBandNoCompletion => (PACKET_TYPE_IN_BAND, DESCRIPTOR_SIZE, 0),
OutgoingPacketType::InBandWithCompletion => (
PACKET_TYPE_IN_BAND,
DESCRIPTOR_SIZE,
PACKET_FLAG_COMPLETION_REQUESTED,
),
OutgoingPacketType::Completion => (PACKET_TYPE_COMPLETION, DESCRIPTOR_SIZE, 0),
OutgoingPacketType::GpaDirect(ranges) => (
PACKET_TYPE_GPA_DIRECT,
DESCRIPTOR_SIZE
+ size_of::<GpaDirectHeader>()
+ ranges.iter().fold(0, |a, range| {
a + size_of::<GpaRange>() + size_of_val(range.gpns())
}),
PACKET_FLAG_COMPLETION_REQUESTED,
),
OutgoingPacketType::TransferPages(_, ranges) => (
PACKET_TYPE_TRANSFER_PAGES,
DESCRIPTOR_SIZE + size_of::<TransferPageHeader>() + size_of_val(ranges),
PACKET_FLAG_COMPLETION_REQUESTED,
),
};
let msg_len = (packet.size + header_size + 7) / 8 * 8;
let total_msg_len = (msg_len + size_of::<Footer>()) as u32;
if total_msg_len >= self.inner.len() - 8 {
return Err(WriteError::Corrupt(Error::InvalidMessageLength));
}
let inp = ptrs.next_in;
let mut outp = ptrs.cached_out;
if self.inner.free(inp, outp) < total_msg_len {
outp = self
.inner
.validate(self.inner.control().outp().load(Ordering::Relaxed))?;
if self.inner.free(inp, outp) < total_msg_len {
return Err(WriteError::Full(total_msg_len as usize));
}
ptrs.cached_out = outp;
}
let desc = PacketDescriptor {
packet_type,
data_offset8: header_size as u16 / 8,
length8: (msg_len / 8) as u16,
flags,
transaction_id: packet.transaction_id,
};
let footer = Footer {
reserved: 0,
offset: inp,
};
let off = inp as usize;
self.inner.mem.write_aligned(off, desc.as_bytes());
match packet.typ {
OutgoingPacketType::GpaDirect(ranges) => {
let mut writer = RingRange {
off: (off + DESCRIPTOR_SIZE) as u32,
size: header_size as u32,
}
.writer(self);
let gpa_header = GpaDirectHeader {
reserved: 0,
range_count: ranges.len() as u32,
};
writer
.write(gpa_header.as_bytes())
.map_err(|_| WriteError::Corrupt(Error::InvalidMessageLength))?;
for range in ranges {
let gpa_rng = GpaRange {
len: range.len() as u32,
offset: range.offset() as u32,
};
writer
.write(gpa_rng.as_bytes())
.map_err(|_| WriteError::Corrupt(Error::InvalidMessageLength))?;
writer
.write(range.gpns().as_bytes())
.map_err(|_| WriteError::Corrupt(Error::InvalidMessageLength))?;
}
}
OutgoingPacketType::TransferPages(tp_id, ranges) => {
let tp_header = TransferPageHeader {
transfer_page_set_id: tp_id,
reserved: 0,
range_count: ranges.len() as u32,
};
self.inner
.mem
.write_aligned(off + DESCRIPTOR_SIZE, tp_header.as_bytes());
for (i, range) in ranges.iter().enumerate() {
self.inner.mem.write_aligned(
off + DESCRIPTOR_SIZE + size_of_val(&tp_header) + i * 8,
range.as_bytes(),
);
}
}
_ => (),
}
self.inner
.mem
.write_aligned(off + msg_len, footer.as_bytes());
ptrs.next_in = self.inner.add_pointer(inp, total_msg_len);
Ok(RingRange {
off: inp + header_size as u32,
size: packet.size as u32,
})
}
}
struct InnerRing<M: RingMem> {
mem: M,
size: u32,
}
impl<M: RingMem> Inspect for InnerRing<M> {
fn inspect(&self, req: inspect::Request<'_>) {
req.respond()
.hex("ring_size", self.size)
.field("control", self.control());
}
}
pub fn inspect_ring<M: RingMem>(mem: M, req: inspect::Request<'_>) {
let _ = InnerRing::new(mem).map(|ring| ring.inspect(req));
}
pub fn reader_needs_signal<M: RingMem>(mem: M) -> bool {
InnerRing::new(mem).is_ok_and(|ring| {
let control = ring.control();
control.interrupt_mask().load(Ordering::Relaxed) == 0
&& (control.inp().load(Ordering::Relaxed) != control.outp().load(Ordering::Relaxed))
})
}
pub fn writer_needs_signal<M: RingMem>(mem: M) -> bool {
InnerRing::new(mem).is_ok_and(|ring| {
let control = ring.control();
let pending_size = control.pending_send_size().load(Ordering::Relaxed);
pending_size != 0
&& ring.free(
control.inp().load(Ordering::Relaxed),
control.outp().load(Ordering::Relaxed),
) >= pending_size
})
}
impl<M: RingMem> Debug for InnerRing<M> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("InnerRing")
.field("control", &self.control())
.field("size", &self.size)
.finish()
}
}
impl<M: RingMem> InnerRing<M> {
pub fn new(mem: M) -> Result<Self, Error> {
let ring_size = u32::try_from(mem.len()).map_err(|_| Error::InvalidRingMemory)?;
if ring_size % 4096 != 0 {
return Err(Error::InvalidRingMemory);
}
let ring = InnerRing {
mem,
size: ring_size,
};
Ok(ring)
}
fn control(&self) -> Control<'_> {
Control(self.mem.control())
}
fn len(&self) -> u32 {
self.size
}
fn validate(&self, p: u32) -> Result<u32, Error> {
if p >= self.size || p % 8 != 0 {
Err(Error::InvalidRingPointer)
} else {
Ok(p)
}
}
fn add_pointer(&self, p: u32, off: u32) -> u32 {
let np = p + off;
if np >= self.size {
assert!(np < self.size * 2);
np - self.size
} else {
np
}
}
fn available(&self, inp: u32, outp: u32) -> u32 {
if inp > outp {
inp - outp
} else {
self.size + inp - outp
}
}
fn free(&self, inp: u32, outp: u32) -> u32 {
if outp > inp {
outp - inp - 8
} else {
self.size - (inp - outp) - 8
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn write_simple<T: RingMem>(out_ring: &mut OutgoingRing<T>, buf: &[u8]) -> Option<bool> {
let mut outgoing = out_ring.outgoing().unwrap();
match out_ring.write(
&mut outgoing,
&OutgoingPacket {
typ: OutgoingPacketType::InBandNoCompletion,
size: buf.len(),
transaction_id: 0,
},
) {
Ok(range) => {
range.writer(out_ring).write(buf).unwrap();
Some(out_ring.commit_write(&mut outgoing))
}
Err(WriteError::Full(_)) => None,
Err(err) => panic!("{}", err),
}
}
fn read_simple<T: RingMem>(in_ring: &mut IncomingRing<T>) -> (Vec<u8>, bool) {
let mut incoming = in_ring.incoming().unwrap();
let msg = in_ring
.read(&mut incoming)
.unwrap()
.payload
.reader(in_ring)
.read_all()
.unwrap();
let signal = in_ring.commit_read(&mut incoming);
(msg, signal)
}
#[test]
fn test_ring() {
let rmem = FlatRingMem::new(16384);
let mut in_ring = IncomingRing::new(&rmem).unwrap();
in_ring.set_interrupt_mask(false);
let mut out_ring = OutgoingRing::new(&rmem).unwrap();
let p = &[1, 2, 3, 4, 5, 6, 7, 8];
assert!(write_simple(&mut out_ring, p).unwrap());
let (msg, signal) = read_simple(&mut in_ring);
assert!(!signal);
assert_eq!(p, &msg[..]);
}
#[test]
fn test_interrupt_mask() {
let rmem = FlatRingMem::new(16384);
let mut in_ring = IncomingRing::new(&rmem).unwrap();
let mut out_ring = OutgoingRing::new(&rmem).unwrap();
assert!(!write_simple(&mut out_ring, &[1, 2, 3]).unwrap());
assert!(!read_simple(&mut in_ring).1);
in_ring.set_interrupt_mask(false);
assert!(write_simple(&mut out_ring, &[1, 2, 3]).unwrap());
assert!(!read_simple(&mut in_ring).1);
}
#[test]
fn test_pending_send_size() {
let rmem = FlatRingMem::new(16384);
let mut in_ring = IncomingRing::new(&rmem).unwrap();
let mut out_ring = OutgoingRing::new(&rmem).unwrap();
write_simple(&mut out_ring, &[1; 4000]).unwrap();
write_simple(&mut out_ring, &[2; 4000]).unwrap();
write_simple(&mut out_ring, &[3; 4000]).unwrap();
write_simple(&mut out_ring, &[4; 4000]).unwrap();
assert!(write_simple(&mut out_ring, &[5; 4000]).is_none());
assert!(!read_simple(&mut in_ring).1);
write_simple(&mut out_ring, &[5; 4000]).unwrap();
assert!(write_simple(&mut out_ring, &[6; 4000]).is_none());
out_ring.set_pending_send_size(4024 * 2).unwrap();
assert!(!read_simple(&mut in_ring).1);
assert!(read_simple(&mut in_ring).1);
assert!(!read_simple(&mut in_ring).1);
}
}