1#![expect(missing_docs)]
5#![forbid(unsafe_code)]
6
7mod buffers;
8pub mod resolver;
9
10use crate::buffers::VirtioWorkPool;
12use bitfield_struct::bitfield;
13use futures::FutureExt;
14use futures::StreamExt;
15use futures_concurrency::future::Race;
16use guestmem::GuestMemory;
17use inspect::Inspect;
18use inspect::InspectMut;
19use inspect_counters::Counter;
20use inspect_counters::Histogram;
21use net_backend::Endpoint;
22use net_backend::EndpointAction;
23use net_backend::QueueConfig;
24use net_backend::RxId;
25use net_backend::TxId;
26use net_backend::TxMetadata;
27use net_backend::TxSegment;
28use net_backend::TxSegmentType;
29use net_backend_resources::mac_address::MacAddress;
30use pal_async::wait::PolledWait;
31use std::future::pending;
32use std::mem::offset_of;
33use std::sync::Arc;
34use std::task::Poll;
35use task_control::AsyncRun;
36use task_control::InspectTaskMut;
37use task_control::StopTask;
38use task_control::TaskControl;
39use thiserror::Error;
40use virtio::DeviceTraits;
41use virtio::DeviceTraitsSharedMemory;
42use virtio::Resources;
43use virtio::VirtioDevice;
44use virtio::VirtioQueue;
45use virtio::VirtioQueueCallbackWork;
46use vmcore::vm_task::VmTaskDriver;
47use vmcore::vm_task::VmTaskDriverSource;
48use zerocopy::FromBytes;
49use zerocopy::Immutable;
50use zerocopy::IntoBytes;
51use zerocopy::KnownLayout;
52
53#[bitfield(u64)]
55#[derive(IntoBytes, Immutable, KnownLayout, FromBytes)]
56struct NetworkFeatures {
57 pub csum: bool,
58 pub guest_csum: bool,
59 pub ctrl_guest_offloads: bool,
60 pub mtu: bool,
61 _reserved: bool,
62 pub mac: bool,
63 _reserved2: bool,
64 pub guest_tso4: bool,
65 pub guest_tso6: bool,
66 pub guest_ecn: bool,
67 pub guest_ufo: bool,
68 pub host_tso4: bool,
69 pub host_tso6: bool,
70 pub host_ecn: bool,
71 pub host_ufo: bool,
72 pub mrg_rxbuf: bool,
73 pub status: bool,
74 pub ctrl_vq: bool,
75 pub ctrl_rx: bool,
76 pub ctrl_vlan: bool,
77 _reserved3: bool,
78 pub guest_announce: bool,
79 pub mq: bool,
80 pub ctrl_mac_addr: bool,
81 #[bits(29)]
82 _reserved4: u64,
83 pub notf_coal: bool,
84 pub guest_uso4: bool,
85 pub guest_uso6: bool,
86 pub host_uso: bool,
87 pub hash_report: bool,
88 _reserved5: bool,
89 pub guest_hdrlen: bool,
90 pub rss: bool,
91 pub rsc_ext: bool,
92 pub standby: bool,
93 pub speed_duplex: bool,
94}
95
96#[bitfield(u16)]
98#[derive(IntoBytes, Immutable, KnownLayout, FromBytes)]
99struct NetStatus {
100 pub link_up: bool,
101 pub announce: bool,
102 #[bits(14)]
103 _reserved: u16,
104}
105
106const DEFAULT_MTU: u16 = 1514;
107
108#[expect(dead_code)]
109const VIRTIO_NET_MAX_QUEUES: u16 = 0x8000;
110
111#[repr(C)]
112struct NetConfig {
113 pub mac: [u8; 6],
114 pub status: u16,
115 pub max_virtqueue_pairs: u16,
116 pub mtu: u16,
117 pub speed: u32, pub duplex: u8, pub rss_max_key_size: u8, pub rss_max_indirection_table_length: u16, pub supported_hash_types: u32, }
123
124#[bitfield(u8)]
126#[derive(IntoBytes, Immutable, KnownLayout, FromBytes)]
127struct VirtioNetHeaderFlags {
128 pub needs_csum: bool,
129 pub data_valid: bool,
130 pub rsc_info: bool,
131 #[bits(5)]
132 _reserved: u8,
133}
134
135#[bitfield(u8)]
136#[derive(IntoBytes, Immutable, KnownLayout, FromBytes)]
137struct VirtioNetHeaderGso {
138 #[bits(3)]
139 pub protocol: VirtioNetHeaderGsoProtocol,
140 #[bits(4)]
141 _reserved: u8,
142 pub ecn: bool,
143}
144
145open_enum::open_enum! {
147 #[derive(IntoBytes, Immutable, KnownLayout, FromBytes)]
148 enum VirtioNetHeaderGsoProtocol: u8 {
149 NONE = 0,
150 TCPV4 = 1,
151 UDP = 3,
152 TCPV6 = 4,
153 UDP_L4 = 5,
154 }
155}
156
157impl VirtioNetHeaderGsoProtocol {
158 const fn from_bits(bits: u8) -> Self {
159 Self(bits)
160 }
161
162 const fn into_bits(self) -> u8 {
163 self.0
164 }
165}
166
167#[derive(IntoBytes, Immutable, KnownLayout, FromBytes)]
168#[repr(C)]
169struct VirtioNetHeader {
170 pub flags: u8,
171 pub gso_type: u8,
172 pub hdr_len: u16,
173 pub gso_size: u16,
174 pub csum_start: u16,
175 pub csum_offset: u16,
176 pub num_buffers: u16,
177 pub hash_value: u32, pub hash_report: u16, pub padding_reserved: u16, }
181
182fn header_size() -> usize {
183 offset_of!(VirtioNetHeader, hash_value)
185}
186
187struct Adapter {
188 driver: VmTaskDriver,
189 max_queues: u16,
190 tx_fast_completions: bool,
191 mac_address: MacAddress,
192}
193
194pub struct Device {
195 registers: NetConfig,
196 memory: GuestMemory,
197 coordinator: TaskControl<CoordinatorState, Coordinator>,
198 coordinator_send: Option<mesh::Sender<CoordinatorMessage>>,
199 adapter: Arc<Adapter>,
200 driver_source: VmTaskDriverSource,
201}
202
203impl Drop for Device {
204 fn drop(&mut self) {}
205}
206
207impl VirtioDevice for Device {
208 fn traits(&self) -> DeviceTraits {
209 DeviceTraits {
211 device_id: 1,
212 device_features: NetworkFeatures::new().with_mac(true).into(),
213 max_queues: 2 * self.registers.max_virtqueue_pairs,
214 device_register_length: size_of::<NetConfig>() as u32,
215 shared_memory: DeviceTraitsSharedMemory { id: 0, size: 0 },
216 }
217 }
218
219 fn read_registers_u32(&self, offset: u16) -> u32 {
220 match offset {
221 0 => u32::from_le_bytes(self.registers.mac[..4].try_into().unwrap()),
222 4 => {
223 (u16::from_le_bytes(self.registers.mac[4..].try_into().unwrap()) as u32)
224 | ((self.registers.status as u32) << 16)
225 }
226 8 => (self.registers.max_virtqueue_pairs as u32) | ((self.registers.mtu as u32) << 16),
227 12 => self.registers.speed,
228 16 => {
229 (self.registers.duplex as u32)
230 | ((self.registers.rss_max_key_size as u32) << 8)
231 | ((self.registers.rss_max_indirection_table_length as u32) << 24)
232 }
233 20 => self.registers.supported_hash_types,
234 _ => 0,
235 }
236 }
237
238 fn write_registers_u32(&mut self, _offset: u16, _val: u32) {}
239
240 fn enable(&mut self, resources: Resources) {
241 let mut queue_resources: Vec<_> = resources.queues.into_iter().collect();
242 let mut workers = Vec::with_capacity(queue_resources.len() / 2);
243 while queue_resources.len() > 1 {
244 let mut next = queue_resources.drain(..2);
245 let rx_resources = next.next().unwrap();
246 let tx_resources = next.next().unwrap();
247 if !rx_resources.params.enable || !tx_resources.params.enable {
248 continue;
249 }
250
251 let rx_queue_size = rx_resources.params.size;
252 let rx_queue_event = PolledWait::new(&self.adapter.driver, rx_resources.event);
253 if let Err(err) = rx_queue_event {
254 tracing::error!(
255 err = &err as &dyn std::error::Error,
256 "Failed creating queue event"
257 );
258 continue;
259 }
260 let rx_queue = VirtioQueue::new(
261 resources.features,
262 rx_resources.params,
263 self.memory.clone(),
264 rx_resources.notify,
265 rx_queue_event.unwrap(),
266 );
267 if let Err(err) = rx_queue {
268 tracing::error!(
269 err = &err as &dyn std::error::Error,
270 "Failed creating virtio net receive queue"
271 );
272 continue;
273 }
274
275 let tx_queue_size = tx_resources.params.size;
276 let tx_queue_event = PolledWait::new(&self.adapter.driver, tx_resources.event);
277 if let Err(err) = tx_queue_event {
278 tracing::error!(
279 err = &err as &dyn std::error::Error,
280 "Failed creating queue event"
281 );
282 continue;
283 }
284 let tx_queue = VirtioQueue::new(
285 resources.features,
286 tx_resources.params,
287 self.memory.clone(),
288 tx_resources.notify,
289 tx_queue_event.unwrap(),
290 );
291 if let Err(err) = tx_queue {
292 tracing::error!(
293 err = &err as &dyn std::error::Error,
294 "Failed creating virtio net transmit queue"
295 );
296 continue;
297 }
298 workers.push(VirtioState {
299 rx_queue: rx_queue.unwrap(),
300 rx_queue_size,
301 tx_queue: tx_queue.unwrap(),
302 tx_queue_size,
303 });
304 }
305
306 let (tx, rx) = mesh::channel();
307 self.coordinator_send = Some(tx);
308 self.insert_coordinator(rx, workers.len() as u16);
309 for (i, virtio_state) in workers.into_iter().enumerate() {
310 self.insert_worker(virtio_state, i);
311 }
312 self.coordinator.start();
313 }
314
315 fn disable(&mut self) {
316 if let Some(send) = self.coordinator_send.take() {
317 send.send(CoordinatorMessage::Disable);
318 }
319 }
320}
321
322struct EndpointQueueState {
323 queue: Box<dyn net_backend::Queue>,
324}
325
326struct NetQueue {
327 state: Option<EndpointQueueState>,
328}
329
330impl InspectTaskMut<Worker> for NetQueue {
331 fn inspect_mut(&mut self, req: inspect::Request<'_>, worker: Option<&mut Worker>) {
332 if worker.is_none() && self.state.is_none() {
333 req.ignore();
334 return;
335 }
336 let mut resp = req.respond();
337 if let Some(worker) = worker {
338 resp.field(
339 "pending_tx_packets",
340 worker
341 .active_state
342 .pending_tx_packets
343 .iter()
344 .fold(0, |acc, next| acc + if next.is_some() { 1 } else { 0 }),
345 )
346 .field(
347 "pending_rx_packets",
348 worker.active_state.pending_rx_packets.ready().len(),
349 )
350 .field(
351 "pending_tx",
352 !worker.active_state.data.tx_segments.is_empty(),
353 )
354 .merge(&worker.active_state.stats);
355 }
356
357 if let Some(epqueue_state) = &mut self.state {
358 resp.field_mut("queue", &mut epqueue_state.queue);
359 }
360 }
361}
362
363struct ProcessingData {
365 tx_segments: Vec<TxSegment>,
366 tx_done: Box<[TxId]>,
367 rx_ready: Box<[RxId]>,
368}
369
370impl ProcessingData {
371 fn new(rx_queue_size: u16, tx_queue_size: u16) -> Self {
372 Self {
373 tx_segments: Vec::new(),
374 tx_done: vec![TxId(0); tx_queue_size as usize].into(),
375 rx_ready: vec![RxId(0); rx_queue_size as usize].into(),
376 }
377 }
378}
379
380#[derive(Inspect, Default)]
381struct QueueStats {
382 tx_stalled: Counter,
383 spurious_wakes: Counter,
384 rx_packets: Counter,
385 tx_packets: Counter,
386 tx_packets_per_wake: Histogram<10>,
387 rx_packets_per_wake: Histogram<10>,
388}
389
390struct ActiveState {
391 pending_tx_packets: Vec<Option<PendingTxPacket>>,
392 pending_rx_packets: VirtioWorkPool,
393 data: ProcessingData,
394 stats: QueueStats,
395}
396
397impl ActiveState {
398 fn new(mem: GuestMemory, rx_queue_size: u16, tx_queue_size: u16) -> Self {
399 Self {
400 pending_tx_packets: (0..tx_queue_size).map(|_| None).collect(),
401 pending_rx_packets: VirtioWorkPool::new(mem, rx_queue_size),
402 data: ProcessingData::new(rx_queue_size, tx_queue_size),
403 stats: Default::default(),
404 }
405 }
406}
407
408struct PendingTxPacket {
410 work: VirtioQueueCallbackWork,
411}
412
413pub struct NicBuilder {
414 max_queues: u16,
415}
416
417impl NicBuilder {
418 pub fn max_queues(mut self, max_queues: u16) -> Self {
419 self.max_queues = max_queues;
420 self
421 }
422
423 pub fn build(
425 self,
426 driver_source: &VmTaskDriverSource,
427 memory: GuestMemory,
428 endpoint: Box<dyn Endpoint>,
429 mac_address: MacAddress,
430 ) -> Device {
431 let max_queues = 1;
435
436 let driver = driver_source.simple();
437 let adapter = Arc::new(Adapter {
438 driver,
439 max_queues,
440 tx_fast_completions: endpoint.tx_fast_completions(),
441 mac_address,
442 });
443
444 let coordinator = TaskControl::new(CoordinatorState {
445 endpoint,
446 adapter: adapter.clone(),
447 });
448
449 let registers = NetConfig {
450 mac: mac_address.to_bytes(),
451 status: NetStatus::new().with_link_up(true).into(),
452 max_virtqueue_pairs: max_queues,
453 mtu: DEFAULT_MTU,
454 speed: 0xffffffff,
455 duplex: 0xff,
456 rss_max_key_size: 0,
457 rss_max_indirection_table_length: 0,
458 supported_hash_types: 0,
459 };
460
461 Device {
462 registers,
463 memory,
464 coordinator,
465 coordinator_send: None,
466 adapter,
467 driver_source: driver_source.clone(),
468 }
469 }
470}
471
472impl Device {
473 pub fn builder() -> NicBuilder {
474 NicBuilder { max_queues: !0 }
475 }
476}
477
478impl InspectMut for Device {
479 fn inspect_mut(&mut self, req: inspect::Request<'_>) {
480 self.coordinator.inspect_mut(req);
481 }
482}
483
484impl Device {
485 fn insert_coordinator(&mut self, recv: mesh::Receiver<CoordinatorMessage>, num_queues: u16) {
486 self.coordinator.insert(
487 &self.adapter.driver,
488 "virtio-net-coordinator".to_string(),
489 Coordinator {
490 recv,
491 workers: (0..self.adapter.max_queues)
492 .map(|_| TaskControl::new(NetQueue { state: None }))
493 .collect(),
494 num_queues,
495 restart: true,
496 },
497 );
498 }
499
500 fn insert_worker(&mut self, virtio_state: VirtioState, idx: usize) {
504 let mut builder = self.driver_source.builder();
505 builder.target_vp(0);
507 builder.run_on_target(!self.adapter.tx_fast_completions);
512 let driver = builder.build("virtio-net");
513
514 let active_state = ActiveState::new(
515 self.memory.clone(),
516 virtio_state.rx_queue_size,
517 virtio_state.tx_queue_size,
518 );
519 let worker = Worker {
520 virtio_state,
521 active_state,
522 };
523 let coordinator = self.coordinator.state_mut().unwrap();
524 let worker_task = &mut coordinator.workers[idx];
525 worker_task.insert(&driver, "virtio-net".to_string(), worker);
526 worker_task.start();
527 }
528}
529
530#[derive(PartialEq)]
531enum CoordinatorMessage {
532 Disable,
533}
534
535struct Coordinator {
536 recv: mesh::Receiver<CoordinatorMessage>,
537 workers: Vec<TaskControl<NetQueue, Worker>>,
538 num_queues: u16,
539 restart: bool,
540}
541
542struct CoordinatorState {
543 endpoint: Box<dyn Endpoint>,
544 adapter: Arc<Adapter>,
545}
546
547impl InspectTaskMut<Coordinator> for CoordinatorState {
548 fn inspect_mut(&mut self, req: inspect::Request<'_>, coordinator: Option<&mut Coordinator>) {
549 let mut resp = req.respond();
550
551 let adapter = self.adapter.as_ref();
552 resp.field("mac_address", adapter.mac_address)
553 .field("max_queues", adapter.max_queues);
554
555 resp.field("endpoint_type", self.endpoint.endpoint_type())
556 .field(
557 "endpoint_max_queues",
558 self.endpoint.multiqueue_support().max_queues,
559 )
560 .field_mut("endpoint", self.endpoint.as_mut());
561
562 if let Some(coordinator) = coordinator {
563 resp.fields_mut(
564 "queues",
565 coordinator.workers[..coordinator.num_queues as usize]
566 .iter_mut()
567 .enumerate(),
568 );
569 }
570 }
571}
572
573impl AsyncRun<Coordinator> for CoordinatorState {
574 async fn run(
575 &mut self,
576 stop: &mut StopTask<'_>,
577 coordinator: &mut Coordinator,
578 ) -> Result<(), task_control::Cancelled> {
579 coordinator.process(stop, self).await
580 }
581}
582
583impl Coordinator {
584 async fn process(
585 &mut self,
586 stop: &mut StopTask<'_>,
587 state: &mut CoordinatorState,
588 ) -> Result<(), task_control::Cancelled> {
589 loop {
590 if self.restart {
591 stop.until_stopped(self.stop_workers()).await?;
592 if let Err(err) = self.restart_queues(state).await {
595 tracing::error!(
596 error = &err as &dyn std::error::Error,
597 "failed to restart queues"
598 );
599 }
600 self.restart = false;
601 }
602 self.start_workers();
603 enum Message {
604 Internal(CoordinatorMessage),
605 ChannelDisconnected,
606 UpdateFromEndpoint(EndpointAction),
607 }
608 let message = {
609 let wait_for_message = async {
610 let internal_msg = self
611 .recv
612 .next()
613 .map(|x| x.map_or(Message::ChannelDisconnected, Message::Internal));
614 let endpoint_restart = state
615 .endpoint
616 .wait_for_endpoint_action()
617 .map(Message::UpdateFromEndpoint);
618 (internal_msg, endpoint_restart).race().await
619 };
620 stop.until_stopped(wait_for_message).await?
621 };
622 match message {
623 Message::UpdateFromEndpoint(EndpointAction::RestartRequired) => self.restart = true,
624 Message::UpdateFromEndpoint(EndpointAction::LinkStatusNotify(_)) => {
625 tracing::error!("unexpected link status notification")
626 }
627 Message::Internal(CoordinatorMessage::Disable) | Message::ChannelDisconnected => {
628 stop.until_stopped(self.stop_workers()).await?;
629 break;
630 }
631 };
632 }
633 Ok(())
634 }
635
636 async fn stop_workers(&mut self) {
637 for worker in &mut self.workers {
638 worker.stop().await;
639 }
640 }
641
642 async fn restart_queues(&mut self, c_state: &mut CoordinatorState) -> Result<(), WorkerError> {
643 for worker in &mut self.workers {
645 worker.task_mut().state = None;
646 }
647
648 let (rx_pools, ready_packets): (Vec<_>, Vec<_>) = self
649 .workers
650 .iter()
651 .map(|worker| {
652 let pool = worker
653 .state()
654 .unwrap()
655 .active_state
656 .pending_rx_packets
657 .clone();
658 let ready = pool.ready();
659 (pool, ready)
660 })
661 .collect::<Vec<_>>()
662 .into_iter()
663 .unzip();
664 let mut queue_config = Vec::with_capacity(rx_pools.len());
665 for (i, pool) in rx_pools.into_iter().enumerate() {
666 queue_config.push(QueueConfig {
667 pool: Box::new(pool),
668 initial_rx: ready_packets[i].as_slice(),
669 driver: Box::new(c_state.adapter.driver.clone()),
670 });
671 }
672
673 let mut queues = Vec::new();
674 c_state
675 .endpoint
676 .get_queues(queue_config, None, &mut queues)
677 .await
678 .map_err(WorkerError::Endpoint)?;
679
680 assert_eq!(queues.len(), self.workers.len());
681
682 for (worker, queue) in self.workers.iter_mut().zip(queues) {
683 worker.task_mut().state = Some(EndpointQueueState { queue });
684 }
685
686 Ok(())
687 }
688
689 fn start_workers(&mut self) {
690 for worker in &mut self.workers {
691 worker.start();
692 }
693 }
694}
695
696impl AsyncRun<Worker> for NetQueue {
697 async fn run(
698 &mut self,
699 stop: &mut StopTask<'_>,
700 worker: &mut Worker,
701 ) -> Result<(), task_control::Cancelled> {
702 match worker.process(stop, self).await {
703 Ok(()) => {}
704 Err(WorkerError::Cancelled(cancelled)) => return Err(cancelled),
705 Err(err) => {
706 tracing::error!(err = &err as &dyn std::error::Error, "virtio net error");
707 }
708 }
709 Ok(())
710 }
711}
712
713struct VirtioState {
714 rx_queue: VirtioQueue,
715 rx_queue_size: u16,
716 tx_queue: VirtioQueue,
717 tx_queue_size: u16,
718}
719
720#[derive(Debug, Error)]
721enum WorkerError {
722 #[error("packet error")]
723 Packet(#[from] PacketError),
724 #[error("virtio queue processing error")]
725 VirtioQueue(#[source] std::io::Error),
726 #[error("endpoint")]
727 Endpoint(#[source] anyhow::Error),
728 #[error("cancelled")]
729 Cancelled(task_control::Cancelled),
730}
731
732impl From<task_control::Cancelled> for WorkerError {
733 fn from(value: task_control::Cancelled) -> Self {
734 Self::Cancelled(value)
735 }
736}
737
738#[derive(Debug, Error)]
739enum PacketError {
740 #[error("Empty packet")]
741 Empty,
742}
743
744struct Worker {
745 virtio_state: VirtioState,
746 active_state: ActiveState,
747}
748
749impl Worker {
750 async fn process(
751 &mut self,
752 stop: &mut StopTask<'_>,
753 queue: &mut NetQueue,
754 ) -> Result<(), WorkerError> {
755 if queue.state.is_none() {
759 stop.until_stopped(pending()).await?
761 }
762
763 self.main_loop(stop, queue).await?;
764 Ok(())
765 }
766
767 async fn main_loop(
768 &mut self,
769 stop: &mut StopTask<'_>,
770 queue: &mut NetQueue,
771 ) -> Result<(), WorkerError> {
772 let epqueue_state = queue.state.as_mut().unwrap();
773
774 loop {
775 let did_some_work = self.process_endpoint_rx(epqueue_state.queue.as_mut())?
776 | self.process_virtio_rx(epqueue_state.queue.as_mut())?
777 | self.process_endpoint_tx(epqueue_state.queue.as_mut())?;
778
779 if !did_some_work {
780 self.active_state.stats.spurious_wakes.increment();
781 }
782
783 stop.until_stopped(async {
787 enum WakeReason {
788 PacketFromClient(Result<VirtioQueueCallbackWork, std::io::Error>),
789 PacketToClient(Result<VirtioQueueCallbackWork, std::io::Error>),
790 NetworkBackend,
791 }
792 loop {
793 let net_queue = std::future::poll_fn(|cx| -> Poll<()> {
794 epqueue_state.queue.poll_ready(cx)
796 })
797 .map(|_| WakeReason::NetworkBackend);
798 let to_client = self.virtio_state.rx_queue.next().map(|work| {
799 WakeReason::PacketToClient(work.expect("queue never completes"))
800 });
801 let wake_reason = if self.active_state.data.tx_segments.is_empty() {
802 let from_client = self.virtio_state.tx_queue.next().map(|work| {
803 WakeReason::PacketFromClient(work.expect("queue never completes"))
804 });
805 (net_queue, from_client, to_client).race().await
806 } else {
807 (net_queue, to_client).race().await
808 };
809 match wake_reason {
810 WakeReason::NetworkBackend => {
811 tracing::trace!("endpoint ready");
812 return Ok::<(), WorkerError>(());
813 }
814 WakeReason::PacketFromClient(work) => {
815 tracing::trace!("tx packet");
816 let work = work.map_err(WorkerError::VirtioQueue)?;
817 self.queue_tx_packet(work)?;
818 self.process_virtio_rx(epqueue_state.queue.as_mut())?;
819 if !self.transmit_pending_segments(epqueue_state)? {
820 self.active_state.stats.tx_stalled.increment();
821 }
822 }
823 WakeReason::PacketToClient(work) => {
824 tracing::trace!("rx packet");
825 let work = work.map_err(WorkerError::VirtioQueue)?;
826 epqueue_state
827 .queue
828 .rx_avail(&[self.active_state.pending_rx_packets.queue_work(work)]);
829 }
830 }
831 }
832 })
833 .await??;
834 }
835 }
836
837 fn queue_tx_packet(&mut self, mut work: VirtioQueueCallbackWork) -> Result<(), WorkerError> {
838 let mut header_bytes_remaining = header_size() as u32;
839 let mut segments = work
840 .payload
841 .iter()
842 .filter_map(|p| {
843 if p.writeable {
844 None
845 } else if header_bytes_remaining >= p.length {
846 header_bytes_remaining -= p.length;
847 None
848 } else if header_bytes_remaining > 0 {
849 let segment = TxSegment {
850 ty: TxSegmentType::Tail,
851 gpa: p.address + header_bytes_remaining as u64,
852 len: p.length - header_bytes_remaining,
853 };
854 header_bytes_remaining = 0;
855 Some(segment)
856 } else {
857 Some(TxSegment {
858 ty: TxSegmentType::Tail,
859 gpa: p.address,
860 len: p.length,
861 })
862 }
863 })
864 .collect::<Vec<_>>();
865 if segments.is_empty() {
866 work.complete(0);
867 return Err(WorkerError::Packet(PacketError::Empty));
868 }
869 let idx = work.descriptor_index();
870 segments[0].ty = TxSegmentType::Head(TxMetadata {
871 id: TxId(idx.into()),
872 segment_count: segments.len(),
873 len: work.get_payload_length(false) as usize - header_size(),
874 ..Default::default()
875 });
876 let state = &mut self.active_state;
877 state.data.tx_segments.append(&mut segments);
878 assert!(state.pending_tx_packets[idx as usize].is_none());
879 state.pending_tx_packets[idx as usize] = Some(PendingTxPacket { work });
880 Ok(())
881 }
882
883 fn process_virtio_rx(
884 &mut self,
885 epqueue: &mut dyn net_backend::Queue,
886 ) -> Result<bool, WorkerError> {
887 let mut rx_ids = Vec::new();
889 while let Some(Some(work)) = self.virtio_state.rx_queue.next().now_or_never() {
890 tracing::trace!("rx packet");
891 let work = work.map_err(WorkerError::VirtioQueue)?;
892 rx_ids.push(self.active_state.pending_rx_packets.queue_work(work));
893 }
894 if !rx_ids.is_empty() {
895 epqueue.rx_avail(rx_ids.as_slice());
896 Ok(true)
897 } else {
898 Ok(false)
899 }
900 }
901
902 fn process_endpoint_rx(
903 &mut self,
904 epqueue: &mut dyn net_backend::Queue,
905 ) -> Result<bool, WorkerError> {
906 let state = &mut self.active_state;
907 let n = epqueue
908 .rx_poll(&mut state.data.rx_ready)
909 .map_err(WorkerError::Endpoint)?;
910 if n == 0 {
911 return Ok(false);
912 }
913
914 for ready_id in state.data.rx_ready[..n].iter() {
915 state.stats.rx_packets.increment();
916 state.pending_rx_packets.complete_packet(*ready_id);
917 }
918
919 state.stats.rx_packets_per_wake.add_sample(n as u64);
920 Ok(true)
921 }
922
923 fn process_endpoint_tx(
924 &mut self,
925 epqueue: &mut dyn net_backend::Queue,
926 ) -> Result<bool, WorkerError> {
927 let n = epqueue
929 .tx_poll(&mut self.active_state.data.tx_done)
930 .map_err(|tx_error| WorkerError::Endpoint(tx_error.into()))?;
931 if n == 0 {
932 return Ok(false);
933 }
934
935 let pending_segment_id = if !self.active_state.data.tx_segments.is_empty() {
936 let TxSegmentType::Head(metadata) = &self.active_state.data.tx_segments[0].ty else {
937 unreachable!()
938 };
939 Some(metadata.id)
940 } else {
941 None
942 };
943 for i in 0..n {
944 let id = self.active_state.data.tx_done[i];
945 self.complete_tx_packet(id)?;
946 if let Some(pending_segment_id) = pending_segment_id {
947 if pending_segment_id.0 == id.0 {
948 self.active_state.data.tx_segments.clear();
949 }
950 }
951 }
952 self.active_state
953 .stats
954 .tx_packets_per_wake
955 .add_sample(n as u64);
956
957 Ok(true)
958 }
959
960 fn transmit_pending_segments(
961 &mut self,
962 queue_state: &mut EndpointQueueState,
963 ) -> Result<bool, WorkerError> {
964 if self.active_state.data.tx_segments.is_empty() {
965 return Ok(false);
966 }
967 let TxSegmentType::Head(metadata) = &self.active_state.data.tx_segments[0].ty else {
968 unreachable!()
969 };
970 let id = metadata.id;
971 self.transmit_segments(queue_state, id)?;
972 Ok(true)
973 }
974
975 fn transmit_segments(
976 &mut self,
977 queue_state: &mut EndpointQueueState,
978 id: TxId,
979 ) -> Result<(), WorkerError> {
980 let (sync, segments_sent) = queue_state
981 .queue
982 .tx_avail(&self.active_state.data.tx_segments)
983 .map_err(WorkerError::Endpoint)?;
984
985 assert!(segments_sent <= self.active_state.data.tx_segments.len());
986
987 if sync && segments_sent == self.active_state.data.tx_segments.len() {
988 self.active_state.data.tx_segments.clear();
989 self.complete_tx_packet(id)?;
990 }
991 Ok(())
992 }
993
994 fn complete_tx_packet(&mut self, id: TxId) -> Result<(), WorkerError> {
995 let state = &mut self.active_state;
996 let mut tx_packet = state.pending_tx_packets[id.0 as usize].take().unwrap();
997 tx_packet.work.complete(0);
998 self.active_state.stats.tx_packets.increment();
999 Ok(())
1000 }
1001}