virtio_net/
lib.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4#![expect(missing_docs)]
5#![forbid(unsafe_code)]
6
7mod buffers;
8pub mod resolver;
9
10// use anyhow::Context;
11use crate::buffers::VirtioWorkPool;
12use bitfield_struct::bitfield;
13use futures::FutureExt;
14use futures::StreamExt;
15use futures_concurrency::future::Race;
16use guestmem::GuestMemory;
17use inspect::Inspect;
18use inspect::InspectMut;
19use inspect_counters::Counter;
20use inspect_counters::Histogram;
21use net_backend::Endpoint;
22use net_backend::EndpointAction;
23use net_backend::QueueConfig;
24use net_backend::RxId;
25use net_backend::TxId;
26use net_backend::TxMetadata;
27use net_backend::TxSegment;
28use net_backend::TxSegmentType;
29use net_backend_resources::mac_address::MacAddress;
30use pal_async::wait::PolledWait;
31use std::future::pending;
32use std::mem::offset_of;
33use std::sync::Arc;
34use std::task::Poll;
35use task_control::AsyncRun;
36use task_control::InspectTaskMut;
37use task_control::StopTask;
38use task_control::TaskControl;
39use thiserror::Error;
40use virtio::DeviceTraits;
41use virtio::DeviceTraitsSharedMemory;
42use virtio::Resources;
43use virtio::VirtioDevice;
44use virtio::VirtioQueue;
45use virtio::VirtioQueueCallbackWork;
46use vmcore::vm_task::VmTaskDriver;
47use vmcore::vm_task::VmTaskDriverSource;
48use zerocopy::FromBytes;
49use zerocopy::Immutable;
50use zerocopy::IntoBytes;
51use zerocopy::KnownLayout;
52
53// These correspond to VIRTIO_NET_F_ flags.
54#[bitfield(u64)]
55#[derive(IntoBytes, Immutable, KnownLayout, FromBytes)]
56struct NetworkFeatures {
57    pub csum: bool,
58    pub guest_csum: bool,
59    pub ctrl_guest_offloads: bool,
60    pub mtu: bool,
61    _reserved: bool,
62    pub mac: bool,
63    _reserved2: bool,
64    pub guest_tso4: bool,
65    pub guest_tso6: bool,
66    pub guest_ecn: bool,
67    pub guest_ufo: bool,
68    pub host_tso4: bool,
69    pub host_tso6: bool,
70    pub host_ecn: bool,
71    pub host_ufo: bool,
72    pub mrg_rxbuf: bool,
73    pub status: bool,
74    pub ctrl_vq: bool,
75    pub ctrl_rx: bool,
76    pub ctrl_vlan: bool,
77    _reserved3: bool,
78    pub guest_announce: bool,
79    pub mq: bool,
80    pub ctrl_mac_addr: bool,
81    #[bits(29)]
82    _reserved4: u64,
83    pub notf_coal: bool,
84    pub guest_uso4: bool,
85    pub guest_uso6: bool,
86    pub host_uso: bool,
87    pub hash_report: bool,
88    _reserved5: bool,
89    pub guest_hdrlen: bool,
90    pub rss: bool,
91    pub rsc_ext: bool,
92    pub standby: bool,
93    pub speed_duplex: bool,
94}
95
96// These correspond to VIRTIO_NET_S_ flags.
97#[bitfield(u16)]
98#[derive(IntoBytes, Immutable, KnownLayout, FromBytes)]
99struct NetStatus {
100    pub link_up: bool,
101    pub announce: bool,
102    #[bits(14)]
103    _reserved: u16,
104}
105
106const DEFAULT_MTU: u16 = 1514;
107
108#[expect(dead_code)]
109const VIRTIO_NET_MAX_QUEUES: u16 = 0x8000;
110
111#[repr(C)]
112struct NetConfig {
113    pub mac: [u8; 6],
114    pub status: u16,
115    pub max_virtqueue_pairs: u16,
116    pub mtu: u16,
117    pub speed: u32,                            // MBit/s; 0xffffffff - unknown speed
118    pub duplex: u8,                            // 0 - half, 1 - full, 0xff - unknown
119    pub rss_max_key_size: u8,                  // VIRTIO_NET_F_RSS or VIRTIO_NET_F_HASH_REPORT
120    pub rss_max_indirection_table_length: u16, // VIRTIO_NET_F_RSS
121    pub supported_hash_types: u32,             // VIRTIO_NET_F_RSS or VIRTIO_NET_F_HASH_REPORT
122}
123
124// These correspond to VIRTIO_NET_HDR_F_ flags.
125#[bitfield(u8)]
126#[derive(IntoBytes, Immutable, KnownLayout, FromBytes)]
127struct VirtioNetHeaderFlags {
128    pub needs_csum: bool,
129    pub data_valid: bool,
130    pub rsc_info: bool,
131    #[bits(5)]
132    _reserved: u8,
133}
134
135#[bitfield(u8)]
136#[derive(IntoBytes, Immutable, KnownLayout, FromBytes)]
137struct VirtioNetHeaderGso {
138    #[bits(3)]
139    pub protocol: VirtioNetHeaderGsoProtocol,
140    #[bits(4)]
141    _reserved: u8,
142    pub ecn: bool,
143}
144
145// These correspond to VIRTIO_NET_HDR_GSO_ values.
146open_enum::open_enum! {
147    #[derive(IntoBytes, Immutable, KnownLayout, FromBytes)]
148    enum VirtioNetHeaderGsoProtocol: u8 {
149        NONE = 0,
150        TCPV4 = 1,
151        UDP = 3,
152        TCPV6 = 4,
153        UDP_L4 = 5,
154    }
155}
156
157impl VirtioNetHeaderGsoProtocol {
158    const fn from_bits(bits: u8) -> Self {
159        Self(bits)
160    }
161
162    const fn into_bits(self) -> u8 {
163        self.0
164    }
165}
166
167#[derive(IntoBytes, Immutable, KnownLayout, FromBytes)]
168#[repr(C)]
169struct VirtioNetHeader {
170    pub flags: u8,
171    pub gso_type: u8,
172    pub hdr_len: u16,
173    pub gso_size: u16,
174    pub csum_start: u16,
175    pub csum_offset: u16,
176    pub num_buffers: u16,
177    pub hash_value: u32,       // Only if VIRTIO_NET_F_HASH_REPORT negotiated
178    pub hash_report: u16,      // Only if VIRTIO_NET_F_HASH_REPORT negotiated
179    pub padding_reserved: u16, // Only if VIRTIO_NET_F_HASH_REPORT negotiated
180}
181
182fn header_size() -> usize {
183    // TODO: Verify hash flags are not set, since header size would be larger in that case.
184    offset_of!(VirtioNetHeader, hash_value)
185}
186
187struct Adapter {
188    driver: VmTaskDriver,
189    max_queues: u16,
190    tx_fast_completions: bool,
191    mac_address: MacAddress,
192}
193
194pub struct Device {
195    registers: NetConfig,
196    memory: GuestMemory,
197    coordinator: TaskControl<CoordinatorState, Coordinator>,
198    coordinator_send: Option<mesh::Sender<CoordinatorMessage>>,
199    adapter: Arc<Adapter>,
200    driver_source: VmTaskDriverSource,
201}
202
203impl Drop for Device {
204    fn drop(&mut self) {}
205}
206
207impl VirtioDevice for Device {
208    fn traits(&self) -> DeviceTraits {
209        // TODO: Add network features based on endpoint capabilities (NetworkFeatures::VIRTIO_NET_F_*)
210        DeviceTraits {
211            device_id: 1,
212            device_features: NetworkFeatures::new().with_mac(true).into(),
213            max_queues: 2 * self.registers.max_virtqueue_pairs,
214            device_register_length: size_of::<NetConfig>() as u32,
215            shared_memory: DeviceTraitsSharedMemory { id: 0, size: 0 },
216        }
217    }
218
219    fn read_registers_u32(&self, offset: u16) -> u32 {
220        match offset {
221            0 => u32::from_le_bytes(self.registers.mac[..4].try_into().unwrap()),
222            4 => {
223                (u16::from_le_bytes(self.registers.mac[4..].try_into().unwrap()) as u32)
224                    | ((self.registers.status as u32) << 16)
225            }
226            8 => (self.registers.max_virtqueue_pairs as u32) | ((self.registers.mtu as u32) << 16),
227            12 => self.registers.speed,
228            16 => {
229                (self.registers.duplex as u32)
230                    | ((self.registers.rss_max_key_size as u32) << 8)
231                    | ((self.registers.rss_max_indirection_table_length as u32) << 24)
232            }
233            20 => self.registers.supported_hash_types,
234            _ => 0,
235        }
236    }
237
238    fn write_registers_u32(&mut self, _offset: u16, _val: u32) {}
239
240    fn enable(&mut self, resources: Resources) {
241        let mut queue_resources: Vec<_> = resources.queues.into_iter().collect();
242        let mut workers = Vec::with_capacity(queue_resources.len() / 2);
243        while queue_resources.len() > 1 {
244            let mut next = queue_resources.drain(..2);
245            let rx_resources = next.next().unwrap();
246            let tx_resources = next.next().unwrap();
247            if !rx_resources.params.enable || !tx_resources.params.enable {
248                continue;
249            }
250
251            let rx_queue_size = rx_resources.params.size;
252            let rx_queue_event = PolledWait::new(&self.adapter.driver, rx_resources.event);
253            if let Err(err) = rx_queue_event {
254                tracing::error!(
255                    err = &err as &dyn std::error::Error,
256                    "Failed creating queue event"
257                );
258                continue;
259            }
260            let rx_queue = VirtioQueue::new(
261                resources.features,
262                rx_resources.params,
263                self.memory.clone(),
264                rx_resources.notify,
265                rx_queue_event.unwrap(),
266            );
267            if let Err(err) = rx_queue {
268                tracing::error!(
269                    err = &err as &dyn std::error::Error,
270                    "Failed creating virtio net receive queue"
271                );
272                continue;
273            }
274
275            let tx_queue_size = tx_resources.params.size;
276            let tx_queue_event = PolledWait::new(&self.adapter.driver, tx_resources.event);
277            if let Err(err) = tx_queue_event {
278                tracing::error!(
279                    err = &err as &dyn std::error::Error,
280                    "Failed creating queue event"
281                );
282                continue;
283            }
284            let tx_queue = VirtioQueue::new(
285                resources.features,
286                tx_resources.params,
287                self.memory.clone(),
288                tx_resources.notify,
289                tx_queue_event.unwrap(),
290            );
291            if let Err(err) = tx_queue {
292                tracing::error!(
293                    err = &err as &dyn std::error::Error,
294                    "Failed creating virtio net transmit queue"
295                );
296                continue;
297            }
298            workers.push(VirtioState {
299                rx_queue: rx_queue.unwrap(),
300                rx_queue_size,
301                tx_queue: tx_queue.unwrap(),
302                tx_queue_size,
303            });
304        }
305
306        let (tx, rx) = mesh::channel();
307        self.coordinator_send = Some(tx);
308        self.insert_coordinator(rx, workers.len() as u16);
309        for (i, virtio_state) in workers.into_iter().enumerate() {
310            self.insert_worker(virtio_state, i);
311        }
312        self.coordinator.start();
313    }
314
315    fn disable(&mut self) {
316        if let Some(send) = self.coordinator_send.take() {
317            send.send(CoordinatorMessage::Disable);
318        }
319    }
320}
321
322struct EndpointQueueState {
323    queue: Box<dyn net_backend::Queue>,
324}
325
326struct NetQueue {
327    state: Option<EndpointQueueState>,
328}
329
330impl InspectTaskMut<Worker> for NetQueue {
331    fn inspect_mut(&mut self, req: inspect::Request<'_>, worker: Option<&mut Worker>) {
332        if worker.is_none() && self.state.is_none() {
333            req.ignore();
334            return;
335        }
336        let mut resp = req.respond();
337        if let Some(worker) = worker {
338            resp.field(
339                "pending_tx_packets",
340                worker
341                    .active_state
342                    .pending_tx_packets
343                    .iter()
344                    .fold(0, |acc, next| acc + if next.is_some() { 1 } else { 0 }),
345            )
346            .field(
347                "pending_rx_packets",
348                worker.active_state.pending_rx_packets.ready().len(),
349            )
350            .field(
351                "pending_tx",
352                !worker.active_state.data.tx_segments.is_empty(),
353            )
354            .merge(&worker.active_state.stats);
355        }
356
357        if let Some(epqueue_state) = &mut self.state {
358            resp.field_mut("queue", &mut epqueue_state.queue);
359        }
360    }
361}
362
363/// Buffers used during packet processing.
364struct ProcessingData {
365    tx_segments: Vec<TxSegment>,
366    tx_done: Box<[TxId]>,
367    rx_ready: Box<[RxId]>,
368}
369
370impl ProcessingData {
371    fn new(rx_queue_size: u16, tx_queue_size: u16) -> Self {
372        Self {
373            tx_segments: Vec::new(),
374            tx_done: vec![TxId(0); tx_queue_size as usize].into(),
375            rx_ready: vec![RxId(0); rx_queue_size as usize].into(),
376        }
377    }
378}
379
380#[derive(Inspect, Default)]
381struct QueueStats {
382    tx_stalled: Counter,
383    spurious_wakes: Counter,
384    rx_packets: Counter,
385    tx_packets: Counter,
386    tx_packets_per_wake: Histogram<10>,
387    rx_packets_per_wake: Histogram<10>,
388}
389
390struct ActiveState {
391    pending_tx_packets: Vec<Option<PendingTxPacket>>,
392    pending_rx_packets: VirtioWorkPool,
393    data: ProcessingData,
394    stats: QueueStats,
395}
396
397impl ActiveState {
398    fn new(mem: GuestMemory, rx_queue_size: u16, tx_queue_size: u16) -> Self {
399        Self {
400            pending_tx_packets: (0..tx_queue_size).map(|_| None).collect(),
401            pending_rx_packets: VirtioWorkPool::new(mem, rx_queue_size),
402            data: ProcessingData::new(rx_queue_size, tx_queue_size),
403            stats: Default::default(),
404        }
405    }
406}
407
408/// The state for a tx packet that's currently pending in the backend endpoint.
409struct PendingTxPacket {
410    work: VirtioQueueCallbackWork,
411}
412
413pub struct NicBuilder {
414    max_queues: u16,
415}
416
417impl NicBuilder {
418    pub fn max_queues(mut self, max_queues: u16) -> Self {
419        self.max_queues = max_queues;
420        self
421    }
422
423    /// Creates a new NIC.
424    pub fn build(
425        self,
426        driver_source: &VmTaskDriverSource,
427        memory: GuestMemory,
428        endpoint: Box<dyn Endpoint>,
429        mac_address: MacAddress,
430    ) -> Device {
431        // TODO: Implement VIRTIO_NET_F_MQ and VIRTIO_NET_F_RSS logic based on mulitqueue support.
432        // let multiqueue = endpoint.multiqueue_support();
433        // let max_queues = self.max_queues.clamp(1, multiqueue.max_queues.min(VIRTIO_NET_MAX_QUEUES));
434        let max_queues = 1;
435
436        let driver = driver_source.simple();
437        let adapter = Arc::new(Adapter {
438            driver,
439            max_queues,
440            tx_fast_completions: endpoint.tx_fast_completions(),
441            mac_address,
442        });
443
444        let coordinator = TaskControl::new(CoordinatorState {
445            endpoint,
446            adapter: adapter.clone(),
447        });
448
449        let registers = NetConfig {
450            mac: mac_address.to_bytes(),
451            status: NetStatus::new().with_link_up(true).into(),
452            max_virtqueue_pairs: max_queues,
453            mtu: DEFAULT_MTU,
454            speed: 0xffffffff,
455            duplex: 0xff,
456            rss_max_key_size: 0,
457            rss_max_indirection_table_length: 0,
458            supported_hash_types: 0,
459        };
460
461        Device {
462            registers,
463            memory,
464            coordinator,
465            coordinator_send: None,
466            adapter,
467            driver_source: driver_source.clone(),
468        }
469    }
470}
471
472impl Device {
473    pub fn builder() -> NicBuilder {
474        NicBuilder { max_queues: !0 }
475    }
476}
477
478impl InspectMut for Device {
479    fn inspect_mut(&mut self, req: inspect::Request<'_>) {
480        self.coordinator.inspect_mut(req);
481    }
482}
483
484impl Device {
485    fn insert_coordinator(&mut self, recv: mesh::Receiver<CoordinatorMessage>, num_queues: u16) {
486        self.coordinator.insert(
487            &self.adapter.driver,
488            "virtio-net-coordinator".to_string(),
489            Coordinator {
490                recv,
491                workers: (0..self.adapter.max_queues)
492                    .map(|_| TaskControl::new(NetQueue { state: None }))
493                    .collect(),
494                num_queues,
495                restart: true,
496            },
497        );
498    }
499
500    /// Allocates and inserts a worker.
501    ///
502    /// The coordinator must be stopped.
503    fn insert_worker(&mut self, virtio_state: VirtioState, idx: usize) {
504        let mut builder = self.driver_source.builder();
505        // TODO: set this correctly
506        builder.target_vp(0);
507        // If tx completions arrive quickly, then just do tx processing
508        // on whatever processor the guest happens to signal from.
509        // Subsequent transmits will be pulled from the completion
510        // processor.
511        builder.run_on_target(!self.adapter.tx_fast_completions);
512        let driver = builder.build("virtio-net");
513
514        let active_state = ActiveState::new(
515            self.memory.clone(),
516            virtio_state.rx_queue_size,
517            virtio_state.tx_queue_size,
518        );
519        let worker = Worker {
520            virtio_state,
521            active_state,
522        };
523        let coordinator = self.coordinator.state_mut().unwrap();
524        let worker_task = &mut coordinator.workers[idx];
525        worker_task.insert(&driver, "virtio-net".to_string(), worker);
526        worker_task.start();
527    }
528}
529
530#[derive(PartialEq)]
531enum CoordinatorMessage {
532    Disable,
533}
534
535struct Coordinator {
536    recv: mesh::Receiver<CoordinatorMessage>,
537    workers: Vec<TaskControl<NetQueue, Worker>>,
538    num_queues: u16,
539    restart: bool,
540}
541
542struct CoordinatorState {
543    endpoint: Box<dyn Endpoint>,
544    adapter: Arc<Adapter>,
545}
546
547impl InspectTaskMut<Coordinator> for CoordinatorState {
548    fn inspect_mut(&mut self, req: inspect::Request<'_>, coordinator: Option<&mut Coordinator>) {
549        let mut resp = req.respond();
550
551        let adapter = self.adapter.as_ref();
552        resp.field("mac_address", adapter.mac_address)
553            .field("max_queues", adapter.max_queues);
554
555        resp.field("endpoint_type", self.endpoint.endpoint_type())
556            .field(
557                "endpoint_max_queues",
558                self.endpoint.multiqueue_support().max_queues,
559            )
560            .field_mut("endpoint", self.endpoint.as_mut());
561
562        if let Some(coordinator) = coordinator {
563            resp.fields_mut(
564                "queues",
565                coordinator.workers[..coordinator.num_queues as usize]
566                    .iter_mut()
567                    .enumerate(),
568            );
569        }
570    }
571}
572
573impl AsyncRun<Coordinator> for CoordinatorState {
574    async fn run(
575        &mut self,
576        stop: &mut StopTask<'_>,
577        coordinator: &mut Coordinator,
578    ) -> Result<(), task_control::Cancelled> {
579        coordinator.process(stop, self).await
580    }
581}
582
583impl Coordinator {
584    async fn process(
585        &mut self,
586        stop: &mut StopTask<'_>,
587        state: &mut CoordinatorState,
588    ) -> Result<(), task_control::Cancelled> {
589        loop {
590            if self.restart {
591                stop.until_stopped(self.stop_workers()).await?;
592                // The queue restart operation is not restartable, so do not
593                // poll on `stop` here.
594                if let Err(err) = self.restart_queues(state).await {
595                    tracing::error!(
596                        error = &err as &dyn std::error::Error,
597                        "failed to restart queues"
598                    );
599                }
600                self.restart = false;
601            }
602            self.start_workers();
603            enum Message {
604                Internal(CoordinatorMessage),
605                ChannelDisconnected,
606                UpdateFromEndpoint(EndpointAction),
607            }
608            let message = {
609                let wait_for_message = async {
610                    let internal_msg = self
611                        .recv
612                        .next()
613                        .map(|x| x.map_or(Message::ChannelDisconnected, Message::Internal));
614                    let endpoint_restart = state
615                        .endpoint
616                        .wait_for_endpoint_action()
617                        .map(Message::UpdateFromEndpoint);
618                    (internal_msg, endpoint_restart).race().await
619                };
620                stop.until_stopped(wait_for_message).await?
621            };
622            match message {
623                Message::UpdateFromEndpoint(EndpointAction::RestartRequired) => self.restart = true,
624                Message::UpdateFromEndpoint(EndpointAction::LinkStatusNotify(_)) => {
625                    tracing::error!("unexpected link status notification")
626                }
627                Message::Internal(CoordinatorMessage::Disable) | Message::ChannelDisconnected => {
628                    stop.until_stopped(self.stop_workers()).await?;
629                    break;
630                }
631            };
632        }
633        Ok(())
634    }
635
636    async fn stop_workers(&mut self) {
637        for worker in &mut self.workers {
638            worker.stop().await;
639        }
640    }
641
642    async fn restart_queues(&mut self, c_state: &mut CoordinatorState) -> Result<(), WorkerError> {
643        // Drop all of the current queues.
644        for worker in &mut self.workers {
645            worker.task_mut().state = None;
646        }
647
648        let (rx_pools, ready_packets): (Vec<_>, Vec<_>) = self
649            .workers
650            .iter()
651            .map(|worker| {
652                let pool = worker
653                    .state()
654                    .unwrap()
655                    .active_state
656                    .pending_rx_packets
657                    .clone();
658                let ready = pool.ready();
659                (pool, ready)
660            })
661            .collect::<Vec<_>>()
662            .into_iter()
663            .unzip();
664        let mut queue_config = Vec::with_capacity(rx_pools.len());
665        for (i, pool) in rx_pools.into_iter().enumerate() {
666            queue_config.push(QueueConfig {
667                pool: Box::new(pool),
668                initial_rx: ready_packets[i].as_slice(),
669                driver: Box::new(c_state.adapter.driver.clone()),
670            });
671        }
672
673        let mut queues = Vec::new();
674        c_state
675            .endpoint
676            .get_queues(queue_config, None, &mut queues)
677            .await
678            .map_err(WorkerError::Endpoint)?;
679
680        assert_eq!(queues.len(), self.workers.len());
681
682        for (worker, queue) in self.workers.iter_mut().zip(queues) {
683            worker.task_mut().state = Some(EndpointQueueState { queue });
684        }
685
686        Ok(())
687    }
688
689    fn start_workers(&mut self) {
690        for worker in &mut self.workers {
691            worker.start();
692        }
693    }
694}
695
696impl AsyncRun<Worker> for NetQueue {
697    async fn run(
698        &mut self,
699        stop: &mut StopTask<'_>,
700        worker: &mut Worker,
701    ) -> Result<(), task_control::Cancelled> {
702        match worker.process(stop, self).await {
703            Ok(()) => {}
704            Err(WorkerError::Cancelled(cancelled)) => return Err(cancelled),
705            Err(err) => {
706                tracing::error!(err = &err as &dyn std::error::Error, "virtio net error");
707            }
708        }
709        Ok(())
710    }
711}
712
713struct VirtioState {
714    rx_queue: VirtioQueue,
715    rx_queue_size: u16,
716    tx_queue: VirtioQueue,
717    tx_queue_size: u16,
718}
719
720#[derive(Debug, Error)]
721enum WorkerError {
722    #[error("packet error")]
723    Packet(#[from] PacketError),
724    #[error("virtio queue processing error")]
725    VirtioQueue(#[source] std::io::Error),
726    #[error("endpoint")]
727    Endpoint(#[source] anyhow::Error),
728    #[error("cancelled")]
729    Cancelled(task_control::Cancelled),
730}
731
732impl From<task_control::Cancelled> for WorkerError {
733    fn from(value: task_control::Cancelled) -> Self {
734        Self::Cancelled(value)
735    }
736}
737
738#[derive(Debug, Error)]
739enum PacketError {
740    #[error("Empty packet")]
741    Empty,
742}
743
744struct Worker {
745    virtio_state: VirtioState,
746    active_state: ActiveState,
747}
748
749impl Worker {
750    async fn process(
751        &mut self,
752        stop: &mut StopTask<'_>,
753        queue: &mut NetQueue,
754    ) -> Result<(), WorkerError> {
755        // Be careful not to wait on actions with unbounded blocking time (e.g.
756        // guest actions, or waiting for network packets to arrive) without
757        // wrapping the wait on `stop.until_stopped`.
758        if queue.state.is_none() {
759            // wait for an active queue
760            stop.until_stopped(pending()).await?
761        }
762
763        self.main_loop(stop, queue).await?;
764        Ok(())
765    }
766
767    async fn main_loop(
768        &mut self,
769        stop: &mut StopTask<'_>,
770        queue: &mut NetQueue,
771    ) -> Result<(), WorkerError> {
772        let epqueue_state = queue.state.as_mut().unwrap();
773
774        loop {
775            let did_some_work = self.process_endpoint_rx(epqueue_state.queue.as_mut())?
776                | self.process_virtio_rx(epqueue_state.queue.as_mut())?
777                | self.process_endpoint_tx(epqueue_state.queue.as_mut())?;
778
779            if !did_some_work {
780                self.active_state.stats.spurious_wakes.increment();
781            }
782
783            // This should be the only await point waiting on network traffic or
784            // guest actions. Wrap it in `stop.until_stopped` to allow
785            // cancellation.
786            stop.until_stopped(async {
787                enum WakeReason {
788                    PacketFromClient(Result<VirtioQueueCallbackWork, std::io::Error>),
789                    PacketToClient(Result<VirtioQueueCallbackWork, std::io::Error>),
790                    NetworkBackend,
791                }
792                loop {
793                    let net_queue = std::future::poll_fn(|cx| -> Poll<()> {
794                        // Check the network endpoint for tx completion or rx.
795                        epqueue_state.queue.poll_ready(cx)
796                    })
797                    .map(|_| WakeReason::NetworkBackend);
798                    let to_client = self.virtio_state.rx_queue.next().map(|work| {
799                        WakeReason::PacketToClient(work.expect("queue never completes"))
800                    });
801                    let wake_reason = if self.active_state.data.tx_segments.is_empty() {
802                        let from_client = self.virtio_state.tx_queue.next().map(|work| {
803                            WakeReason::PacketFromClient(work.expect("queue never completes"))
804                        });
805                        (net_queue, from_client, to_client).race().await
806                    } else {
807                        (net_queue, to_client).race().await
808                    };
809                    match wake_reason {
810                        WakeReason::NetworkBackend => {
811                            tracing::trace!("endpoint ready");
812                            return Ok::<(), WorkerError>(());
813                        }
814                        WakeReason::PacketFromClient(work) => {
815                            tracing::trace!("tx packet");
816                            let work = work.map_err(WorkerError::VirtioQueue)?;
817                            self.queue_tx_packet(work)?;
818                            self.process_virtio_rx(epqueue_state.queue.as_mut())?;
819                            if !self.transmit_pending_segments(epqueue_state)? {
820                                self.active_state.stats.tx_stalled.increment();
821                            }
822                        }
823                        WakeReason::PacketToClient(work) => {
824                            tracing::trace!("rx packet");
825                            let work = work.map_err(WorkerError::VirtioQueue)?;
826                            epqueue_state
827                                .queue
828                                .rx_avail(&[self.active_state.pending_rx_packets.queue_work(work)]);
829                        }
830                    }
831                }
832            })
833            .await??;
834        }
835    }
836
837    fn queue_tx_packet(&mut self, mut work: VirtioQueueCallbackWork) -> Result<(), WorkerError> {
838        let mut header_bytes_remaining = header_size() as u32;
839        let mut segments = work
840            .payload
841            .iter()
842            .filter_map(|p| {
843                if p.writeable {
844                    None
845                } else if header_bytes_remaining >= p.length {
846                    header_bytes_remaining -= p.length;
847                    None
848                } else if header_bytes_remaining > 0 {
849                    let segment = TxSegment {
850                        ty: TxSegmentType::Tail,
851                        gpa: p.address + header_bytes_remaining as u64,
852                        len: p.length - header_bytes_remaining,
853                    };
854                    header_bytes_remaining = 0;
855                    Some(segment)
856                } else {
857                    Some(TxSegment {
858                        ty: TxSegmentType::Tail,
859                        gpa: p.address,
860                        len: p.length,
861                    })
862                }
863            })
864            .collect::<Vec<_>>();
865        if segments.is_empty() {
866            work.complete(0);
867            return Err(WorkerError::Packet(PacketError::Empty));
868        }
869        let idx = work.descriptor_index();
870        segments[0].ty = TxSegmentType::Head(TxMetadata {
871            id: TxId(idx.into()),
872            segment_count: segments.len(),
873            len: work.get_payload_length(false) as usize - header_size(),
874            ..Default::default()
875        });
876        let state = &mut self.active_state;
877        state.data.tx_segments.append(&mut segments);
878        assert!(state.pending_tx_packets[idx as usize].is_none());
879        state.pending_tx_packets[idx as usize] = Some(PendingTxPacket { work });
880        Ok(())
881    }
882
883    fn process_virtio_rx(
884        &mut self,
885        epqueue: &mut dyn net_backend::Queue,
886    ) -> Result<bool, WorkerError> {
887        // Fill the receive queue with any available buffers.
888        let mut rx_ids = Vec::new();
889        while let Some(Some(work)) = self.virtio_state.rx_queue.next().now_or_never() {
890            tracing::trace!("rx packet");
891            let work = work.map_err(WorkerError::VirtioQueue)?;
892            rx_ids.push(self.active_state.pending_rx_packets.queue_work(work));
893        }
894        if !rx_ids.is_empty() {
895            epqueue.rx_avail(rx_ids.as_slice());
896            Ok(true)
897        } else {
898            Ok(false)
899        }
900    }
901
902    fn process_endpoint_rx(
903        &mut self,
904        epqueue: &mut dyn net_backend::Queue,
905    ) -> Result<bool, WorkerError> {
906        let state = &mut self.active_state;
907        let n = epqueue
908            .rx_poll(&mut state.data.rx_ready)
909            .map_err(WorkerError::Endpoint)?;
910        if n == 0 {
911            return Ok(false);
912        }
913
914        for ready_id in state.data.rx_ready[..n].iter() {
915            state.stats.rx_packets.increment();
916            state.pending_rx_packets.complete_packet(*ready_id);
917        }
918
919        state.stats.rx_packets_per_wake.add_sample(n as u64);
920        Ok(true)
921    }
922
923    fn process_endpoint_tx(
924        &mut self,
925        epqueue: &mut dyn net_backend::Queue,
926    ) -> Result<bool, WorkerError> {
927        // Drain completed transmits.
928        let n = epqueue
929            .tx_poll(&mut self.active_state.data.tx_done)
930            .map_err(|tx_error| WorkerError::Endpoint(tx_error.into()))?;
931        if n == 0 {
932            return Ok(false);
933        }
934
935        let pending_segment_id = if !self.active_state.data.tx_segments.is_empty() {
936            let TxSegmentType::Head(metadata) = &self.active_state.data.tx_segments[0].ty else {
937                unreachable!()
938            };
939            Some(metadata.id)
940        } else {
941            None
942        };
943        for i in 0..n {
944            let id = self.active_state.data.tx_done[i];
945            self.complete_tx_packet(id)?;
946            if let Some(pending_segment_id) = pending_segment_id {
947                if pending_segment_id.0 == id.0 {
948                    self.active_state.data.tx_segments.clear();
949                }
950            }
951        }
952        self.active_state
953            .stats
954            .tx_packets_per_wake
955            .add_sample(n as u64);
956
957        Ok(true)
958    }
959
960    fn transmit_pending_segments(
961        &mut self,
962        queue_state: &mut EndpointQueueState,
963    ) -> Result<bool, WorkerError> {
964        if self.active_state.data.tx_segments.is_empty() {
965            return Ok(false);
966        }
967        let TxSegmentType::Head(metadata) = &self.active_state.data.tx_segments[0].ty else {
968            unreachable!()
969        };
970        let id = metadata.id;
971        self.transmit_segments(queue_state, id)?;
972        Ok(true)
973    }
974
975    fn transmit_segments(
976        &mut self,
977        queue_state: &mut EndpointQueueState,
978        id: TxId,
979    ) -> Result<(), WorkerError> {
980        let (sync, segments_sent) = queue_state
981            .queue
982            .tx_avail(&self.active_state.data.tx_segments)
983            .map_err(WorkerError::Endpoint)?;
984
985        assert!(segments_sent <= self.active_state.data.tx_segments.len());
986
987        if sync && segments_sent == self.active_state.data.tx_segments.len() {
988            self.active_state.data.tx_segments.clear();
989            self.complete_tx_packet(id)?;
990        }
991        Ok(())
992    }
993
994    fn complete_tx_packet(&mut self, id: TxId) -> Result<(), WorkerError> {
995        let state = &mut self.active_state;
996        let mut tx_packet = state.pending_tx_packets[id.0 as usize].take().unwrap();
997        tx_packet.work.complete(0);
998        self.active_state.stats.tx_packets.increment();
999        Ok(())
1000    }
1001}