net_mana/
lib.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4#![forbid(unsafe_code)]
5#![expect(missing_docs)]
6
7use anyhow::Context as _;
8use async_trait::async_trait;
9use futures::FutureExt;
10use futures::StreamExt;
11use gdma_defs::Cqe;
12use gdma_defs::GDMA_EQE_COMPLETION;
13use gdma_defs::Sge;
14use gdma_defs::bnic::CQE_RX_OKAY;
15use gdma_defs::bnic::CQE_TX_GDMA_ERR;
16use gdma_defs::bnic::CQE_TX_OKAY;
17use gdma_defs::bnic::MANA_LONG_PKT_FMT;
18use gdma_defs::bnic::MANA_SHORT_PKT_FMT;
19use gdma_defs::bnic::ManaQueryStatisticsResponse;
20use gdma_defs::bnic::ManaRxcompOob;
21use gdma_defs::bnic::ManaTxCompOob;
22use gdma_defs::bnic::ManaTxOob;
23use guestmem::GuestMemory;
24use inspect::Inspect;
25use inspect::InspectMut;
26use inspect::SensitivityLevel;
27use mana_driver::mana::BnicEq;
28use mana_driver::mana::BnicWq;
29use mana_driver::mana::ResourceArena;
30use mana_driver::mana::RxConfig;
31use mana_driver::mana::TxConfig;
32use mana_driver::mana::Vport;
33use mana_driver::queues::Cq;
34use mana_driver::queues::Eq;
35use mana_driver::queues::Wq;
36use net_backend::BufferAccess;
37use net_backend::Endpoint;
38use net_backend::EndpointAction;
39use net_backend::L3Protocol;
40use net_backend::L4Protocol;
41use net_backend::MultiQueueSupport;
42use net_backend::Queue;
43use net_backend::QueueConfig;
44use net_backend::RssConfig;
45use net_backend::RxChecksumState;
46use net_backend::RxId;
47use net_backend::RxMetadata;
48use net_backend::TxError;
49use net_backend::TxId;
50use net_backend::TxOffloadSupport;
51use net_backend::TxSegment;
52use net_backend::TxSegmentType;
53use pal_async::task::Spawn;
54use safeatomic::AtomicSliceOps;
55use std::collections::VecDeque;
56use std::sync::Arc;
57use std::sync::Weak;
58use std::sync::atomic::AtomicU8;
59use std::sync::atomic::AtomicUsize;
60use std::sync::atomic::Ordering;
61use std::task::Context;
62use std::task::Poll;
63use thiserror::Error;
64use user_driver::DeviceBacking;
65use user_driver::DmaClient;
66use user_driver::interrupt::DeviceInterrupt;
67use user_driver::memory::MemoryBlock;
68use user_driver::memory::PAGE_SIZE32;
69use user_driver::memory::PAGE_SIZE64;
70use vmcore::slim_event::SlimEvent;
71use zerocopy::FromBytes;
72use zerocopy::FromZeros;
73
74/// Per queue limit, in number of pages.
75/// Used to handle bounce buffering non-contiguous network packet headers.
76const SPLIT_HEADER_BOUNCE_PAGE_LIMIT: u32 = 4;
77
78/// Per queue limit for bounce buffering, in number of pages.
79/// This is only used when bounce buffering is enabled for the device.
80const RX_BOUNCE_BUFFER_PAGE_LIMIT: u32 = 64;
81const TX_BOUNCE_BUFFER_PAGE_LIMIT: u32 = 64;
82
83pub struct ManaEndpoint<T: DeviceBacking> {
84    spawner: Box<dyn Spawn>,
85    vport: Arc<Vport<T>>,
86    queues: Vec<QueueResources>,
87    arena: ResourceArena,
88    receive_update: mesh::Receiver<bool>,
89    queue_tracker: Arc<(AtomicUsize, SlimEvent)>,
90    bounce_buffer: bool,
91}
92
93struct QueueResources {
94    _eq: BnicEq,
95    rxq: BnicWq,
96    _txq: BnicWq,
97}
98
99#[derive(Copy, Clone, Debug, PartialEq, Eq)]
100pub enum GuestDmaMode {
101    DirectDma,
102    BounceBuffer,
103}
104
105impl<T: DeviceBacking> ManaEndpoint<T> {
106    pub async fn new(
107        spawner: impl 'static + Spawn,
108        vport: Vport<T>,
109        dma_mode: GuestDmaMode,
110    ) -> Self {
111        let (endpoint_tx, endpoint_rx) = mesh::channel();
112        vport.register_link_status_notifier(endpoint_tx).await;
113        Self {
114            spawner: Box::new(spawner),
115            vport: Arc::new(vport),
116            queues: Vec::new(),
117            arena: ResourceArena::new(),
118            receive_update: endpoint_rx,
119            queue_tracker: Arc::new((AtomicUsize::new(0), SlimEvent::new())),
120            bounce_buffer: match dma_mode {
121                GuestDmaMode::DirectDma => false,
122                GuestDmaMode::BounceBuffer => true,
123            },
124        }
125    }
126}
127
128fn inspect_mana_stats(stats: &ManaQueryStatisticsResponse, req: inspect::Request<'_>) {
129    req.respond()
130        .sensitivity_counter(
131            "in_discards_no_wqe",
132            SensitivityLevel::Safe,
133            stats.in_discards_no_wqe,
134        )
135        .sensitivity_counter(
136            "in_errors_rx_vport_disabled",
137            SensitivityLevel::Safe,
138            stats.in_errors_rx_vport_disabled,
139        )
140        .sensitivity_counter("hc_in_octets", SensitivityLevel::Safe, stats.hc_in_octets)
141        .sensitivity_counter(
142            "hc_in_ucast_pkts",
143            SensitivityLevel::Safe,
144            stats.hc_in_ucast_pkts,
145        )
146        .sensitivity_counter(
147            "hc_in_ucast_octets",
148            SensitivityLevel::Safe,
149            stats.hc_in_ucast_octets,
150        )
151        .sensitivity_counter(
152            "hc_in_multicast_pkts",
153            SensitivityLevel::Safe,
154            stats.hc_in_multicast_pkts,
155        )
156        .sensitivity_counter(
157            "hc_in_multicast_octets",
158            SensitivityLevel::Safe,
159            stats.hc_in_multicast_octets,
160        )
161        .sensitivity_counter(
162            "hc_in_broadcast_pkts",
163            SensitivityLevel::Safe,
164            stats.hc_in_broadcast_pkts,
165        )
166        .sensitivity_counter(
167            "hc_in_broadcast_octets",
168            SensitivityLevel::Safe,
169            stats.hc_in_broadcast_octets,
170        )
171        .sensitivity_counter(
172            "out_errors_gf_disabled",
173            SensitivityLevel::Safe,
174            stats.out_errors_gf_disabled,
175        )
176        .sensitivity_counter(
177            "out_errors_vport_disabled",
178            SensitivityLevel::Safe,
179            stats.out_errors_vport_disabled,
180        )
181        .sensitivity_counter(
182            "out_errors_invalid_vport_offset_packets",
183            SensitivityLevel::Safe,
184            stats.out_errors_invalid_vport_offset_packets,
185        )
186        .sensitivity_counter(
187            "out_errors_vlan_enforcement",
188            SensitivityLevel::Safe,
189            stats.out_errors_vlan_enforcement,
190        )
191        .sensitivity_counter(
192            "out_errors_eth_type_enforcement",
193            SensitivityLevel::Safe,
194            stats.out_errors_eth_type_enforcement,
195        )
196        .sensitivity_counter(
197            "out_errors_sa_enforcement",
198            SensitivityLevel::Safe,
199            stats.out_errors_sa_enforcement,
200        )
201        .sensitivity_counter(
202            "out_errors_sqpdid_enforcement",
203            SensitivityLevel::Safe,
204            stats.out_errors_sqpdid_enforcement,
205        )
206        .sensitivity_counter(
207            "out_errors_cqpdid_enforcement",
208            SensitivityLevel::Safe,
209            stats.out_errors_cqpdid_enforcement,
210        )
211        .sensitivity_counter(
212            "out_errors_mtu_violation",
213            SensitivityLevel::Safe,
214            stats.out_errors_mtu_violation,
215        )
216        .sensitivity_counter(
217            "out_errors_invalid_oob",
218            SensitivityLevel::Safe,
219            stats.out_errors_invalid_oob,
220        )
221        .sensitivity_counter("hc_out_octets", SensitivityLevel::Safe, stats.hc_out_octets)
222        .sensitivity_counter(
223            "hc_out_ucast_pkts",
224            SensitivityLevel::Safe,
225            stats.hc_out_ucast_pkts,
226        )
227        .sensitivity_counter(
228            "hc_out_ucast_octets",
229            SensitivityLevel::Safe,
230            stats.hc_out_ucast_octets,
231        )
232        .sensitivity_counter(
233            "hc_out_multicast_pkts",
234            SensitivityLevel::Safe,
235            stats.hc_out_multicast_pkts,
236        )
237        .sensitivity_counter(
238            "hc_out_multicast_octets",
239            SensitivityLevel::Safe,
240            stats.hc_out_multicast_octets,
241        )
242        .sensitivity_counter(
243            "hc_out_broadcast_pkts",
244            SensitivityLevel::Safe,
245            stats.hc_out_broadcast_pkts,
246        )
247        .sensitivity_counter(
248            "hc_out_broadcast_octets",
249            SensitivityLevel::Safe,
250            stats.hc_out_broadcast_octets,
251        );
252}
253
254impl<T: DeviceBacking> InspectMut for ManaEndpoint<T> {
255    fn inspect_mut(&mut self, req: inspect::Request<'_>) {
256        req.respond()
257            .sensitivity_child("stats", SensitivityLevel::Safe, |req| {
258                let vport = self.vport.clone();
259                let deferred = req.defer();
260                self.spawner
261                    .spawn("mana-stats", async move {
262                        let stats = if let Ok(stats) = vport.query_stats().await {
263                            stats
264                        } else {
265                            ManaQueryStatisticsResponse::new_zeroed()
266                        };
267                        deferred.inspect(inspect::adhoc(|req| inspect_mana_stats(&stats, req)));
268                    })
269                    .detach();
270            });
271    }
272}
273
274impl<T: DeviceBacking> ManaEndpoint<T> {
275    async fn new_queue(
276        &mut self,
277        tx_config: &TxConfig,
278        pool: Box<dyn BufferAccess>,
279        initial_rx: &[RxId],
280        arena: &mut ResourceArena,
281        cpu: u32,
282    ) -> anyhow::Result<(ManaQueue<T>, QueueResources)> {
283        let eq_size = 0x1000;
284        let tx_wq_size = 0x4000;
285        let tx_cq_size = 0x4000;
286        let rx_wq_size = 0x8000;
287        let rx_cq_size = 0x4000;
288
289        let eq = (self.vport.new_eq(arena, eq_size, cpu))
290            .await
291            .context("failed to create eq")?;
292        let txq = (self
293            .vport
294            .new_wq(arena, true, tx_wq_size, tx_cq_size, eq.id()))
295        .await
296        .context("failed to create tx queue")?;
297        let rxq = (self
298            .vport
299            .new_wq(arena, false, rx_wq_size, rx_cq_size, eq.id()))
300        .await
301        .context("failed to create rx queue")?;
302
303        let interrupt = eq.interrupt();
304
305        // The effective rx max may be smaller depending on the number of SGE
306        // entries used in the work queue (which depends on the NIC's configured
307        // MTU).
308        let rx_max = (rx_cq_size / size_of::<Cqe>() as u32).min(512);
309
310        let tx_max = tx_cq_size / size_of::<Cqe>() as u32;
311
312        let tx_bounce_buffer = ContiguousBufferManager::new(
313            self.vport.dma_client().await,
314            if self.bounce_buffer {
315                TX_BOUNCE_BUFFER_PAGE_LIMIT
316            } else {
317                SPLIT_HEADER_BOUNCE_PAGE_LIMIT
318            },
319        )
320        .context("failed to allocate tx bounce buffer")?;
321
322        let rx_bounce_buffer = if self.bounce_buffer {
323            Some(
324                ContiguousBufferManager::new(
325                    self.vport.dma_client().await,
326                    RX_BOUNCE_BUFFER_PAGE_LIMIT,
327                )
328                .context("failed to allocate rx bounce buffer")?,
329            )
330        } else {
331            None
332        };
333
334        let mut queue = ManaQueue {
335            guest_memory: pool.guest_memory().clone(),
336            pool,
337            rx_bounce_buffer,
338            tx_bounce_buffer,
339            vport: Arc::downgrade(&self.vport),
340            queue_tracker: self.queue_tracker.clone(),
341            eq: eq.queue(),
342            eq_armed: true,
343            interrupt,
344            tx_cq_armed: true,
345            rx_cq_armed: true,
346            vp_offset: tx_config.tx_vport_offset,
347            mem_key: self.vport.gpa_mkey(),
348            tx_wq: txq.wq(),
349            tx_cq: txq.cq(),
350            rx_wq: rxq.wq(),
351            rx_cq: rxq.cq(),
352            avail_rx: VecDeque::new(),
353            posted_rx: VecDeque::new(),
354            rx_max: rx_max as usize,
355            posted_tx: VecDeque::new(),
356            dropped_tx: VecDeque::new(),
357            tx_max: tx_max as usize,
358            force_tx_header_bounce: false,
359            stats: QueueStats::default(),
360        };
361        self.queue_tracker.0.fetch_add(1, Ordering::AcqRel);
362        queue.rx_avail(initial_rx);
363        queue.rx_wq.commit();
364
365        let resources = QueueResources {
366            _eq: eq,
367            rxq,
368            _txq: txq,
369        };
370        Ok((queue, resources))
371    }
372
373    async fn get_queues_inner(
374        &mut self,
375        arena: &mut ResourceArena,
376        config: Vec<QueueConfig<'_>>,
377        rss: Option<&RssConfig<'_>>,
378        queues: &mut Vec<Box<dyn Queue>>,
379    ) -> anyhow::Result<()> {
380        assert!(self.queues.is_empty());
381
382        let tx_config = self
383            .vport
384            .config_tx()
385            .await
386            .context("failed to configure transmit")?;
387
388        let mut queue_resources = Vec::new();
389
390        for config in config {
391            // Start the queue interrupt on CPU 0, which is already used by the
392            // HWC so this is cheap. The actual interrupt will be allocated
393            // later when `update_target_vp` is first called.
394            let (queue, resources) = self
395                .new_queue(&tx_config, config.pool, config.initial_rx, arena, 0)
396                .await?;
397
398            queues.push(Box::new(queue));
399            queue_resources.push(resources);
400        }
401
402        let indirection_table;
403        let rx_config = if let Some(rss) = rss {
404            indirection_table = rss
405                .indirection_table
406                .iter()
407                .map(|&queue_id| {
408                    queue_resources
409                        .get(queue_id as usize)
410                        .unwrap_or_else(|| &queue_resources[0])
411                        .rxq
412                        .wq_obj()
413                })
414                .collect::<Vec<_>>();
415
416            RxConfig {
417                rx_enable: Some(true),
418                rss_enable: Some(true),
419                hash_key: Some(rss.key.try_into().ok().context("wrong hash key size")?),
420                default_rxobj: Some(queue_resources[0].rxq.wq_obj()),
421                indirection_table: Some(&indirection_table),
422            }
423        } else {
424            RxConfig {
425                rx_enable: Some(true),
426                rss_enable: Some(false),
427                hash_key: None,
428                default_rxobj: Some(queue_resources[0].rxq.wq_obj()),
429                indirection_table: None,
430            }
431        };
432
433        self.vport.config_rx(&rx_config).await?;
434        self.queues = queue_resources;
435        Ok(())
436    }
437}
438
439#[async_trait]
440impl<T: DeviceBacking> Endpoint for ManaEndpoint<T> {
441    fn endpoint_type(&self) -> &'static str {
442        "mana"
443    }
444
445    async fn get_queues(
446        &mut self,
447        config: Vec<QueueConfig<'_>>,
448        rss: Option<&RssConfig<'_>>,
449        queues: &mut Vec<Box<dyn Queue>>,
450    ) -> anyhow::Result<()> {
451        assert!(self.arena.is_empty());
452        let mut arena = ResourceArena::new();
453        match self.get_queues_inner(&mut arena, config, rss, queues).await {
454            Ok(()) => {
455                self.arena = arena;
456                Ok(())
457            }
458            Err(err) => {
459                self.vport.destroy(arena).await;
460                Err(err)
461            }
462        }
463    }
464
465    async fn stop(&mut self) {
466        if let Err(err) = self
467            .vport
468            .config_rx(&RxConfig {
469                rx_enable: Some(false),
470                rss_enable: None,
471                hash_key: None,
472                default_rxobj: None,
473                indirection_table: None,
474            })
475            .await
476        {
477            tracing::warn!(
478                error = err.as_ref() as &dyn std::error::Error,
479                "failed to stop rx"
480            );
481        }
482
483        self.queues.clear();
484        self.vport.destroy(std::mem::take(&mut self.arena)).await;
485        // Wait for all outstanding queues. There can be a delay switching out
486        // the queues when an endpoint is removed, and the queue has access to
487        // the vport which is being stopped here.
488        if self.queue_tracker.0.load(Ordering::Acquire) > 0 {
489            self.queue_tracker.1.wait().await;
490        }
491    }
492
493    fn is_ordered(&self) -> bool {
494        true
495    }
496
497    fn tx_offload_support(&self) -> TxOffloadSupport {
498        TxOffloadSupport {
499            ipv4_header: true,
500            tcp: true,
501            udp: true,
502            // Tbe bounce buffer path does not support TSO.
503            tso: !self.bounce_buffer,
504        }
505    }
506
507    fn multiqueue_support(&self) -> MultiQueueSupport {
508        MultiQueueSupport {
509            max_queues: self
510                .vport
511                .max_rx_queues()
512                .min(self.vport.max_tx_queues())
513                .min(u16::MAX.into()) as u16,
514            indirection_table_size: self.vport.num_indirection_ent().min(u16::MAX.into()) as u16,
515        }
516    }
517
518    fn tx_fast_completions(&self) -> bool {
519        // The mana NIC completes packets quickly and in order.
520        true
521    }
522
523    async fn set_data_path_to_guest_vf(&self, use_vf: bool) -> anyhow::Result<()> {
524        self.vport.move_filter(if use_vf { 1 } else { 0 }).await?;
525        Ok(())
526    }
527
528    async fn get_data_path_to_guest_vf(&self) -> anyhow::Result<bool> {
529        match self.vport.get_direction_to_vtl0().await {
530            Some(to_vtl0) => Ok(to_vtl0),
531            None => Err(anyhow::anyhow!("Device does not support data path query")),
532        }
533    }
534
535    async fn wait_for_endpoint_action(&mut self) -> EndpointAction {
536        self.receive_update
537            .select_next_some()
538            .map(EndpointAction::LinkStatusNotify)
539            .await
540    }
541
542    fn link_speed(&self) -> u64 {
543        // Hard code to 200Gbps until MANA supports querying this.
544        200 * 1000 * 1000 * 1000
545    }
546}
547
548pub struct ManaQueue<T: DeviceBacking> {
549    pool: Box<dyn BufferAccess>,
550    guest_memory: GuestMemory,
551    rx_bounce_buffer: Option<ContiguousBufferManager>,
552    tx_bounce_buffer: ContiguousBufferManager,
553
554    vport: Weak<Vport<T>>,
555    queue_tracker: Arc<(AtomicUsize, SlimEvent)>,
556
557    eq: Eq,
558    eq_armed: bool,
559    interrupt: DeviceInterrupt,
560    tx_cq_armed: bool,
561    rx_cq_armed: bool,
562
563    vp_offset: u16,
564    mem_key: u32,
565
566    tx_wq: Wq,
567    tx_cq: Cq,
568
569    rx_wq: Wq,
570    rx_cq: Cq,
571
572    avail_rx: VecDeque<RxId>,
573    posted_rx: VecDeque<PostedRx>,
574    rx_max: usize,
575
576    posted_tx: VecDeque<PostedTx>,
577    dropped_tx: VecDeque<TxId>,
578    tx_max: usize,
579
580    force_tx_header_bounce: bool,
581
582    stats: QueueStats,
583}
584
585impl<T: DeviceBacking> Drop for ManaQueue<T> {
586    fn drop(&mut self) {
587        // Signal the endpoint when no more queues are active.
588        if self.queue_tracker.0.fetch_sub(1, Ordering::AcqRel) == 1 {
589            self.queue_tracker.1.signal();
590        }
591    }
592}
593
594struct PostedRx {
595    id: RxId,
596    wqe_len: u32,
597    bounced_len_with_padding: u32,
598    bounce_offset: u32,
599}
600
601struct PostedTx {
602    id: TxId,
603    wqe_len: u32,
604    bounced_len_with_padding: u32,
605}
606
607#[derive(Default)]
608struct QueueStats {
609    tx_events: u64,
610    tx_packets: u64,
611    tx_errors: u64,
612    tx_dropped: u64,
613    tx_stuck: u64,
614
615    rx_events: u64,
616    rx_packets: u64,
617    rx_errors: u64,
618
619    interrupts: u64,
620}
621
622impl Inspect for QueueStats {
623    fn inspect(&self, req: inspect::Request<'_>) {
624        req.respond()
625            .counter("tx_events", self.tx_events)
626            .counter("tx_packets", self.tx_packets)
627            .counter("tx_errors", self.tx_errors)
628            .counter("tx_dropped", self.tx_dropped)
629            .counter("tx_stuck", self.tx_stuck)
630            .counter("rx_events", self.rx_events)
631            .counter("rx_packets", self.rx_packets)
632            .counter("rx_errors", self.rx_errors)
633            .counter("interrupts", self.interrupts);
634    }
635}
636
637impl<T: DeviceBacking> InspectMut for ManaQueue<T> {
638    // N.B. Inspect fields need to be kept in sync with
639    // Microsoft internal diagnostics testing.
640    // Search for EXPECTED_QUEUE_FIELDS_V1.
641    fn inspect_mut(&mut self, req: inspect::Request<'_>) {
642        req.respond()
643            .merge(&self.tx_bounce_buffer)
644            .field("rx_bounce_buffer", &self.rx_bounce_buffer)
645            .merge(&self.stats)
646            .field("eq", &self.eq)
647            .field("eq/armed", self.eq_armed)
648            .field_mut("force_tx_header_bounce", &mut self.force_tx_header_bounce)
649            .field("rx_wq", &self.rx_wq)
650            .field("rx_cq", &self.rx_cq)
651            .field("rx_cq/armed", self.rx_cq_armed)
652            .field("tx_wq", &self.tx_wq)
653            .field("tx_cq", &self.tx_cq)
654            .field("tx_cq/armed", self.tx_cq_armed)
655            .field("rx_queued", self.posted_rx.len())
656            .field("rx_avail", self.avail_rx.len())
657            .field("tx_queued", self.posted_tx.len());
658    }
659}
660
661/// RWQEs cannot be larger than 256 bytes.
662pub const MAX_RWQE_SIZE: u32 = 256;
663
664/// SWQEs cannot be larger than 512 bytes.
665pub const MAX_SWQE_SIZE: u32 = 512;
666
667impl<T: DeviceBacking> ManaQueue<T> {
668    fn push_rqe(&mut self) -> bool {
669        // Make sure there is enough room for an entry of the maximum size. This
670        // is conservative, but it simplifies the logic.
671        if self.rx_wq.available() < MAX_RWQE_SIZE {
672            return false;
673        }
674        if let Some(id) = self.avail_rx.pop_front() {
675            let rx = if let Some(pool) = &mut self.rx_bounce_buffer {
676                let size = self.pool.capacity(id);
677                let mut pool_tx = pool.start_allocation();
678                let Ok(buffer) = pool_tx.allocate(size) else {
679                    self.avail_rx.push_front(id);
680                    return false;
681                };
682                let buffer = buffer.reserve();
683                let sqe = Sge {
684                    address: buffer.gpa,
685                    mem_key: self.mem_key,
686                    size,
687                };
688                let wqe_len = self
689                    .rx_wq
690                    .push(&(), [sqe], None, 0)
691                    .expect("rq should not be full");
692
693                PostedRx {
694                    id,
695                    wqe_len,
696                    bounce_offset: buffer.offset,
697                    bounced_len_with_padding: pool_tx.commit(),
698                }
699            } else {
700                let sgl = self.pool.guest_addresses(id).iter().map(|seg| Sge {
701                    address: self.guest_memory.iova(seg.gpa).unwrap(),
702                    mem_key: self.mem_key,
703                    size: seg.len,
704                });
705
706                let wqe_len = self
707                    .rx_wq
708                    .push(&(), sgl, None, 0)
709                    .expect("rq should not be full");
710
711                assert!(wqe_len <= MAX_RWQE_SIZE, "too many scatter/gather entries");
712                PostedRx {
713                    id,
714                    wqe_len,
715                    bounce_offset: 0,
716                    bounced_len_with_padding: 0,
717                }
718            };
719
720            self.posted_rx.push_back(rx);
721            true
722        } else {
723            false
724        }
725    }
726
727    fn trace_tx_wqe(&mut self, tx_oob: ManaTxCompOob, done_length: usize) {
728        tracelimit::error_ratelimited!(
729            cqe_hdr_type = tx_oob.cqe_hdr.cqe_type(),
730            cqe_hdr_vendor_err = tx_oob.cqe_hdr.vendor_err(),
731            tx_oob_data_offset = tx_oob.tx_data_offset,
732            tx_oob_sgl_offset = tx_oob.offsets.tx_sgl_offset(),
733            tx_oob_wqe_offset = tx_oob.offsets.tx_wqe_offset(),
734            done_length,
735            posted_tx_len = self.posted_tx.len(),
736            "tx completion error"
737        );
738
739        // TODO: Use tx_wqe_offset to read the Wqe.
740        // Use Wqe.ClientOob to read the ManaTxOob.s_oob.
741        // Log properties of s_oob like checksum, etc.
742
743        if let Some(packet) = self.posted_tx.front() {
744            tracelimit::error_ratelimited!(
745                id = packet.id.0,
746                wqe_len = packet.wqe_len,
747                bounced_len_with_padding = packet.bounced_len_with_padding,
748                "posted tx"
749            );
750        }
751    }
752}
753
754#[async_trait]
755impl<T: DeviceBacking + Send> Queue for ManaQueue<T> {
756    async fn update_target_vp(&mut self, target_vp: u32) {
757        if let Some(vport) = self.vport.upgrade() {
758            let result = vport.retarget_interrupt(self.eq.id(), target_vp).await;
759            match result {
760                Err(err) => {
761                    tracing::warn!(
762                        error = err.as_ref() as &dyn std::error::Error,
763                        "failed to retarget interrupt to cpu"
764                    );
765                }
766                Ok(None) => {}
767                Ok(Some(event)) => self.interrupt = event,
768            }
769        }
770    }
771
772    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<()> {
773        if !self.tx_cq_armed || !self.rx_cq_armed {
774            return Poll::Ready(());
775        }
776
777        loop {
778            while let Some(eqe) = self.eq.pop() {
779                self.eq_armed = false;
780                match eqe.params.event_type() {
781                    GDMA_EQE_COMPLETION => {
782                        let cq_id =
783                            u32::from_le_bytes(eqe.data[..4].try_into().unwrap()) & 0xffffff;
784                        if cq_id == self.tx_cq.id() {
785                            self.stats.tx_events += 1;
786                            self.tx_cq_armed = false;
787                        } else if cq_id == self.rx_cq.id() {
788                            self.stats.rx_events += 1;
789                            self.rx_cq_armed = false;
790                        } else {
791                            tracing::error!(cq_id, "unknown cq id");
792                        }
793                    }
794                    ty => {
795                        tracing::error!(ty, "unknown completion type")
796                    }
797                }
798            }
799
800            if !self.tx_cq_armed || !self.rx_cq_armed {
801                // When the vp count exceeds the number of queues, the event queue can easily
802                // overflow when not ACK'ed prior to arming the CQ
803                self.eq.ack();
804                return Poll::Ready(());
805            }
806
807            if !self.eq_armed {
808                self.eq.arm();
809                self.eq_armed = true;
810            }
811            std::task::ready!(self.interrupt.poll(cx));
812
813            self.stats.interrupts += 1;
814        }
815    }
816
817    fn rx_avail(&mut self, done: &[RxId]) {
818        self.avail_rx.extend(done);
819        let mut commit = false;
820        while self.posted_rx.len() < self.rx_max && self.push_rqe() {
821            commit = true;
822        }
823        if commit {
824            self.rx_wq.commit();
825        }
826    }
827
828    fn rx_poll(&mut self, packets: &mut [RxId]) -> anyhow::Result<usize> {
829        let mut i = 0;
830        let mut commit = false;
831        while i < packets.len() {
832            if let Some(cqe) = self.rx_cq.pop() {
833                let rx = self.posted_rx.pop_front().unwrap();
834                let rx_oob = ManaRxcompOob::read_from_prefix(&cqe.data[..]).unwrap().0; // TODO: zerocopy: use-rest-of-range (https://github.com/microsoft/openvmm/issues/759)
835                match rx_oob.cqe_hdr.cqe_type() {
836                    CQE_RX_OKAY => {
837                        let ip_checksum = if rx_oob.flags.rx_iphdr_csum_succeed() {
838                            RxChecksumState::Good
839                        } else if rx_oob.flags.rx_iphdr_csum_fail() {
840                            RxChecksumState::Bad
841                        } else {
842                            RxChecksumState::Unknown
843                        };
844                        let (l4_protocol, l4_checksum) = if rx_oob.flags.rx_tcp_csum_succeed() {
845                            (L4Protocol::Tcp, RxChecksumState::Good)
846                        } else if rx_oob.flags.rx_tcp_csum_fail() {
847                            (L4Protocol::Tcp, RxChecksumState::Bad)
848                        } else if rx_oob.flags.rx_udp_csum_succeed() {
849                            (L4Protocol::Udp, RxChecksumState::Good)
850                        } else if rx_oob.flags.rx_udp_csum_fail() {
851                            (L4Protocol::Udp, RxChecksumState::Bad)
852                        } else {
853                            (L4Protocol::Unknown, RxChecksumState::Unknown)
854                        };
855                        let len = rx_oob.ppi[0].pkt_len.into();
856                        self.pool.write_header(
857                            rx.id,
858                            &RxMetadata {
859                                offset: 0,
860                                len,
861                                ip_checksum,
862                                l4_checksum,
863                                l4_protocol,
864                            },
865                        );
866                        if rx.bounced_len_with_padding > 0 {
867                            // TODO: avoid this allocation by updating
868                            // write_data to take a slice of shared memory.
869                            let mut data = vec![0; len];
870                            self.rx_bounce_buffer.as_mut().unwrap().as_slice()
871                                [rx.bounce_offset as usize..][..len]
872                                .atomic_read(&mut data);
873                            self.pool.write_data(rx.id, &data);
874                        }
875                        self.stats.rx_packets += 1;
876                        packets[i] = rx.id;
877                        i += 1;
878                    }
879                    ty => {
880                        tracelimit::error_ratelimited!(ty, "invalid rx cqe type");
881                        self.stats.rx_errors += 1;
882                        self.avail_rx.push_back(rx.id);
883                    }
884                }
885                self.rx_wq.advance_head(rx.wqe_len);
886                if rx.bounced_len_with_padding > 0 {
887                    self.rx_bounce_buffer
888                        .as_mut()
889                        .unwrap()
890                        .free(rx.bounced_len_with_padding);
891                }
892                // Replenish the rq, if possible.
893                commit |= self.push_rqe();
894            } else {
895                if !self.rx_cq_armed {
896                    self.rx_cq.arm();
897                    self.rx_cq_armed = true;
898                }
899                break;
900            }
901        }
902        if commit {
903            self.rx_wq.commit();
904        }
905        Ok(i)
906    }
907
908    fn tx_avail(&mut self, segments: &[TxSegment]) -> anyhow::Result<(bool, usize)> {
909        let mut i = 0;
910        let mut commit = false;
911        while i < segments.len()
912            && self.posted_tx.len() < self.tx_max
913            && self.tx_wq.available() >= MAX_SWQE_SIZE
914        {
915            let head = &segments[i];
916            let TxSegmentType::Head(meta) = &head.ty else {
917                unreachable!()
918            };
919
920            if let Some(tx) = self.handle_tx(&segments[i..i + meta.segment_count])? {
921                commit = true;
922                self.posted_tx.push_back(tx);
923            } else {
924                self.dropped_tx.push_back(meta.id);
925            }
926            i += meta.segment_count;
927        }
928
929        if commit {
930            self.tx_wq.commit();
931        }
932        Ok((false, i))
933    }
934
935    fn tx_poll(&mut self, done: &mut [TxId]) -> Result<usize, TxError> {
936        let mut i = 0;
937        let mut queue_stuck = false;
938        while i < done.len() {
939            let id = if let Some(cqe) = self.tx_cq.pop() {
940                let tx_oob = ManaTxCompOob::read_from_prefix(&cqe.data[..]).unwrap().0; // TODO: zerocopy: use-rest-of-range (https://github.com/microsoft/openvmm/issues/759)
941                match tx_oob.cqe_hdr.cqe_type() {
942                    CQE_TX_OKAY => {
943                        self.stats.tx_packets += 1;
944                    }
945                    CQE_TX_GDMA_ERR => {
946                        queue_stuck = true;
947                    }
948                    ty => {
949                        tracelimit::error_ratelimited!(ty, "tx completion error");
950                        self.stats.tx_errors += 1;
951                    }
952                }
953                if queue_stuck {
954                    // Hardware hit an error with the packet coming from the Guest.
955                    // CQE_TX_GDMA_ERR is how the Hardware indicates that it has disabled the queue.
956                    self.stats.tx_errors += 1;
957                    self.stats.tx_stuck += 1;
958                    self.trace_tx_wqe(tx_oob, done.len());
959                    // Return a TryRestart error to indicate that the queue needs to be restarted.
960                    return Err(TxError::TryRestart(anyhow::anyhow!("GDMA error")));
961                }
962                let packet = self.posted_tx.pop_front().unwrap();
963                self.tx_wq.advance_head(packet.wqe_len);
964                if packet.bounced_len_with_padding > 0 {
965                    self.tx_bounce_buffer.free(packet.bounced_len_with_padding);
966                }
967                packet.id
968            } else if let Some(id) = self.dropped_tx.pop_front() {
969                self.stats.tx_dropped += 1;
970                id
971            } else {
972                if !self.tx_cq_armed {
973                    self.tx_cq.arm();
974                    self.tx_cq_armed = true;
975                }
976                break;
977            };
978
979            done[i] = id;
980            i += 1;
981        }
982        Ok(i)
983    }
984
985    fn buffer_access(&mut self) -> Option<&mut dyn BufferAccess> {
986        Some(self.pool.as_mut())
987    }
988}
989
990impl<T: DeviceBacking> ManaQueue<T> {
991    fn handle_tx(&mut self, segments: &[TxSegment]) -> anyhow::Result<Option<PostedTx>> {
992        let head = &segments[0];
993        let TxSegmentType::Head(meta) = &head.ty else {
994            unreachable!()
995        };
996
997        let mut oob = ManaTxOob::new_zeroed();
998        oob.s_oob.set_vcq_num(self.tx_cq.id());
999        oob.s_oob
1000            .set_vsq_frame((self.tx_wq.id() >> 10) as u16 & 0x3fff);
1001
1002        oob.s_oob
1003            .set_is_outer_ipv4(meta.l3_protocol == L3Protocol::Ipv4);
1004        oob.s_oob
1005            .set_is_outer_ipv6(meta.l3_protocol == L3Protocol::Ipv6);
1006        oob.s_oob
1007            .set_comp_iphdr_csum(meta.offload_ip_header_checksum);
1008        oob.s_oob.set_comp_tcp_csum(meta.offload_tcp_checksum);
1009        oob.s_oob.set_comp_udp_csum(meta.offload_udp_checksum);
1010        if meta.offload_tcp_checksum {
1011            oob.s_oob.set_trans_off(meta.l2_len as u16 + meta.l3_len);
1012        }
1013        let short_format = self.vp_offset <= 0xff;
1014        if short_format {
1015            oob.s_oob.set_pkt_fmt(MANA_SHORT_PKT_FMT);
1016            oob.s_oob.set_short_vp_offset(self.vp_offset as u8);
1017        } else {
1018            oob.s_oob.set_pkt_fmt(MANA_LONG_PKT_FMT);
1019            oob.l_oob.set_long_vp_offset(self.vp_offset);
1020        }
1021
1022        let mut bounce_buffer = self.tx_bounce_buffer.start_allocation();
1023        let tx = if self.rx_bounce_buffer.is_some() {
1024            assert!(!meta.offload_tcp_segmentation);
1025            let gd_client_unit_data = 0;
1026            let mut buf: ContiguousBuffer<'_, '_> = match bounce_buffer.allocate(meta.len as u32) {
1027                Ok(buf) => buf,
1028                Err(err) => {
1029                    tracelimit::error_ratelimited!(
1030                        err = &err as &dyn std::error::Error,
1031                        meta.len,
1032                        "failed to bounce buffer"
1033                    );
1034                    // Drop the packet
1035                    return Ok(None);
1036                }
1037            };
1038            let mut next = buf.as_slice();
1039            for seg in segments {
1040                let len = seg.len as usize;
1041                self.guest_memory.read_to_atomic(seg.gpa, &next[..len])?;
1042                next = &next[len..];
1043            }
1044            let buf = buf.reserve();
1045            let sge = Sge {
1046                address: buf.gpa,
1047                mem_key: self.mem_key,
1048                size: meta.len as u32,
1049            };
1050            let wqe_len = if short_format {
1051                self.tx_wq
1052                    .push(&oob.s_oob, [sge], None, gd_client_unit_data)
1053                    .unwrap()
1054            } else {
1055                self.tx_wq
1056                    .push(&oob, [sge], None, gd_client_unit_data)
1057                    .unwrap()
1058            };
1059            PostedTx {
1060                id: meta.id,
1061                wqe_len,
1062                bounced_len_with_padding: bounce_buffer.commit(),
1063            }
1064        } else {
1065            let mut gd_client_unit_data = 0;
1066            let mut header_len = head.len;
1067            let (header_segment_count, partial_bytes) = if meta.offload_tcp_segmentation {
1068                header_len = (meta.l2_len as u16 + meta.l3_len + meta.l4_len as u16) as u32;
1069                if header_len > PAGE_SIZE32 {
1070                    tracelimit::error_ratelimited!(
1071                        header_len,
1072                        "Header larger than PAGE_SIZE unsupported"
1073                    );
1074                    // Drop the packet
1075                    return Ok(None);
1076                }
1077
1078                let mut partial_bytes = 0;
1079                gd_client_unit_data = meta.max_tcp_segment_size;
1080                if header_len > head.len || self.force_tx_header_bounce {
1081                    let mut header_bytes_remaining = header_len;
1082                    let mut hdr_idx = 0;
1083                    while hdr_idx < segments.len() {
1084                        if header_bytes_remaining <= segments[hdr_idx].len {
1085                            if segments[hdr_idx].len > header_bytes_remaining {
1086                                partial_bytes = header_bytes_remaining;
1087                            }
1088                            header_bytes_remaining = 0;
1089                            break;
1090                        }
1091                        header_bytes_remaining -= segments[hdr_idx].len;
1092                        hdr_idx += 1;
1093                    }
1094                    if header_bytes_remaining > 0 {
1095                        tracelimit::error_ratelimited!(
1096                            header_len,
1097                            missing_header_bytes = header_bytes_remaining,
1098                            "Invalid split header"
1099                        );
1100                        // Drop the packet
1101                        return Ok(None);
1102                    }
1103                    ((hdr_idx + 1), partial_bytes)
1104                } else {
1105                    if head.len > header_len {
1106                        partial_bytes = header_len;
1107                    }
1108                    (1, partial_bytes)
1109                }
1110            } else {
1111                (1, 0)
1112            };
1113
1114            let mut last_segment_bounced = false;
1115            // The header needs to be contiguous.
1116            let head_iova = if header_len > head.len || self.force_tx_header_bounce {
1117                let mut copy = match bounce_buffer.allocate(header_len) {
1118                    Ok(buf) => buf,
1119                    Err(err) => {
1120                        tracelimit::error_ratelimited!(
1121                            err = &err as &dyn std::error::Error,
1122                            header_len,
1123                            "Failed to bounce buffer split header"
1124                        );
1125                        // Drop the packet
1126                        return Ok(None);
1127                    }
1128                };
1129                let mut next = copy.as_slice();
1130                for hdr_seg in &segments[..header_segment_count] {
1131                    let len = std::cmp::min(next.len(), hdr_seg.len as usize);
1132                    self.guest_memory
1133                        .read_to_atomic(hdr_seg.gpa, &next[..len])?;
1134                    next = &next[len..];
1135                }
1136                last_segment_bounced = true;
1137                let ContiguousBufferInUse { gpa, .. } = copy.reserve();
1138                gpa
1139            } else {
1140                self.guest_memory.iova(head.gpa).unwrap()
1141            };
1142
1143            // Hardware limit for short oob is 31. Max WQE size is 512 bytes.
1144            // Hardware limit for long oob is 30.
1145            let hardware_segment_limit = if short_format { 31 } else { 30 };
1146            let mut sgl = [Sge::new_zeroed(); 31];
1147            sgl[0] = Sge {
1148                address: head_iova,
1149                mem_key: self.mem_key,
1150                size: header_len,
1151            };
1152            let tail_sgl_offset = if partial_bytes > 0 {
1153                last_segment_bounced = false;
1154                let shared_seg = &segments[header_segment_count - 1];
1155                sgl[1] = Sge {
1156                    address: self
1157                        .guest_memory
1158                        .iova(shared_seg.gpa)
1159                        .unwrap()
1160                        .wrapping_add(partial_bytes as u64),
1161                    mem_key: self.mem_key,
1162                    size: shared_seg.len - partial_bytes,
1163                };
1164                2
1165            } else {
1166                1
1167            };
1168
1169            let mut segment_count = tail_sgl_offset + meta.segment_count - header_segment_count;
1170            let mut sgl_idx = tail_sgl_offset - 1;
1171            let sgl = if segment_count <= hardware_segment_limit {
1172                for (tail, sge) in segments[header_segment_count..]
1173                    .iter()
1174                    .zip(&mut sgl[tail_sgl_offset..])
1175                {
1176                    *sge = Sge {
1177                        address: self.guest_memory.iova(tail.gpa).unwrap(),
1178                        mem_key: self.mem_key,
1179                        size: tail.len,
1180                    };
1181                }
1182                &sgl[..segment_count]
1183            } else {
1184                let sgl = &mut sgl[..hardware_segment_limit];
1185                for tail_idx in header_segment_count..segments.len() {
1186                    let tail = &segments[tail_idx];
1187                    let cur_seg = &mut sgl[sgl_idx];
1188                    // Try to coalesce segments together if there are more than the hardware allows.
1189                    // TODO: Could use more expensive techniques such as
1190                    //       copying portions of segments to fill an entire
1191                    //       bounce page if the simple algorithm of coalescing
1192                    //       full segments together fails.
1193                    // TODO: If the header was not bounced, we could search the segments for the
1194                    //       longest sequence that can be coalesced, instead of the first sequence.
1195                    let coalesce_possible = cur_seg.size + tail.len < PAGE_SIZE32;
1196                    if segment_count > hardware_segment_limit {
1197                        if !last_segment_bounced
1198                            && coalesce_possible
1199                            && bounce_buffer.allocate(cur_seg.size + tail.len).is_ok()
1200                        {
1201                            // There is enough room to coalesce the current
1202                            // segment with the previous. The previous segment
1203                            // is not yet bounced, so bounce it now.
1204                            let last_segment_gpa = segments[tail_idx - 1].gpa;
1205                            let mut copy = bounce_buffer.allocate(cur_seg.size).unwrap();
1206                            self.guest_memory
1207                                .read_to_atomic(last_segment_gpa, copy.as_slice())?;
1208                            let ContiguousBufferInUse { gpa, .. } = copy.reserve();
1209                            cur_seg.address = gpa;
1210                            last_segment_bounced = true;
1211                        }
1212                        if last_segment_bounced {
1213                            if let Some(mut copy) = bounce_buffer.try_extend(tail.len) {
1214                                // Combine current segment with previous one using bounce buffer.
1215                                self.guest_memory
1216                                    .read_to_atomic(tail.gpa, copy.as_slice())?;
1217                                let ContiguousBufferInUse {
1218                                    len_with_padding, ..
1219                                } = copy.reserve();
1220                                assert_eq!(tail.len, len_with_padding);
1221                                cur_seg.size += len_with_padding;
1222                                segment_count -= 1;
1223                                continue;
1224                            }
1225                        }
1226                        last_segment_bounced = false;
1227                    }
1228
1229                    sgl_idx += 1;
1230                    if sgl_idx == hardware_segment_limit {
1231                        tracelimit::error_ratelimited!(
1232                            segments_remaining = segment_count - sgl_idx,
1233                            hardware_segment_limit,
1234                            "Failed to bounce buffer the packet too many segments"
1235                        );
1236                        // Drop the packet, no need to free bounce buffer
1237                        return Ok(None);
1238                    }
1239
1240                    sgl[sgl_idx] = Sge {
1241                        address: self.guest_memory.iova(tail.gpa).unwrap(),
1242                        mem_key: self.mem_key,
1243                        size: tail.len,
1244                    };
1245                }
1246                &sgl[..segment_count]
1247            };
1248
1249            let wqe_len = if short_format {
1250                self.tx_wq
1251                    .push(
1252                        &oob.s_oob,
1253                        sgl.iter().copied(),
1254                        meta.offload_tcp_segmentation.then(|| sgl[0].size as u8),
1255                        gd_client_unit_data,
1256                    )
1257                    .unwrap()
1258            } else {
1259                self.tx_wq
1260                    .push(
1261                        &oob,
1262                        sgl.iter().copied(),
1263                        meta.offload_tcp_segmentation.then(|| sgl[0].size as u8),
1264                        gd_client_unit_data,
1265                    )
1266                    .unwrap()
1267            };
1268            PostedTx {
1269                id: meta.id,
1270                wqe_len,
1271                bounced_len_with_padding: bounce_buffer.commit(),
1272            }
1273        };
1274        Ok(Some(tx))
1275    }
1276}
1277
1278struct ContiguousBufferInUse {
1279    pub gpa: u64,
1280    pub offset: u32,
1281    pub len_with_padding: u32,
1282}
1283
1284struct ContiguousBuffer<'a, 'b> {
1285    parent: &'a mut ContiguousBufferManagerTransaction<'b>,
1286    offset: u32,
1287    len: u32,
1288    padding_len: u32,
1289}
1290
1291impl<'a, 'b> ContiguousBuffer<'a, 'b> {
1292    pub fn new(
1293        parent: &'a mut ContiguousBufferManagerTransaction<'b>,
1294        offset: u32,
1295        len: u32,
1296        padding_len: u32,
1297    ) -> Self {
1298        Self {
1299            parent,
1300            offset,
1301            len,
1302            padding_len,
1303        }
1304    }
1305
1306    pub fn as_slice(&mut self) -> &[AtomicU8] {
1307        &self.parent.as_slice()[self.offset as usize..(self.offset + self.len) as usize]
1308    }
1309
1310    pub fn reserve(self) -> ContiguousBufferInUse {
1311        let page = self.offset / PAGE_SIZE32;
1312        let offset_in_page = self.offset - page * PAGE_SIZE32;
1313        let gpa = self.parent.page_gpa(page as usize) + offset_in_page as u64;
1314        let len_with_padding = self.len + self.padding_len;
1315        self.parent.head = self.parent.head.wrapping_add(len_with_padding);
1316        ContiguousBufferInUse {
1317            gpa,
1318            offset: self.offset,
1319            len_with_padding,
1320        }
1321    }
1322}
1323
1324struct ContiguousBufferManagerTransaction<'a> {
1325    parent: &'a mut ContiguousBufferManager,
1326    pub head: u32,
1327}
1328
1329impl<'a> ContiguousBufferManagerTransaction<'a> {
1330    pub fn new(parent: &'a mut ContiguousBufferManager) -> Self {
1331        let head = parent.head;
1332        Self { parent, head }
1333    }
1334
1335    /// Allocates from next section of available ring buffer.
1336    pub fn allocate<'b>(&'b mut self, len: u32) -> Result<ContiguousBuffer<'b, 'a>, OutOfMemory> {
1337        assert!(len < PAGE_SIZE32);
1338        let mut len_with_padding = len;
1339        let mut allocated_offset = self.head;
1340        let bytes_remaining_on_page = PAGE_SIZE32 - (self.head & (PAGE_SIZE32 - 1));
1341        if len > bytes_remaining_on_page {
1342            allocated_offset = allocated_offset.wrapping_add(bytes_remaining_on_page);
1343            len_with_padding += bytes_remaining_on_page;
1344        }
1345        if len_with_padding > self.parent.tail.wrapping_sub(self.head) {
1346            self.parent.failed_allocations += 1;
1347            return Err(OutOfMemory);
1348        }
1349        Ok(ContiguousBuffer::new(
1350            self,
1351            allocated_offset % self.parent.len,
1352            len,
1353            len_with_padding - len,
1354        ))
1355    }
1356
1357    pub fn try_extend<'b>(&'b mut self, len: u32) -> Option<ContiguousBuffer<'b, 'a>> {
1358        let bytes_remaining_on_page = PAGE_SIZE32 - (self.head & (PAGE_SIZE32 - 1));
1359        if bytes_remaining_on_page == PAGE_SIZE32 {
1360            // Used the entire previous page. Cannot extend onto a new page.
1361            return None;
1362        }
1363        if len <= bytes_remaining_on_page {
1364            self.allocate(len).ok()
1365        } else {
1366            None
1367        }
1368    }
1369
1370    pub fn commit(self) -> u32 {
1371        self.parent.split_headers += 1;
1372        let len_with_padding = self.head.wrapping_sub(self.parent.head);
1373        self.parent.head = self.head;
1374        len_with_padding
1375    }
1376
1377    pub fn as_slice(&self) -> &[AtomicU8] {
1378        self.parent.as_slice()
1379    }
1380
1381    pub fn page_gpa(&self, page_idx: usize) -> u64 {
1382        self.parent.mem.pfns()[page_idx] * PAGE_SIZE64
1383    }
1384}
1385
1386struct ContiguousBufferManager {
1387    len: u32,
1388    head: u32,
1389    tail: u32,
1390    mem: MemoryBlock,
1391    // Counters
1392    split_headers: u64,
1393    failed_allocations: u64,
1394}
1395
1396#[derive(Debug, Error)]
1397#[error("out of bounce buffer memory")]
1398struct OutOfMemory;
1399
1400impl ContiguousBufferManager {
1401    pub fn new(dma_client: Arc<dyn DmaClient>, page_limit: u32) -> anyhow::Result<Self> {
1402        let len = PAGE_SIZE32 * page_limit;
1403        let mem = dma_client.allocate_dma_buffer(len as usize)?;
1404        Ok(Self {
1405            len,
1406            head: 0,
1407            tail: len - 1,
1408            mem,
1409            split_headers: 0,
1410            failed_allocations: 0,
1411        })
1412    }
1413
1414    pub fn start_allocation(&mut self) -> ContiguousBufferManagerTransaction<'_> {
1415        ContiguousBufferManagerTransaction::new(self)
1416    }
1417
1418    /// Frees oldest reserved range by advancing the tail of the ring buffer to
1419    /// account for that range. This requires entries to be consumed FIFO.
1420    pub fn free(&mut self, len_with_padding: u32) {
1421        self.tail = self.tail.wrapping_add(len_with_padding);
1422    }
1423
1424    pub fn as_slice(&self) -> &[AtomicU8] {
1425        self.mem.as_slice()
1426    }
1427}
1428
1429impl Inspect for ContiguousBufferManager {
1430    fn inspect(&self, req: inspect::Request<'_>) {
1431        req.respond()
1432            .counter("split_headers", self.split_headers)
1433            .counter("failed_allocations", self.failed_allocations);
1434    }
1435}
1436
1437#[cfg(test)]
1438mod tests {
1439    use crate::GuestDmaMode;
1440    use crate::ManaEndpoint;
1441    use chipset_device::mmio::ExternallyManagedMmioIntercepts;
1442    use gdma::VportConfig;
1443    use gdma_defs::bnic::ManaQueryDeviceCfgResp;
1444    use mana_driver::mana::ManaDevice;
1445    use net_backend::Endpoint;
1446    use net_backend::QueueConfig;
1447    use net_backend::RxId;
1448    use net_backend::TxId;
1449    use net_backend::TxSegment;
1450    use net_backend::loopback::LoopbackEndpoint;
1451    use pal_async::DefaultDriver;
1452    use pal_async::async_test;
1453    use pci_core::msi::MsiInterruptSet;
1454    use std::future::poll_fn;
1455    use test_with_tracing::test;
1456    use user_driver_emulated_mock::DeviceTestMemory;
1457    use user_driver_emulated_mock::EmulatedDevice;
1458    use vmcore::vm_task::SingleDriverBackend;
1459    use vmcore::vm_task::VmTaskDriverSource;
1460
1461    /// Constructs a mana emulator backed by the loopback endpoint, then hooks a
1462    /// mana driver up to it, puts the net_mana endpoint on top of that, and
1463    /// ensures that packets can be sent and received.
1464    #[async_test]
1465    async fn test_endpoint_direct_dma(driver: DefaultDriver) {
1466        test_endpoint(driver, GuestDmaMode::DirectDma, 1138, 1).await;
1467    }
1468
1469    #[async_test]
1470    async fn test_endpoint_bounce_buffer(driver: DefaultDriver) {
1471        test_endpoint(driver, GuestDmaMode::BounceBuffer, 1138, 1).await;
1472    }
1473
1474    #[async_test]
1475    async fn test_segment_coalescing(driver: DefaultDriver) {
1476        // 34 segments of 60 bytes each == 2040
1477        test_endpoint(driver, GuestDmaMode::DirectDma, 2040, 34).await;
1478    }
1479
1480    #[async_test]
1481    async fn test_segment_coalescing_many(driver: DefaultDriver) {
1482        // 128 segments of 16 bytes each == 2048
1483        test_endpoint(driver, GuestDmaMode::DirectDma, 2048, 128).await;
1484    }
1485
1486    async fn test_endpoint(
1487        driver: DefaultDriver,
1488        dma_mode: GuestDmaMode,
1489        packet_len: usize,
1490        num_segments: usize,
1491    ) {
1492        let pages = 256; // 1MB
1493        let allow_dma = dma_mode == GuestDmaMode::DirectDma;
1494        let mem: DeviceTestMemory = DeviceTestMemory::new(pages * 2, allow_dma, "test_endpoint");
1495        let payload_mem = mem.payload_mem();
1496
1497        let mut msi_set = MsiInterruptSet::new();
1498        let device = gdma::GdmaDevice::new(
1499            &VmTaskDriverSource::new(SingleDriverBackend::new(driver.clone())),
1500            mem.guest_memory(),
1501            &mut msi_set,
1502            vec![VportConfig {
1503                mac_address: [1, 2, 3, 4, 5, 6].into(),
1504                endpoint: Box::new(LoopbackEndpoint::new()),
1505            }],
1506            &mut ExternallyManagedMmioIntercepts,
1507        );
1508        let device = EmulatedDevice::new(device, msi_set, mem.dma_client());
1509        let dev_config = ManaQueryDeviceCfgResp {
1510            pf_cap_flags1: 0.into(),
1511            pf_cap_flags2: 0,
1512            pf_cap_flags3: 0,
1513            pf_cap_flags4: 0,
1514            max_num_vports: 1,
1515            reserved: 0,
1516            max_num_eqs: 64,
1517        };
1518        let thing = ManaDevice::new(&driver, device, 1, 1).await.unwrap();
1519        let vport = thing.new_vport(0, None, &dev_config).await.unwrap();
1520        let mut endpoint = ManaEndpoint::new(driver.clone(), vport, dma_mode).await;
1521        let mut queues = Vec::new();
1522        let pool = net_backend::tests::Bufs::new(payload_mem.clone());
1523        endpoint
1524            .get_queues(
1525                vec![QueueConfig {
1526                    pool: Box::new(pool),
1527                    initial_rx: &(1..128).map(RxId).collect::<Vec<_>>(),
1528                    driver: Box::new(driver.clone()),
1529                }],
1530                None,
1531                &mut queues,
1532            )
1533            .await
1534            .unwrap();
1535
1536        for i in 0..1000 {
1537            let sent_data = (0..packet_len).map(|v| (i + v) as u8).collect::<Vec<u8>>();
1538            payload_mem.write_at(0, &sent_data).unwrap();
1539
1540            let mut segments = Vec::new();
1541            let segment_len = packet_len / num_segments;
1542            assert!(packet_len % num_segments == 0);
1543            assert!(sent_data.len() == packet_len);
1544            segments.push(TxSegment {
1545                ty: net_backend::TxSegmentType::Head(net_backend::TxMetadata {
1546                    id: TxId(1),
1547                    segment_count: num_segments,
1548                    len: sent_data.len(),
1549                    ..Default::default()
1550                }),
1551                gpa: 0,
1552                len: segment_len as u32,
1553            });
1554
1555            for j in 0..(num_segments - 1) {
1556                let gpa = (j + 1) * segment_len;
1557                segments.push(TxSegment {
1558                    ty: net_backend::TxSegmentType::Tail,
1559                    gpa: gpa as u64,
1560                    len: segment_len as u32,
1561                });
1562            }
1563            assert!(segments.len() == num_segments);
1564
1565            queues[0].tx_avail(segments.as_slice()).unwrap();
1566
1567            let mut packets = [RxId(0); 2];
1568            let mut done = [TxId(0); 2];
1569            let mut done_n = 0;
1570            let mut packets_n = 0;
1571            while done_n == 0 || packets_n == 0 {
1572                poll_fn(|cx| queues[0].poll_ready(cx)).await;
1573                packets_n += queues[0].rx_poll(&mut packets[packets_n..]).unwrap();
1574                done_n += queues[0].tx_poll(&mut done[done_n..]).unwrap();
1575            }
1576            assert_eq!(packets_n, 1);
1577            let rx_id = packets[0];
1578
1579            let mut received_data = vec![0; packet_len];
1580            payload_mem
1581                .read_at(2048 * rx_id.0 as u64, &mut received_data)
1582                .unwrap();
1583            assert!(received_data.len() == packet_len);
1584            assert_eq!(&received_data[..], sent_data, "{i} {:?}", rx_id);
1585            assert_eq!(done_n, 1);
1586            assert_eq!(done[0].0, 1);
1587            queues[0].rx_avail(&[rx_id]);
1588        }
1589
1590        drop(queues);
1591        endpoint.stop().await;
1592    }
1593
1594    #[async_test]
1595    async fn test_vport_with_query_filter_state(driver: DefaultDriver) {
1596        let pages = 512; // 2MB
1597        let mem = DeviceTestMemory::new(pages, false, "test_vport_with_query_filter_state");
1598        let mut msi_set = MsiInterruptSet::new();
1599        let device = gdma::GdmaDevice::new(
1600            &VmTaskDriverSource::new(SingleDriverBackend::new(driver.clone())),
1601            mem.guest_memory(),
1602            &mut msi_set,
1603            vec![VportConfig {
1604                mac_address: [1, 2, 3, 4, 5, 6].into(),
1605                endpoint: Box::new(LoopbackEndpoint::new()),
1606            }],
1607            &mut ExternallyManagedMmioIntercepts,
1608        );
1609        let dma_client = mem.dma_client();
1610        let device = EmulatedDevice::new(device, msi_set, dma_client);
1611        let cap_flags1 = gdma_defs::bnic::BasicNicDriverFlags::new().with_query_filter_state(1);
1612        let dev_config = ManaQueryDeviceCfgResp {
1613            pf_cap_flags1: cap_flags1,
1614            pf_cap_flags2: 0,
1615            pf_cap_flags3: 0,
1616            pf_cap_flags4: 0,
1617            max_num_vports: 1,
1618            reserved: 0,
1619            max_num_eqs: 64,
1620        };
1621        let thing = ManaDevice::new(&driver, device, 1, 1).await.unwrap();
1622        let _ = thing.new_vport(0, None, &dev_config).await.unwrap();
1623    }
1624}