net_backend/
lib.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! This module defines a trait and implementations thereof for network
5//! backends.
6
7#![expect(missing_docs)]
8#![forbid(unsafe_code)]
9
10pub mod loopback;
11pub mod null;
12pub mod resolve;
13pub mod tests;
14
15use async_trait::async_trait;
16use bitfield_struct::bitfield;
17use futures::FutureExt;
18use futures::StreamExt;
19use futures::TryFutureExt;
20use futures::lock::Mutex;
21use futures_concurrency::future::Race;
22use guestmem::GuestMemory;
23use guestmem::GuestMemoryError;
24use inspect::InspectMut;
25use inspect_counters::Counter;
26use mesh::rpc::Rpc;
27use mesh::rpc::RpcSend;
28use null::NullEndpoint;
29use pal_async::driver::Driver;
30use std::future::pending;
31use std::sync::Arc;
32use std::task::Context;
33use std::task::Poll;
34use thiserror::Error;
35
36/// Per-queue configuration.
37pub struct QueueConfig<'a> {
38    pub pool: Box<dyn BufferAccess>,
39    pub initial_rx: &'a [RxId],
40    pub driver: Box<dyn Driver>,
41}
42
43/// A network endpoint.
44#[async_trait]
45pub trait Endpoint: Send + Sync + InspectMut {
46    /// Returns an informational endpoint type.
47    fn endpoint_type(&self) -> &'static str;
48
49    /// Initializes the queues associated with the endpoint.
50    ///
51    /// `initial_rx` contains the initial set of receives buffers that are
52    /// available.
53    async fn get_queues(
54        &mut self,
55        config: Vec<QueueConfig<'_>>,
56        rss: Option<&RssConfig<'_>>,
57        queues: &mut Vec<Box<dyn Queue>>,
58    ) -> anyhow::Result<()>;
59
60    /// Stops the endpoint.
61    ///
62    /// All queues returned via `get_queues` must have been dropped.
63    async fn stop(&mut self);
64
65    /// Specifies whether packets are always completed in order.
66    fn is_ordered(&self) -> bool {
67        false
68    }
69
70    /// Specifies the supported set of transmit offloads.
71    fn tx_offload_support(&self) -> TxOffloadSupport {
72        TxOffloadSupport::default()
73    }
74
75    /// Specifies parameters related to supporting multiple queues.
76    fn multiqueue_support(&self) -> MultiQueueSupport {
77        MultiQueueSupport {
78            max_queues: 1,
79            indirection_table_size: 0,
80        }
81    }
82
83    /// If true, transmits are guaranteed to complete quickly. This is used to
84    /// allow eliding tx notifications from the guest when there are already
85    /// some tx packets in flight.
86    fn tx_fast_completions(&self) -> bool {
87        false
88    }
89
90    /// Sets the current data path for packet flow (e.g. via vmbus synthnic or through virtual function).
91    /// This is only supported for endpoints that pair with an accelerated device.
92    async fn set_data_path_to_guest_vf(&self, _use_vf: bool) -> anyhow::Result<()> {
93        Err(anyhow::Error::msg("Unsupported in current endpoint"))
94    }
95
96    async fn get_data_path_to_guest_vf(&self) -> anyhow::Result<bool> {
97        Err(anyhow::Error::msg("Unsupported in current endpoint"))
98    }
99
100    /// On completion, the return value indicates the specific endpoint action to take.
101    async fn wait_for_endpoint_action(&mut self) -> EndpointAction {
102        pending().await
103    }
104
105    /// Link speed in bps.
106    fn link_speed(&self) -> u64 {
107        // Reporting a reasonable default value (10Gbps) here that the individual endpoints
108        // can overwrite.
109        10 * 1000 * 1000 * 1000
110    }
111}
112
113/// Multi-queue related support.
114#[derive(Debug, Copy, Clone)]
115pub struct MultiQueueSupport {
116    /// The number of supported queues.
117    pub max_queues: u16,
118    /// The size of the RSS indirection table.
119    pub indirection_table_size: u16,
120}
121
122/// The set of supported transmit offloads.
123#[derive(Debug, Copy, Clone, Default)]
124pub struct TxOffloadSupport {
125    /// IPv4 header checksum offload.
126    pub ipv4_header: bool,
127    /// TCP checksum offload.
128    pub tcp: bool,
129    /// UDP checksum offload.
130    pub udp: bool,
131    /// TCP segmentation offload.
132    pub tso: bool,
133}
134
135#[derive(Debug, Clone)]
136pub struct RssConfig<'a> {
137    pub key: &'a [u8],
138    pub indirection_table: &'a [u16],
139    pub flags: u32, // TODO
140}
141
142#[derive(Error, Debug)]
143pub enum TxError {
144    #[error("error requiring queue restart. {0}")]
145    TryRestart(#[source] anyhow::Error),
146    #[error("unrecoverable error. {0}")]
147    Fatal(#[source] anyhow::Error),
148}
149pub trait BackendQueueStats {
150    fn rx_errors(&self) -> Counter;
151    fn tx_errors(&self) -> Counter;
152    fn rx_packets(&self) -> Counter;
153    fn tx_packets(&self) -> Counter;
154}
155
156/// A trait for sending and receiving network packets.
157#[async_trait]
158pub trait Queue: Send + InspectMut {
159    /// Updates the queue's target VP.
160    async fn update_target_vp(&mut self, target_vp: u32) {
161        let _ = target_vp;
162    }
163
164    /// Polls the queue for readiness.
165    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<()>;
166
167    /// Makes receive buffers available for use by the device.
168    fn rx_avail(&mut self, done: &[RxId]);
169
170    /// Polls the device for receives.
171    fn rx_poll(&mut self, packets: &mut [RxId]) -> anyhow::Result<usize>;
172
173    /// Posts transmits to the device.
174    ///
175    /// Returns `Ok(false)` if the segments will complete asynchronously.
176    fn tx_avail(&mut self, segments: &[TxSegment]) -> anyhow::Result<(bool, usize)>;
177
178    /// Polls the device for transmit completions.
179    fn tx_poll(&mut self, done: &mut [TxId]) -> Result<usize, TxError>;
180
181    /// Get the buffer access.
182    fn buffer_access(&mut self) -> Option<&mut dyn BufferAccess>;
183
184    /// Get queue statistics
185    fn queue_stats(&self) -> Option<&dyn BackendQueueStats> {
186        None // Default implementation - not all queues implement stats
187    }
188}
189
190/// A trait for providing access to guest memory buffers.
191pub trait BufferAccess: 'static + Send {
192    /// The associated guest memory accessor.
193    fn guest_memory(&self) -> &GuestMemory;
194
195    /// Writes data to the specified buffer.
196    fn write_data(&mut self, id: RxId, data: &[u8]);
197
198    /// The guest addresses of the specified buffer.
199    fn guest_addresses(&mut self, id: RxId) -> &[RxBufferSegment];
200
201    /// The capacity of the specified buffer in bytes.
202    fn capacity(&self, id: RxId) -> u32;
203
204    /// Sets the packet metadata for the receive.
205    fn write_header(&mut self, id: RxId, metadata: &RxMetadata);
206
207    /// Writes the packet header and data in a single call.
208    fn write_packet(&mut self, id: RxId, metadata: &RxMetadata, data: &[u8]) {
209        self.write_data(id, data);
210        self.write_header(id, metadata);
211    }
212}
213
214/// A receive buffer ID.
215#[derive(Debug, Copy, Clone)]
216#[repr(transparent)]
217pub struct RxId(pub u32);
218
219/// An individual segment in guest memory of a receive buffer.
220#[derive(Debug, Copy, Clone)]
221pub struct RxBufferSegment {
222    /// Guest physical address.
223    pub gpa: u64,
224    /// The number of bytes in this range.
225    pub len: u32,
226}
227
228/// Receive packet metadata.
229#[derive(Debug, Copy, Clone)]
230pub struct RxMetadata {
231    /// The offset of the packet data from the beginning of the receive buffer.
232    pub offset: usize,
233    /// The length of the packet in bytes.
234    pub len: usize,
235    /// The IP checksum validation state.
236    pub ip_checksum: RxChecksumState,
237    /// The L4 checksum validation state.
238    pub l4_checksum: RxChecksumState,
239    /// The L4 protocol.
240    pub l4_protocol: L4Protocol,
241}
242
243impl Default for RxMetadata {
244    fn default() -> Self {
245        Self {
246            offset: 0,
247            len: 0,
248            ip_checksum: RxChecksumState::Unknown,
249            l4_checksum: RxChecksumState::Unknown,
250            l4_protocol: L4Protocol::Unknown,
251        }
252    }
253}
254
255/// The "L3" protocol: the IP layer.
256#[derive(Debug, Copy, Clone, PartialEq, Eq)]
257pub enum L3Protocol {
258    Unknown,
259    Ipv4,
260    Ipv6,
261}
262
263/// The "L4" protocol: the TCP/UDP layer.
264#[derive(Debug, Copy, Clone, PartialEq, Eq)]
265pub enum L4Protocol {
266    Unknown,
267    Tcp,
268    Udp,
269}
270
271/// The receive checksum state for a packet.
272#[derive(Debug, Copy, Clone, PartialEq, Eq)]
273pub enum RxChecksumState {
274    /// The checksum was not evaluated.
275    Unknown,
276    /// The checksum value is correct.
277    Good,
278    /// The checksum value is incorrect.
279    Bad,
280    /// The checksum has been validated, but the value in the header is wrong.
281    ///
282    /// This occurs when LRO/RSC offload has been performed--multiple packet
283    /// payloads are glommed together without updating the checksum in the first
284    /// packet's header.
285    ValidatedButWrong,
286}
287
288impl RxChecksumState {
289    /// Returns true if the checksum has been validated.
290    pub fn is_valid(self) -> bool {
291        self == Self::Good || self == Self::ValidatedButWrong
292    }
293}
294
295/// A transmit ID. This may be used by multiple segments at the same time.
296#[derive(Debug, Copy, Clone)]
297#[repr(transparent)]
298pub struct TxId(pub u32);
299
300#[derive(Debug, Clone)]
301/// The segment type.
302pub enum TxSegmentType {
303    /// The start of a packet.
304    Head(TxMetadata),
305    /// A packet continuation.
306    Tail,
307}
308
309#[derive(Debug, Clone)]
310/// Transmit packet metadata.
311pub struct TxMetadata {
312    /// The transmit ID.
313    pub id: TxId,
314    /// The number of segments, including this one.
315    pub segment_count: u8,
316    /// Flags.
317    pub flags: TxFlags,
318    /// The total length of the packet in bytes.
319    pub len: u32,
320    /// The length of the Ethernet frame header. Only guaranteed to be set if
321    /// various offload flags are set.
322    pub l2_len: u8,
323    /// The length of the IP header. Only guaranteed to be set if various
324    /// offload flags are set.
325    pub l3_len: u16,
326    /// The length of the TCP header. Only guaranteed to be set if various
327    /// offload flags are set.
328    pub l4_len: u8,
329    /// The maximum TCP segment size, used for segmentation. Only guaranteed to
330    /// be set if [`TxFlags::offload_tcp_segmentation`] is set.
331    pub max_tcp_segment_size: u16,
332}
333
334/// Flags affecting transmit behavior.
335#[bitfield(u8)]
336pub struct TxFlags {
337    /// Offload IPv4 header checksum calculation.
338    ///
339    /// `l3_protocol`, `l2_len`, and `l3_len` must be set.
340    pub offload_ip_header_checksum: bool,
341    /// Offload the TCP checksum calculation.
342    ///
343    /// `l3_protocol`, `l2_len`, and `l3_len` must be set.
344    pub offload_tcp_checksum: bool,
345    /// Offload the UDP checksum calculation.
346    ///
347    /// `l3_protocol`, `l2_len`, and `l3_len` must be set.
348    pub offload_udp_checksum: bool,
349    /// Offload the TCP segmentation, allowing packets to be larger than the
350    /// MTU.
351    ///
352    /// `l3_protocol`, `l2_len`, `l3_len`, `l4_len`, and `tcp_segment_size` must
353    /// be set.
354    pub offload_tcp_segmentation: bool,
355    /// If true, the packet is IPv4.
356    pub is_ipv4: bool,
357    /// If true, the packet is IPv6. Mutually exclusive with `is_ipv4`.
358    pub is_ipv6: bool,
359    #[bits(2)]
360    _reserved: u8,
361}
362
363impl Default for TxMetadata {
364    fn default() -> Self {
365        Self {
366            id: TxId(0),
367            segment_count: 0,
368            len: 0,
369            flags: TxFlags::new(),
370            l2_len: 0,
371            l3_len: 0,
372            l4_len: 0,
373            max_tcp_segment_size: 0,
374        }
375    }
376}
377
378#[derive(Debug, Clone)]
379/// A transmit packet segment.
380pub struct TxSegment {
381    /// The segment type (head or tail).
382    pub ty: TxSegmentType,
383    /// The guest address of this segment.
384    pub gpa: u64,
385    /// The length of this segment.
386    pub len: u32,
387}
388
389/// Computes the number of packets in `segments`.
390pub fn packet_count(mut segments: &[TxSegment]) -> usize {
391    let mut packet_count = 0;
392    while let Some(head) = segments.first() {
393        let TxSegmentType::Head(metadata) = &head.ty else {
394            unreachable!()
395        };
396        segments = &segments[metadata.segment_count as usize..];
397        packet_count += 1;
398    }
399    packet_count
400}
401
402/// Gets the next packet from a list of segments, returning the packet metadata,
403/// the segments in the packet, and the remaining segments.
404pub fn next_packet(segments: &[TxSegment]) -> (&TxMetadata, &[TxSegment], &[TxSegment]) {
405    let metadata = if let TxSegmentType::Head(metadata) = &segments[0].ty {
406        metadata
407    } else {
408        unreachable!();
409    };
410    let (this, rest) = segments.split_at(metadata.segment_count.into());
411    (metadata, this, rest)
412}
413
414/// Linearizes the next packet in a list of segments, returning the buffer data
415/// and advancing the segment list.
416pub fn linearize(
417    pool: &dyn BufferAccess,
418    segments: &mut &[TxSegment],
419) -> Result<Vec<u8>, GuestMemoryError> {
420    let (head, this, rest) = next_packet(segments);
421    let mut v = vec![0; head.len as usize];
422    let mut offset = 0;
423    let mem = pool.guest_memory();
424    for segment in this {
425        let dest = &mut v[offset..offset + segment.len as usize];
426        mem.read_at(segment.gpa, dest)?;
427        offset += segment.len as usize;
428    }
429    assert_eq!(v.len(), offset);
430    *segments = rest;
431    Ok(v)
432}
433
434#[derive(PartialEq, Debug)]
435pub enum EndpointAction {
436    RestartRequired,
437    LinkStatusNotify(bool),
438}
439
440enum DisconnectableEndpointUpdate {
441    EndpointConnected(Box<dyn Endpoint>),
442    EndpointDisconnected(Rpc<(), Option<Box<dyn Endpoint>>>),
443}
444
445pub struct DisconnectableEndpointControl {
446    send_update: mesh::Sender<DisconnectableEndpointUpdate>,
447}
448
449impl DisconnectableEndpointControl {
450    pub fn connect(&mut self, endpoint: Box<dyn Endpoint>) -> anyhow::Result<()> {
451        self.send_update
452            .send(DisconnectableEndpointUpdate::EndpointConnected(endpoint));
453        Ok(())
454    }
455
456    pub async fn disconnect(&mut self) -> anyhow::Result<Option<Box<dyn Endpoint>>> {
457        self.send_update
458            .call(DisconnectableEndpointUpdate::EndpointDisconnected, ())
459            .map_err(anyhow::Error::from)
460            .await
461    }
462}
463
464pub struct DisconnectableEndpointCachedState {
465    is_ordered: bool,
466    tx_offload_support: TxOffloadSupport,
467    multiqueue_support: MultiQueueSupport,
468    tx_fast_completions: bool,
469    link_speed: u64,
470}
471
472pub struct DisconnectableEndpoint {
473    endpoint: Option<Box<dyn Endpoint>>,
474    null_endpoint: Box<dyn Endpoint>,
475    cached_state: Option<DisconnectableEndpointCachedState>,
476    receive_update: Arc<Mutex<mesh::Receiver<DisconnectableEndpointUpdate>>>,
477    notify_disconnect_complete: Option<(
478        Rpc<(), Option<Box<dyn Endpoint>>>,
479        Option<Box<dyn Endpoint>>,
480    )>,
481}
482
483impl InspectMut for DisconnectableEndpoint {
484    fn inspect_mut(&mut self, req: inspect::Request<'_>) {
485        self.current_mut().inspect_mut(req)
486    }
487}
488
489impl DisconnectableEndpoint {
490    pub fn new() -> (Self, DisconnectableEndpointControl) {
491        let (endpoint_tx, endpoint_rx) = mesh::channel();
492        let control = DisconnectableEndpointControl {
493            send_update: endpoint_tx,
494        };
495        (
496            Self {
497                endpoint: None,
498                null_endpoint: Box::new(NullEndpoint::new()),
499                cached_state: None,
500                receive_update: Arc::new(Mutex::new(endpoint_rx)),
501                notify_disconnect_complete: None,
502            },
503            control,
504        )
505    }
506
507    fn current(&self) -> &dyn Endpoint {
508        self.endpoint
509            .as_ref()
510            .unwrap_or(&self.null_endpoint)
511            .as_ref()
512    }
513
514    fn current_mut(&mut self) -> &mut dyn Endpoint {
515        self.endpoint
516            .as_mut()
517            .unwrap_or(&mut self.null_endpoint)
518            .as_mut()
519    }
520}
521
522#[async_trait]
523impl Endpoint for DisconnectableEndpoint {
524    fn endpoint_type(&self) -> &'static str {
525        self.current().endpoint_type()
526    }
527
528    async fn get_queues(
529        &mut self,
530        config: Vec<QueueConfig<'_>>,
531        rss: Option<&RssConfig<'_>>,
532        queues: &mut Vec<Box<dyn Queue>>,
533    ) -> anyhow::Result<()> {
534        self.current_mut().get_queues(config, rss, queues).await
535    }
536
537    async fn stop(&mut self) {
538        self.current_mut().stop().await
539    }
540
541    fn is_ordered(&self) -> bool {
542        self.cached_state
543            .as_ref()
544            .expect("Endpoint needs connected at least once before use")
545            .is_ordered
546    }
547
548    fn tx_offload_support(&self) -> TxOffloadSupport {
549        self.cached_state
550            .as_ref()
551            .expect("Endpoint needs connected at least once before use")
552            .tx_offload_support
553    }
554
555    fn multiqueue_support(&self) -> MultiQueueSupport {
556        self.cached_state
557            .as_ref()
558            .expect("Endpoint needs connected at least once before use")
559            .multiqueue_support
560    }
561
562    fn tx_fast_completions(&self) -> bool {
563        self.cached_state
564            .as_ref()
565            .expect("Endpoint needs connected at least once before use")
566            .tx_fast_completions
567    }
568
569    async fn set_data_path_to_guest_vf(&self, use_vf: bool) -> anyhow::Result<()> {
570        self.current().set_data_path_to_guest_vf(use_vf).await
571    }
572
573    async fn get_data_path_to_guest_vf(&self) -> anyhow::Result<bool> {
574        self.current().get_data_path_to_guest_vf().await
575    }
576
577    async fn wait_for_endpoint_action(&mut self) -> EndpointAction {
578        // If the previous message disconnected the endpoint, notify the caller
579        // that the operation has completed, returning the old endpoint.
580        if let Some((rpc, old_endpoint)) = self.notify_disconnect_complete.take() {
581            rpc.handle(async |_| old_endpoint).await;
582        }
583
584        enum Message {
585            DisconnectableEndpointUpdate(DisconnectableEndpointUpdate),
586            UpdateFromEndpoint(EndpointAction),
587        }
588        let receiver = self.receive_update.clone();
589        let mut receive_update = receiver.lock().await;
590        let update = async {
591            match receive_update.next().await {
592                Some(m) => Message::DisconnectableEndpointUpdate(m),
593                None => {
594                    pending::<()>().await;
595                    unreachable!()
596                }
597            }
598        };
599        let ep_update = self
600            .current_mut()
601            .wait_for_endpoint_action()
602            .map(Message::UpdateFromEndpoint);
603        let m = (update, ep_update).race().await;
604        match m {
605            Message::DisconnectableEndpointUpdate(
606                DisconnectableEndpointUpdate::EndpointConnected(endpoint),
607            ) => {
608                let old_endpoint = self.endpoint.take();
609                assert!(old_endpoint.is_none());
610                self.endpoint = Some(endpoint);
611                self.cached_state = Some(DisconnectableEndpointCachedState {
612                    is_ordered: self.current().is_ordered(),
613                    tx_offload_support: self.current().tx_offload_support(),
614                    multiqueue_support: self.current().multiqueue_support(),
615                    tx_fast_completions: self.current().tx_fast_completions(),
616                    link_speed: self.current().link_speed(),
617                });
618                EndpointAction::RestartRequired
619            }
620            Message::DisconnectableEndpointUpdate(
621                DisconnectableEndpointUpdate::EndpointDisconnected(rpc),
622            ) => {
623                let old_endpoint = self.endpoint.take();
624                // Wait until the next call into this function to notify the
625                // caller that the operation has completed. This makes it more
626                // likely that the endpoint is no longer referenced (old queues
627                // have been disposed, etc.).
628                self.notify_disconnect_complete = Some((rpc, old_endpoint));
629                EndpointAction::RestartRequired
630            }
631            Message::UpdateFromEndpoint(update) => update,
632        }
633    }
634
635    fn link_speed(&self) -> u64 {
636        self.cached_state
637            .as_ref()
638            .expect("Endpoint needs connected at least once before use")
639            .link_speed
640    }
641}