net_consomme/
lib.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4#![expect(missing_docs)]
5#![forbid(unsafe_code)]
6
7pub mod resolver;
8
9use async_trait::async_trait;
10use consomme::ChecksumState;
11use consomme::Consomme;
12use consomme::ConsommeControl;
13use consomme::ConsommeState;
14use inspect::Inspect;
15use inspect::InspectMut;
16use inspect_counters::Counter;
17use net_backend::BufferAccess;
18use net_backend::L4Protocol;
19use net_backend::QueueConfig;
20use net_backend::RssConfig;
21use net_backend::RxChecksumState;
22use net_backend::RxId;
23use net_backend::RxMetadata;
24use net_backend::TxError;
25use net_backend::TxId;
26use net_backend::TxOffloadSupport;
27use net_backend::TxSegment;
28use net_backend::TxSegmentType;
29use pal_async::driver::Driver;
30use parking_lot::Mutex;
31use std::collections::VecDeque;
32use std::sync::Arc;
33use std::task::Context;
34use std::task::Poll;
35
36pub struct ConsommeEndpoint {
37    consomme: Arc<Mutex<Option<Consomme>>>,
38}
39
40impl ConsommeEndpoint {
41    pub fn new() -> Result<Self, consomme::Error> {
42        Ok(Self {
43            consomme: Arc::new(Mutex::new(Some(Consomme::new()?))),
44        })
45    }
46
47    pub fn new_with_state(state: ConsommeState) -> Self {
48        Self {
49            consomme: Arc::new(Mutex::new(Some(Consomme::new_with_state(state)))),
50        }
51    }
52
53    pub fn new_dynamic(state: ConsommeState) -> (Self, ConsommeControl) {
54        let (consomme, control) = Consomme::new_dynamic(state);
55        (
56            Self {
57                consomme: Arc::new(Mutex::new(Some(consomme))),
58            },
59            control,
60        )
61    }
62}
63
64impl InspectMut for ConsommeEndpoint {
65    fn inspect_mut(&mut self, req: inspect::Request<'_>) {
66        if let Some(consomme) = &mut *self.consomme.lock() {
67            consomme.inspect_mut(req);
68        }
69    }
70}
71
72#[async_trait]
73impl net_backend::Endpoint for ConsommeEndpoint {
74    fn endpoint_type(&self) -> &'static str {
75        "consomme"
76    }
77
78    async fn get_queues(
79        &mut self,
80        config: Vec<QueueConfig<'_>>,
81        _rss: Option<&RssConfig<'_>>,
82        queues: &mut Vec<Box<dyn net_backend::Queue>>,
83    ) -> anyhow::Result<()> {
84        assert_eq!(config.len(), 1);
85        let config = config.into_iter().next().unwrap();
86        let mut queue = Box::new(ConsommeQueue {
87            slot: self.consomme.clone(),
88            consomme: self.consomme.lock().take(),
89            state: QueueState {
90                pool: config.pool,
91                rx_avail: config.initial_rx.iter().copied().collect(),
92                rx_ready: VecDeque::new(),
93                tx_avail: VecDeque::new(),
94                tx_ready: VecDeque::new(),
95            },
96            stats: Default::default(),
97            driver: config.driver,
98        });
99        queue.with_consomme(|c| c.refresh_driver());
100        queues.push(queue);
101        Ok(())
102    }
103
104    async fn stop(&mut self) {
105        assert!(self.consomme.lock().is_some());
106    }
107
108    fn is_ordered(&self) -> bool {
109        true
110    }
111
112    fn tx_offload_support(&self) -> TxOffloadSupport {
113        TxOffloadSupport {
114            ipv4_header: true,
115            tcp: true,
116            udp: true,
117            tso: true,
118        }
119    }
120}
121
122pub struct ConsommeQueue {
123    slot: Arc<Mutex<Option<Consomme>>>,
124    consomme: Option<Consomme>,
125    state: QueueState,
126    stats: Stats,
127    driver: Box<dyn Driver>,
128}
129
130impl InspectMut for ConsommeQueue {
131    fn inspect_mut(&mut self, req: inspect::Request<'_>) {
132        req.respond()
133            .merge(self.consomme.as_mut().unwrap())
134            .field("rx_avail", self.state.rx_avail.len())
135            .field("rx_ready", self.state.rx_ready.len())
136            .field("tx_avail", self.state.tx_avail.len())
137            .field("tx_ready", self.state.tx_ready.len())
138            .field("stats", &self.stats);
139    }
140}
141
142impl Drop for ConsommeQueue {
143    fn drop(&mut self) {
144        *self.slot.lock() = self.consomme.take();
145    }
146}
147
148impl ConsommeQueue {
149    fn with_consomme<F, R>(&mut self, f: F) -> R
150    where
151        F: FnOnce(&mut consomme::Access<'_, Client<'_>>) -> R,
152    {
153        f(&mut self.consomme.as_mut().unwrap().access(&mut Client {
154            state: &mut self.state,
155            stats: &mut self.stats,
156            driver: &self.driver,
157        }))
158    }
159}
160
161impl net_backend::Queue for ConsommeQueue {
162    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<()> {
163        while let Some(head) = self.state.tx_avail.front() {
164            let TxSegmentType::Head(meta) = &head.ty else {
165                unreachable!()
166            };
167            let tx_id = meta.id;
168            let checksum = ChecksumState {
169                ipv4: meta.offload_ip_header_checksum,
170                tcp: meta.offload_tcp_checksum,
171                udp: meta.offload_udp_checksum,
172                tso: meta
173                    .offload_tcp_segmentation
174                    .then_some(meta.max_tcp_segment_size),
175            };
176
177            let mut buf = vec![0; meta.len];
178            let gm = self.state.pool.guest_memory();
179            let mut offset = 0;
180            for segment in self.state.tx_avail.drain(..meta.segment_count) {
181                let dest = &mut buf[offset..offset + segment.len as usize];
182                if let Err(err) = gm.read_at(segment.gpa, dest) {
183                    tracing::error!(
184                        error = &err as &dyn std::error::Error,
185                        "memory write failure"
186                    );
187                }
188                offset += segment.len as usize;
189            }
190
191            if let Err(err) = self.with_consomme(|c| c.send(&buf, &checksum)) {
192                tracing::debug!(error = &err as &dyn std::error::Error, "tx packet ignored");
193                match err {
194                    consomme::DropReason::SendBufferFull => self.stats.tx_dropped.increment(),
195                    consomme::DropReason::UnsupportedEthertype(_)
196                    | consomme::DropReason::UnsupportedIpProtocol(_)
197                    | consomme::DropReason::UnsupportedDhcp(_)
198                    | consomme::DropReason::UnsupportedArp => self.stats.tx_unknown.increment(),
199                    consomme::DropReason::Packet(_)
200                    | consomme::DropReason::Ipv4Checksum
201                    | consomme::DropReason::Io(_)
202                    | consomme::DropReason::BadTcpState(_) => self.stats.tx_errors.increment(),
203                    consomme::DropReason::PortNotBound => unreachable!(),
204                }
205            }
206
207            self.state.tx_ready.push_back(tx_id);
208        }
209
210        self.with_consomme(|c| c.poll(cx));
211
212        if !self.state.tx_ready.is_empty() || !self.state.rx_ready.is_empty() {
213            Poll::Ready(())
214        } else {
215            Poll::Pending
216        }
217    }
218
219    fn rx_avail(&mut self, done: &[RxId]) {
220        self.state.rx_avail.extend(done);
221    }
222
223    fn rx_poll(&mut self, packets: &mut [RxId]) -> anyhow::Result<usize> {
224        let n = packets.len().min(self.state.rx_ready.len());
225        for (x, y) in packets.iter_mut().zip(self.state.rx_ready.drain(..n)) {
226            *x = y;
227        }
228        Ok(n)
229    }
230
231    fn tx_avail(&mut self, segments: &[TxSegment]) -> anyhow::Result<(bool, usize)> {
232        self.state.tx_avail.extend(segments.iter().cloned());
233        Ok((false, segments.len()))
234    }
235
236    fn tx_poll(&mut self, done: &mut [TxId]) -> Result<usize, TxError> {
237        let n = done.len().min(self.state.tx_ready.len());
238        for (x, y) in done.iter_mut().zip(self.state.tx_ready.drain(..n)) {
239            *x = y;
240        }
241        Ok(n)
242    }
243
244    fn buffer_access(&mut self) -> Option<&mut dyn BufferAccess> {
245        Some(self.state.pool.as_mut())
246    }
247}
248
249struct QueueState {
250    pool: Box<dyn BufferAccess>,
251    rx_avail: VecDeque<RxId>,
252    rx_ready: VecDeque<RxId>,
253    tx_avail: VecDeque<TxSegment>,
254    tx_ready: VecDeque<TxId>,
255}
256
257#[derive(Inspect, Default)]
258struct Stats {
259    rx_dropped: Counter,
260    tx_dropped: Counter,
261    tx_errors: Counter,
262    tx_unknown: Counter,
263}
264
265struct Client<'a> {
266    state: &'a mut QueueState,
267    stats: &'a mut Stats,
268    driver: &'a dyn Driver,
269}
270
271impl consomme::Client for Client<'_> {
272    fn driver(&self) -> &dyn Driver {
273        self.driver
274    }
275
276    fn recv(&mut self, data: &[u8], checksum: &ChecksumState) {
277        let Some(rx_id) = self.state.rx_avail.pop_front() else {
278            // This should be rare, only affecting unbuffered protocols. TCP and
279            // UDP are buffered and they won't indicate packets unless rx_mtu()
280            // returns a non-zero value.
281            self.stats.rx_dropped.increment();
282            return;
283        };
284        let max = self.state.pool.capacity(rx_id) as usize;
285        if data.len() <= max {
286            self.state.pool.write_packet(
287                rx_id,
288                &RxMetadata {
289                    offset: 0,
290                    len: data.len(),
291                    ip_checksum: if checksum.ipv4 {
292                        RxChecksumState::Good
293                    } else {
294                        RxChecksumState::Unknown
295                    },
296                    l4_checksum: if checksum.tcp || checksum.udp {
297                        RxChecksumState::Good
298                    } else {
299                        RxChecksumState::Unknown
300                    },
301                    l4_protocol: if checksum.tcp {
302                        L4Protocol::Tcp
303                    } else if checksum.udp {
304                        L4Protocol::Udp
305                    } else {
306                        L4Protocol::Unknown
307                    },
308                },
309                data,
310            );
311            self.state.rx_ready.push_back(rx_id);
312        } else {
313            tracing::warn!(len = data.len(), max, "dropping rx packet: too large");
314            self.state.rx_avail.push_front(rx_id);
315        }
316    }
317
318    fn rx_mtu(&mut self) -> usize {
319        if let Some(&rx_id) = self.state.rx_avail.front() {
320            self.state.pool.capacity(rx_id) as usize
321        } else {
322            0
323        }
324    }
325}