net_backend/
lib.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! This module defines a trait and implementations thereof for network
5//! backends.
6
7#![expect(missing_docs)]
8#![forbid(unsafe_code)]
9
10pub mod loopback;
11pub mod null;
12pub mod resolve;
13pub mod tests;
14
15use async_trait::async_trait;
16use futures::FutureExt;
17use futures::StreamExt;
18use futures::TryFutureExt;
19use futures::lock::Mutex;
20use futures_concurrency::future::Race;
21use guestmem::GuestMemory;
22use guestmem::GuestMemoryError;
23use inspect::InspectMut;
24use mesh::rpc::Rpc;
25use mesh::rpc::RpcSend;
26use null::NullEndpoint;
27use pal_async::driver::Driver;
28use std::future::pending;
29use std::sync::Arc;
30use std::task::Context;
31use std::task::Poll;
32use thiserror::Error;
33
34/// Per-queue configuration.
35pub struct QueueConfig<'a> {
36    pub pool: Box<dyn BufferAccess>,
37    pub initial_rx: &'a [RxId],
38    pub driver: Box<dyn Driver>,
39}
40
41/// A network endpoint.
42#[async_trait]
43pub trait Endpoint: Send + Sync + InspectMut {
44    /// Returns an informational endpoint type.
45    fn endpoint_type(&self) -> &'static str;
46
47    /// Initializes the queues associated with the endpoint.
48    ///
49    /// `initial_rx` contains the initial set of receives buffers that are
50    /// available.
51    async fn get_queues(
52        &mut self,
53        config: Vec<QueueConfig<'_>>,
54        rss: Option<&RssConfig<'_>>,
55        queues: &mut Vec<Box<dyn Queue>>,
56    ) -> anyhow::Result<()>;
57
58    /// Stops the endpoint.
59    ///
60    /// All queues returned via `get_queues` must have been dropped.
61    async fn stop(&mut self);
62
63    /// Specifies whether packets are always completed in order.
64    fn is_ordered(&self) -> bool {
65        false
66    }
67
68    /// Specifies the supported set of transmit offloads.
69    fn tx_offload_support(&self) -> TxOffloadSupport {
70        TxOffloadSupport::default()
71    }
72
73    /// Specifies parameters related to supporting multiple queues.
74    fn multiqueue_support(&self) -> MultiQueueSupport {
75        MultiQueueSupport {
76            max_queues: 1,
77            indirection_table_size: 0,
78        }
79    }
80
81    /// If true, transmits are guaranteed to complete quickly. This is used to
82    /// allow eliding tx notifications from the guest when there are already
83    /// some tx packets in flight.
84    fn tx_fast_completions(&self) -> bool {
85        false
86    }
87
88    /// Sets the current data path for packet flow (e.g. via vmbus synthnic or through virtual function).
89    /// This is only supported for endpoints that pair with an accelerated device.
90    async fn set_data_path_to_guest_vf(&self, _use_vf: bool) -> anyhow::Result<()> {
91        Err(anyhow::Error::msg("Unsupported in current endpoint"))
92    }
93
94    async fn get_data_path_to_guest_vf(&self) -> anyhow::Result<bool> {
95        Err(anyhow::Error::msg("Unsupported in current endpoint"))
96    }
97
98    /// On completion, the return value indicates the specific endpoint action to take.
99    async fn wait_for_endpoint_action(&mut self) -> EndpointAction {
100        pending().await
101    }
102
103    /// Link speed in bps.
104    fn link_speed(&self) -> u64 {
105        // Reporting a reasonable default value (10Gbps) here that the individual endpoints
106        // can overwrite.
107        10 * 1000 * 1000 * 1000
108    }
109}
110
111/// Multi-queue related support.
112#[derive(Debug, Copy, Clone)]
113pub struct MultiQueueSupport {
114    /// The number of supported queues.
115    pub max_queues: u16,
116    /// The size of the RSS indirection table.
117    pub indirection_table_size: u16,
118}
119
120/// The set of supported transmit offloads.
121#[derive(Debug, Copy, Clone, Default)]
122pub struct TxOffloadSupport {
123    /// IPv4 header checksum offload.
124    pub ipv4_header: bool,
125    /// TCP checksum offload.
126    pub tcp: bool,
127    /// UDP checksum offload.
128    pub udp: bool,
129    /// TCP segmentation offload.
130    pub tso: bool,
131}
132
133#[derive(Debug, Clone)]
134pub struct RssConfig<'a> {
135    pub key: &'a [u8],
136    pub indirection_table: &'a [u16],
137    pub flags: u32, // TODO
138}
139
140#[derive(Error, Debug)]
141pub enum TxError {
142    #[error("error requiring queue restart. {0}")]
143    TryRestart(#[source] anyhow::Error),
144    #[error("unrecoverable error. {0}")]
145    Fatal(#[source] anyhow::Error),
146}
147
148/// A trait for sending and receiving network packets.
149#[async_trait]
150pub trait Queue: Send + InspectMut {
151    /// Updates the queue's target VP.
152    async fn update_target_vp(&mut self, target_vp: u32) {
153        let _ = target_vp;
154    }
155
156    /// Polls the queue for readiness.
157    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<()>;
158
159    /// Makes receive buffers available for use by the device.
160    fn rx_avail(&mut self, done: &[RxId]);
161
162    /// Polls the device for receives.
163    fn rx_poll(&mut self, packets: &mut [RxId]) -> anyhow::Result<usize>;
164
165    /// Posts transmits to the device.
166    ///
167    /// Returns `Ok(false)` if the segments will complete asynchronously.
168    fn tx_avail(&mut self, segments: &[TxSegment]) -> anyhow::Result<(bool, usize)>;
169
170    /// Polls the device for transmit completions.
171    fn tx_poll(&mut self, done: &mut [TxId]) -> Result<usize, TxError>;
172
173    /// Get the buffer access.
174    fn buffer_access(&mut self) -> Option<&mut dyn BufferAccess>;
175}
176
177/// A trait for providing access to guest memory buffers.
178pub trait BufferAccess: 'static + Send {
179    /// The associated guest memory accessor.
180    fn guest_memory(&self) -> &GuestMemory;
181
182    /// Writes data to the specified buffer.
183    fn write_data(&mut self, id: RxId, data: &[u8]);
184
185    /// The guest addresses of the specified buffer.
186    fn guest_addresses(&mut self, id: RxId) -> &[RxBufferSegment];
187
188    /// The capacity of the specified buffer in bytes.
189    fn capacity(&self, id: RxId) -> u32;
190
191    /// Sets the packet metadata for the receive.
192    fn write_header(&mut self, id: RxId, metadata: &RxMetadata);
193
194    /// Writes the packet header and data in a single call.
195    fn write_packet(&mut self, id: RxId, metadata: &RxMetadata, data: &[u8]) {
196        self.write_data(id, data);
197        self.write_header(id, metadata);
198    }
199}
200
201/// A receive buffer ID.
202#[derive(Debug, Copy, Clone)]
203#[repr(transparent)]
204pub struct RxId(pub u32);
205
206/// An individual segment in guest memory of a receive buffer.
207#[derive(Debug, Copy, Clone)]
208pub struct RxBufferSegment {
209    /// Guest physical address.
210    pub gpa: u64,
211    /// The number of bytes in this range.
212    pub len: u32,
213}
214
215/// Receive packet metadata.
216#[derive(Debug, Copy, Clone)]
217pub struct RxMetadata {
218    /// The offset of the packet data from the beginning of the receive buffer.
219    pub offset: usize,
220    /// The length of the packet in bytes.
221    pub len: usize,
222    /// The IP checksum validation state.
223    pub ip_checksum: RxChecksumState,
224    /// The L4 checksum validation state.
225    pub l4_checksum: RxChecksumState,
226    /// The L4 protocol.
227    pub l4_protocol: L4Protocol,
228}
229
230impl Default for RxMetadata {
231    fn default() -> Self {
232        Self {
233            offset: 0,
234            len: 0,
235            ip_checksum: RxChecksumState::Unknown,
236            l4_checksum: RxChecksumState::Unknown,
237            l4_protocol: L4Protocol::Unknown,
238        }
239    }
240}
241
242/// The "L3" protocol: the IP layer.
243#[derive(Debug, Copy, Clone, PartialEq, Eq)]
244pub enum L3Protocol {
245    Unknown,
246    Ipv4,
247    Ipv6,
248}
249
250/// The "L4" protocol: the TCP/UDP layer.
251#[derive(Debug, Copy, Clone, PartialEq, Eq)]
252pub enum L4Protocol {
253    Unknown,
254    Tcp,
255    Udp,
256}
257
258/// The receive checksum state for a packet.
259#[derive(Debug, Copy, Clone, PartialEq, Eq)]
260pub enum RxChecksumState {
261    /// The checksum was not evaluated.
262    Unknown,
263    /// The checksum value is correct.
264    Good,
265    /// The checksum value is incorrect.
266    Bad,
267    /// The checksum has been validated, but the value in the header is wrong.
268    ///
269    /// This occurs when LRO/RSC offload has been performed--multiple packet
270    /// payloads are glommed together without updating the checksum in the first
271    /// packet's header.
272    ValidatedButWrong,
273}
274
275impl RxChecksumState {
276    /// Returns true if the checksum has been validated.
277    pub fn is_valid(self) -> bool {
278        self == Self::Good || self == Self::ValidatedButWrong
279    }
280}
281
282/// A transmit ID. This may be used by multiple segments at the same time.
283#[derive(Debug, Copy, Clone)]
284#[repr(transparent)]
285pub struct TxId(pub u32);
286
287#[derive(Debug, Clone)]
288/// The segment type.
289pub enum TxSegmentType {
290    /// The start of a packet.
291    Head(TxMetadata),
292    /// A packet continuation.
293    Tail,
294}
295
296#[derive(Debug, Clone)]
297/// Transmit packet metadata.
298pub struct TxMetadata {
299    /// The transmit ID.
300    pub id: TxId,
301    /// The number of segments to follow.
302    pub segment_count: usize,
303    /// The total length of the packet in bytes.
304    pub len: usize,
305    /// Offload IPv4 header checksum calculation.
306    ///
307    /// l3_protocol, l2_len, and l3_len must be set.
308    pub offload_ip_header_checksum: bool,
309    /// Offload the TCP checksum calculation.
310    ///
311    /// l3_protocol, l2_len, and l3_len must be set.
312    pub offload_tcp_checksum: bool,
313    /// Offload the UDP checksum calculation.
314    ///
315    /// l3_protocol, l2_len, and l3_len must be set.
316    pub offload_udp_checksum: bool,
317    /// Offload the TCP segmentation, allowing packets to be larger than the
318    /// MTU.
319    ///
320    /// l3_protocol, l2_len, l3_len, l4_len, and tcp_segment_size must be set.
321    pub offload_tcp_segmentation: bool,
322    /// The L3 protocol, needed when performing any of the offloads.
323    pub l3_protocol: L3Protocol,
324    /// The length of the Ethernet frame header.
325    pub l2_len: u8,
326    /// The length of the IP header.
327    pub l3_len: u16,
328    /// The length of the TCP header.
329    pub l4_len: u8,
330    /// The maximum TCP segment size, used for segmentation.
331    pub max_tcp_segment_size: u16,
332}
333
334impl Default for TxMetadata {
335    fn default() -> Self {
336        Self {
337            id: TxId(0),
338            segment_count: 0,
339            len: 0,
340            offload_ip_header_checksum: false,
341            offload_tcp_checksum: false,
342            offload_udp_checksum: false,
343            offload_tcp_segmentation: false,
344            l3_protocol: L3Protocol::Unknown,
345            l2_len: 0,
346            l3_len: 0,
347            l4_len: 0,
348            max_tcp_segment_size: 0,
349        }
350    }
351}
352
353#[derive(Debug, Clone)]
354/// A transmit packet segment.
355pub struct TxSegment {
356    /// The segment type (head or tail).
357    pub ty: TxSegmentType,
358    /// The guest address of this segment.
359    pub gpa: u64,
360    /// The length of this segment.
361    pub len: u32,
362}
363
364/// Computes the number of packets in `segments`.
365pub fn packet_count(mut segments: &[TxSegment]) -> usize {
366    let mut packet_count = 0;
367    while let Some(head) = segments.first() {
368        let TxSegmentType::Head(metadata) = &head.ty else {
369            unreachable!()
370        };
371        segments = &segments[metadata.segment_count..];
372        packet_count += 1;
373    }
374    packet_count
375}
376
377/// Gets the next packet from a list of segments, returning the packet metadata,
378/// the segments in the packet, and the remaining segments.
379pub fn next_packet(segments: &[TxSegment]) -> (&TxMetadata, &[TxSegment], &[TxSegment]) {
380    let metadata = if let TxSegmentType::Head(metadata) = &segments[0].ty {
381        metadata
382    } else {
383        unreachable!();
384    };
385    let (this, rest) = segments.split_at(metadata.segment_count);
386    (metadata, this, rest)
387}
388
389/// Linearizes the next packet in a list of segments, returning the buffer data
390/// and advancing the segment list.
391pub fn linearize(
392    pool: &dyn BufferAccess,
393    segments: &mut &[TxSegment],
394) -> Result<Vec<u8>, GuestMemoryError> {
395    let (head, this, rest) = next_packet(segments);
396    let mut v = vec![0; head.len];
397    let mut offset = 0;
398    let mem = pool.guest_memory();
399    for segment in this {
400        let dest = &mut v[offset..offset + segment.len as usize];
401        mem.read_at(segment.gpa, dest)?;
402        offset += segment.len as usize;
403    }
404    assert_eq!(v.len(), offset);
405    *segments = rest;
406    Ok(v)
407}
408
409#[derive(PartialEq, Debug)]
410pub enum EndpointAction {
411    RestartRequired,
412    LinkStatusNotify(bool),
413}
414
415enum DisconnectableEndpointUpdate {
416    EndpointConnected(Box<dyn Endpoint>),
417    EndpointDisconnected(Rpc<(), Option<Box<dyn Endpoint>>>),
418}
419
420pub struct DisconnectableEndpointControl {
421    send_update: mesh::Sender<DisconnectableEndpointUpdate>,
422}
423
424impl DisconnectableEndpointControl {
425    pub fn connect(&mut self, endpoint: Box<dyn Endpoint>) -> anyhow::Result<()> {
426        self.send_update
427            .send(DisconnectableEndpointUpdate::EndpointConnected(endpoint));
428        Ok(())
429    }
430
431    pub async fn disconnect(&mut self) -> anyhow::Result<Option<Box<dyn Endpoint>>> {
432        self.send_update
433            .call(DisconnectableEndpointUpdate::EndpointDisconnected, ())
434            .map_err(anyhow::Error::from)
435            .await
436    }
437}
438
439pub struct DisconnectableEndpointCachedState {
440    is_ordered: bool,
441    tx_offload_support: TxOffloadSupport,
442    multiqueue_support: MultiQueueSupport,
443    tx_fast_completions: bool,
444    link_speed: u64,
445}
446
447pub struct DisconnectableEndpoint {
448    endpoint: Option<Box<dyn Endpoint>>,
449    null_endpoint: Box<dyn Endpoint>,
450    cached_state: Option<DisconnectableEndpointCachedState>,
451    receive_update: Arc<Mutex<mesh::Receiver<DisconnectableEndpointUpdate>>>,
452}
453
454impl InspectMut for DisconnectableEndpoint {
455    fn inspect_mut(&mut self, req: inspect::Request<'_>) {
456        self.current_mut().inspect_mut(req)
457    }
458}
459
460impl DisconnectableEndpoint {
461    pub fn new() -> (Self, DisconnectableEndpointControl) {
462        let (endpoint_tx, endpoint_rx) = mesh::channel();
463        let control = DisconnectableEndpointControl {
464            send_update: endpoint_tx,
465        };
466        (
467            Self {
468                endpoint: None,
469                null_endpoint: Box::new(NullEndpoint::new()),
470                cached_state: None,
471                receive_update: Arc::new(Mutex::new(endpoint_rx)),
472            },
473            control,
474        )
475    }
476
477    fn current(&self) -> &dyn Endpoint {
478        self.endpoint
479            .as_ref()
480            .unwrap_or(&self.null_endpoint)
481            .as_ref()
482    }
483
484    fn current_mut(&mut self) -> &mut dyn Endpoint {
485        self.endpoint
486            .as_mut()
487            .unwrap_or(&mut self.null_endpoint)
488            .as_mut()
489    }
490}
491
492#[async_trait]
493impl Endpoint for DisconnectableEndpoint {
494    fn endpoint_type(&self) -> &'static str {
495        self.current().endpoint_type()
496    }
497
498    async fn get_queues(
499        &mut self,
500        config: Vec<QueueConfig<'_>>,
501        rss: Option<&RssConfig<'_>>,
502        queues: &mut Vec<Box<dyn Queue>>,
503    ) -> anyhow::Result<()> {
504        self.current_mut().get_queues(config, rss, queues).await
505    }
506
507    async fn stop(&mut self) {
508        self.current_mut().stop().await
509    }
510
511    fn is_ordered(&self) -> bool {
512        self.cached_state
513            .as_ref()
514            .expect("Endpoint needs connected at least once before use")
515            .is_ordered
516    }
517
518    fn tx_offload_support(&self) -> TxOffloadSupport {
519        self.cached_state
520            .as_ref()
521            .expect("Endpoint needs connected at least once before use")
522            .tx_offload_support
523    }
524
525    fn multiqueue_support(&self) -> MultiQueueSupport {
526        self.cached_state
527            .as_ref()
528            .expect("Endpoint needs connected at least once before use")
529            .multiqueue_support
530    }
531
532    fn tx_fast_completions(&self) -> bool {
533        self.cached_state
534            .as_ref()
535            .expect("Endpoint needs connected at least once before use")
536            .tx_fast_completions
537    }
538
539    async fn set_data_path_to_guest_vf(&self, use_vf: bool) -> anyhow::Result<()> {
540        self.current().set_data_path_to_guest_vf(use_vf).await
541    }
542
543    async fn get_data_path_to_guest_vf(&self) -> anyhow::Result<bool> {
544        self.current().get_data_path_to_guest_vf().await
545    }
546
547    async fn wait_for_endpoint_action(&mut self) -> EndpointAction {
548        enum Message {
549            DisconnectableEndpointUpdate(DisconnectableEndpointUpdate),
550            UpdateFromEndpoint(EndpointAction),
551        }
552        let receiver = self.receive_update.clone();
553        let mut receive_update = receiver.lock().await;
554        let update = async {
555            match receive_update.next().await {
556                Some(m) => Message::DisconnectableEndpointUpdate(m),
557                None => {
558                    pending::<()>().await;
559                    unreachable!()
560                }
561            }
562        };
563        let ep_update = self
564            .current_mut()
565            .wait_for_endpoint_action()
566            .map(Message::UpdateFromEndpoint);
567        let m = (update, ep_update).race().await;
568        match m {
569            Message::DisconnectableEndpointUpdate(
570                DisconnectableEndpointUpdate::EndpointConnected(endpoint),
571            ) => {
572                let old_endpoint = self.endpoint.take();
573                assert!(old_endpoint.is_none());
574                self.endpoint = Some(endpoint);
575                self.cached_state = Some(DisconnectableEndpointCachedState {
576                    is_ordered: self.current().is_ordered(),
577                    tx_offload_support: self.current().tx_offload_support(),
578                    multiqueue_support: self.current().multiqueue_support(),
579                    tx_fast_completions: self.current().tx_fast_completions(),
580                    link_speed: self.current().link_speed(),
581                });
582                EndpointAction::RestartRequired
583            }
584            Message::DisconnectableEndpointUpdate(
585                DisconnectableEndpointUpdate::EndpointDisconnected(rpc),
586            ) => {
587                let old_endpoint = self.endpoint.take();
588                self.endpoint = None;
589                rpc.handle(async |_| old_endpoint).await;
590                EndpointAction::RestartRequired
591            }
592            Message::UpdateFromEndpoint(update) => update,
593        }
594    }
595
596    fn link_speed(&self) -> u64 {
597        self.cached_state
598            .as_ref()
599            .expect("Endpoint needs connected at least once before use")
600            .link_speed
601    }
602}