#![expect(missing_docs)]
pub mod loopback;
pub mod null;
pub mod resolve;
pub mod tests;
use async_trait::async_trait;
use futures::FutureExt;
use futures::StreamExt;
use futures::TryFutureExt;
use futures::lock::Mutex;
use futures_concurrency::future::Race;
use guestmem::GuestMemory;
use guestmem::GuestMemoryError;
use inspect::InspectMut;
use mesh::rpc::Rpc;
use mesh::rpc::RpcSend;
use null::NullEndpoint;
use pal_async::driver::Driver;
use std::future::pending;
use std::sync::Arc;
use std::task::Context;
use std::task::Poll;
pub struct QueueConfig<'a> {
pub pool: Box<dyn BufferAccess>,
pub initial_rx: &'a [RxId],
pub driver: Box<dyn Driver>,
}
#[async_trait]
pub trait Endpoint: Send + Sync + InspectMut {
fn endpoint_type(&self) -> &'static str;
async fn get_queues(
&mut self,
config: Vec<QueueConfig<'_>>,
rss: Option<&RssConfig<'_>>,
queues: &mut Vec<Box<dyn Queue>>,
) -> anyhow::Result<()>;
async fn stop(&mut self);
fn is_ordered(&self) -> bool {
false
}
fn tx_offload_support(&self) -> TxOffloadSupport {
TxOffloadSupport::default()
}
fn multiqueue_support(&self) -> MultiQueueSupport {
MultiQueueSupport {
max_queues: 1,
indirection_table_size: 0,
}
}
fn tx_fast_completions(&self) -> bool {
false
}
async fn set_data_path_to_guest_vf(&self, _use_vf: bool) -> anyhow::Result<()> {
Err(anyhow::Error::msg("Unsupported in current endpoint"))
}
async fn get_data_path_to_guest_vf(&self) -> anyhow::Result<bool> {
Err(anyhow::Error::msg("Unsupported in current endpoint"))
}
async fn wait_for_endpoint_action(&mut self) -> EndpointAction {
pending().await
}
fn link_speed(&self) -> u64 {
10 * 1000 * 1000 * 1000
}
}
#[derive(Debug, Copy, Clone)]
pub struct MultiQueueSupport {
pub max_queues: u16,
pub indirection_table_size: u16,
}
#[derive(Debug, Copy, Clone, Default)]
pub struct TxOffloadSupport {
pub ipv4_header: bool,
pub tcp: bool,
pub udp: bool,
pub tso: bool,
}
#[derive(Debug, Clone)]
pub struct RssConfig<'a> {
pub key: &'a [u8],
pub indirection_table: &'a [u16],
pub flags: u32, }
#[async_trait]
pub trait Queue: Send + InspectMut {
async fn update_target_vp(&mut self, target_vp: u32) {
let _ = target_vp;
}
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<()>;
fn rx_avail(&mut self, done: &[RxId]);
fn rx_poll(&mut self, packets: &mut [RxId]) -> anyhow::Result<usize>;
fn tx_avail(&mut self, segments: &[TxSegment]) -> anyhow::Result<(bool, usize)>;
fn tx_poll(&mut self, done: &mut [TxId]) -> anyhow::Result<usize>;
fn buffer_access(&mut self) -> Option<&mut dyn BufferAccess>;
}
pub trait BufferAccess: 'static + Send {
fn guest_memory(&self) -> &GuestMemory;
fn write_data(&mut self, id: RxId, data: &[u8]);
fn guest_addresses(&mut self, id: RxId) -> &[RxBufferSegment];
fn capacity(&self, id: RxId) -> u32;
fn write_header(&mut self, id: RxId, metadata: &RxMetadata);
fn write_packet(&mut self, id: RxId, metadata: &RxMetadata, data: &[u8]) {
self.write_data(id, data);
self.write_header(id, metadata);
}
}
#[derive(Debug, Copy, Clone)]
#[repr(transparent)]
pub struct RxId(pub u32);
#[derive(Debug, Copy, Clone)]
pub struct RxBufferSegment {
pub gpa: u64,
pub len: u32,
}
#[derive(Debug, Copy, Clone)]
pub struct RxMetadata {
pub offset: usize,
pub len: usize,
pub ip_checksum: RxChecksumState,
pub l4_checksum: RxChecksumState,
pub l4_protocol: L4Protocol,
}
impl Default for RxMetadata {
fn default() -> Self {
Self {
offset: 0,
len: 0,
ip_checksum: RxChecksumState::Unknown,
l4_checksum: RxChecksumState::Unknown,
l4_protocol: L4Protocol::Unknown,
}
}
}
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum L3Protocol {
Unknown,
Ipv4,
Ipv6,
}
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum L4Protocol {
Unknown,
Tcp,
Udp,
}
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum RxChecksumState {
Unknown,
Good,
Bad,
ValidatedButWrong,
}
impl RxChecksumState {
pub fn is_valid(self) -> bool {
self == Self::Good || self == Self::ValidatedButWrong
}
}
#[derive(Debug, Copy, Clone)]
#[repr(transparent)]
pub struct TxId(pub u32);
#[derive(Debug, Clone)]
pub enum TxSegmentType {
Head(TxMetadata),
Tail,
}
#[derive(Debug, Clone)]
pub struct TxMetadata {
pub id: TxId,
pub segment_count: usize,
pub len: usize,
pub offload_ip_header_checksum: bool,
pub offload_tcp_checksum: bool,
pub offload_udp_checksum: bool,
pub offload_tcp_segmentation: bool,
pub l3_protocol: L3Protocol,
pub l2_len: u8,
pub l3_len: u16,
pub l4_len: u8,
pub max_tcp_segment_size: u16,
}
impl Default for TxMetadata {
fn default() -> Self {
Self {
id: TxId(0),
segment_count: 0,
len: 0,
offload_ip_header_checksum: false,
offload_tcp_checksum: false,
offload_udp_checksum: false,
offload_tcp_segmentation: false,
l3_protocol: L3Protocol::Unknown,
l2_len: 0,
l3_len: 0,
l4_len: 0,
max_tcp_segment_size: 0,
}
}
}
#[derive(Debug, Clone)]
pub struct TxSegment {
pub ty: TxSegmentType,
pub gpa: u64,
pub len: u32,
}
pub fn packet_count(mut segments: &[TxSegment]) -> usize {
let mut packet_count = 0;
while let Some(head) = segments.first() {
let TxSegmentType::Head(metadata) = &head.ty else {
unreachable!()
};
segments = &segments[metadata.segment_count..];
packet_count += 1;
}
packet_count
}
pub fn next_packet(segments: &[TxSegment]) -> (&TxMetadata, &[TxSegment], &[TxSegment]) {
let metadata = if let TxSegmentType::Head(metadata) = &segments[0].ty {
metadata
} else {
unreachable!();
};
let (this, rest) = segments.split_at(metadata.segment_count);
(metadata, this, rest)
}
pub fn linearize(
pool: &dyn BufferAccess,
segments: &mut &[TxSegment],
) -> Result<Vec<u8>, GuestMemoryError> {
let (head, this, rest) = next_packet(segments);
let mut v = vec![0; head.len];
let mut offset = 0;
let mem = pool.guest_memory();
for segment in this {
let dest = &mut v[offset..offset + segment.len as usize];
mem.read_at(segment.gpa, dest)?;
offset += segment.len as usize;
}
assert_eq!(v.len(), offset);
*segments = rest;
Ok(v)
}
#[derive(PartialEq, Debug)]
pub enum EndpointAction {
RestartRequired,
LinkStatusNotify(bool),
}
enum DisconnectableEndpointUpdate {
EndpointConnected(Box<dyn Endpoint>),
EndpointDisconnected(Rpc<(), Option<Box<dyn Endpoint>>>),
}
pub struct DisconnectableEndpointControl {
send_update: mesh::Sender<DisconnectableEndpointUpdate>,
}
impl DisconnectableEndpointControl {
pub fn connect(&mut self, endpoint: Box<dyn Endpoint>) -> anyhow::Result<()> {
self.send_update
.send(DisconnectableEndpointUpdate::EndpointConnected(endpoint));
Ok(())
}
pub async fn disconnect(&mut self) -> anyhow::Result<Option<Box<dyn Endpoint>>> {
self.send_update
.call(DisconnectableEndpointUpdate::EndpointDisconnected, ())
.map_err(anyhow::Error::from)
.await
}
}
pub struct DisconnectableEndpointCachedState {
is_ordered: bool,
tx_offload_support: TxOffloadSupport,
multiqueue_support: MultiQueueSupport,
tx_fast_completions: bool,
link_speed: u64,
}
pub struct DisconnectableEndpoint {
endpoint: Option<Box<dyn Endpoint>>,
null_endpoint: Box<dyn Endpoint>,
cached_state: Option<DisconnectableEndpointCachedState>,
receive_update: Arc<Mutex<mesh::Receiver<DisconnectableEndpointUpdate>>>,
}
impl InspectMut for DisconnectableEndpoint {
fn inspect_mut(&mut self, req: inspect::Request<'_>) {
self.current_mut().inspect_mut(req)
}
}
impl DisconnectableEndpoint {
pub fn new() -> (Self, DisconnectableEndpointControl) {
let (endpoint_tx, endpoint_rx) = mesh::channel();
let control = DisconnectableEndpointControl {
send_update: endpoint_tx,
};
(
Self {
endpoint: None,
null_endpoint: Box::new(NullEndpoint::new()),
cached_state: None,
receive_update: Arc::new(Mutex::new(endpoint_rx)),
},
control,
)
}
fn current(&self) -> &dyn Endpoint {
self.endpoint
.as_ref()
.unwrap_or(&self.null_endpoint)
.as_ref()
}
fn current_mut(&mut self) -> &mut dyn Endpoint {
self.endpoint
.as_mut()
.unwrap_or(&mut self.null_endpoint)
.as_mut()
}
}
#[async_trait]
impl Endpoint for DisconnectableEndpoint {
fn endpoint_type(&self) -> &'static str {
self.current().endpoint_type()
}
async fn get_queues(
&mut self,
config: Vec<QueueConfig<'_>>,
rss: Option<&RssConfig<'_>>,
queues: &mut Vec<Box<dyn Queue>>,
) -> anyhow::Result<()> {
self.current_mut().get_queues(config, rss, queues).await
}
async fn stop(&mut self) {
self.current_mut().stop().await
}
fn is_ordered(&self) -> bool {
self.cached_state
.as_ref()
.expect("Endpoint needs connected at least once before use")
.is_ordered
}
fn tx_offload_support(&self) -> TxOffloadSupport {
self.cached_state
.as_ref()
.expect("Endpoint needs connected at least once before use")
.tx_offload_support
}
fn multiqueue_support(&self) -> MultiQueueSupport {
self.cached_state
.as_ref()
.expect("Endpoint needs connected at least once before use")
.multiqueue_support
}
fn tx_fast_completions(&self) -> bool {
self.cached_state
.as_ref()
.expect("Endpoint needs connected at least once before use")
.tx_fast_completions
}
async fn set_data_path_to_guest_vf(&self, use_vf: bool) -> anyhow::Result<()> {
self.current().set_data_path_to_guest_vf(use_vf).await
}
async fn get_data_path_to_guest_vf(&self) -> anyhow::Result<bool> {
self.current().get_data_path_to_guest_vf().await
}
async fn wait_for_endpoint_action(&mut self) -> EndpointAction {
enum Message {
DisconnectableEndpointUpdate(DisconnectableEndpointUpdate),
UpdateFromEndpoint(EndpointAction),
}
let receiver = self.receive_update.clone();
let mut receive_update = receiver.lock().await;
let update = async {
match receive_update.next().await {
Some(m) => Message::DisconnectableEndpointUpdate(m),
None => {
pending::<()>().await;
unreachable!()
}
}
};
let ep_update = self
.current_mut()
.wait_for_endpoint_action()
.map(Message::UpdateFromEndpoint);
let m = (update, ep_update).race().await;
match m {
Message::DisconnectableEndpointUpdate(
DisconnectableEndpointUpdate::EndpointConnected(endpoint),
) => {
let old_endpoint = self.endpoint.take();
assert!(old_endpoint.is_none());
self.endpoint = Some(endpoint);
self.cached_state = Some(DisconnectableEndpointCachedState {
is_ordered: self.current().is_ordered(),
tx_offload_support: self.current().tx_offload_support(),
multiqueue_support: self.current().multiqueue_support(),
tx_fast_completions: self.current().tx_fast_completions(),
link_speed: self.current().link_speed(),
});
EndpointAction::RestartRequired
}
Message::DisconnectableEndpointUpdate(
DisconnectableEndpointUpdate::EndpointDisconnected(rpc),
) => {
let old_endpoint = self.endpoint.take();
self.endpoint = None;
rpc.handle(async |_| old_endpoint).await;
EndpointAction::RestartRequired
}
Message::UpdateFromEndpoint(update) => update,
}
}
fn link_speed(&self) -> u64 {
self.cached_state
.as_ref()
.expect("Endpoint needs connected at least once before use")
.link_speed
}
}